├── .dockerignore
├── .github
├── CODEOWNERS
├── PULL_REQUEST_TEMPLATE.md
└── workflows
│ ├── AddLabel.yml
│ ├── CPUTests.yml
│ ├── RunTests.yml
│ ├── UploadDockerImages.yml
│ ├── build_and_upload_images.sh
│ ├── build_upload_internal.yml
│ ├── require-checklist.yml
│ ├── run_tests_internal.yml
│ └── utils
│ └── setup_runner.sh
├── .gitignore
├── .pre-commit-config.yaml
├── .vscode
├── launch.json
└── settings.json
├── AUTHORS
├── CONTRIBUTING.md
├── LICENSE
├── MaxText
├── __init__.py
├── accelerator_to_spec_map.py
├── benchmark_chunked_prefill.py
├── checkpointing.py
├── common_types.py
├── configs
│ ├── README.md
│ ├── a3
│ │ ├── llama_2_7b
│ │ │ ├── 16vm.sh
│ │ │ ├── 1vm.sh
│ │ │ ├── 2vm.sh
│ │ │ ├── 4vm.sh
│ │ │ ├── 8vm.sh
│ │ │ └── README.md
│ │ └── llama_3.1_405b
│ │ │ └── 128vm.sh
│ ├── base.yml
│ ├── dpo.yml
│ ├── experimental
│ │ ├── 1024b.sh
│ │ ├── 128b.sh
│ │ ├── 256b.sh
│ │ ├── 32b.sh
│ │ ├── 512b.sh
│ │ └── 64b.sh
│ ├── gpu_smoke_test.yml
│ ├── inference.yml
│ ├── inference_jetstream.yml
│ ├── models
│ │ ├── deepseek2-16b.yml
│ │ ├── deepseek2-236b.yml
│ │ ├── deepseek3-671b.yml
│ │ ├── gemma-2b.yml
│ │ ├── gemma-7b.yml
│ │ ├── gemma2-27b.yml
│ │ ├── gemma2-2b.yml
│ │ ├── gemma2-9b.yml
│ │ ├── gemma3-12b.yml
│ │ ├── gemma3-27b.yml
│ │ ├── gemma3-4b.yml
│ │ ├── gpt3-175b.yml
│ │ ├── gpt3-22b.yml
│ │ ├── gpt3-52k.yml
│ │ ├── gpt3-6b.yml
│ │ ├── gpu
│ │ │ ├── llama2_70b.yml
│ │ │ ├── llama2_7b.yml
│ │ │ ├── llama3.1_405b.yml
│ │ │ ├── llama3_70b.yml
│ │ │ ├── llama3_8b.yml
│ │ │ ├── mixtral_8x1b.yml
│ │ │ ├── mixtral_8x2b.yml
│ │ │ └── mixtral_8x7b.yml
│ │ ├── llama2-13b.yml
│ │ ├── llama2-70b.yml
│ │ ├── llama2-7b.yml
│ │ ├── llama3-405b.yml
│ │ ├── llama3-70b.yml
│ │ ├── llama3-8b.yml
│ │ ├── llama3.1-405b.yml
│ │ ├── llama3.1-70b.yml
│ │ ├── llama3.1-8b.yml
│ │ ├── llama3.3-70b.yml
│ │ ├── llama4-17b-128e.yml
│ │ ├── llama4-17b-16e.yml
│ │ ├── mistral-7b.yml
│ │ ├── mixtral-8x22b.yml
│ │ └── mixtral-8x7b.yml
│ ├── quantization
│ │ ├── README.md
│ │ ├── dense_llm_subchannel.json
│ │ ├── dense_llm_weight_only_scale.json
│ │ ├── int4_weight_only.json
│ │ └── int8_weight_only.json
│ ├── sft.yml
│ ├── tpu_smoke_test.yml
│ ├── trillium
│ │ ├── gemma2_27b.sh
│ │ ├── gemma2_9b.sh
│ │ ├── gemma3_27b.sh
│ │ ├── gpt3_175b.sh
│ │ ├── llama2_7b_4096.sh
│ │ └── mixtral_8x7b.sh
│ ├── v4
│ │ ├── 22b.sh
│ │ ├── 52b.sh
│ │ └── README.md
│ ├── v5e
│ │ ├── 128b.sh
│ │ ├── 16b.sh
│ │ ├── 32b.sh
│ │ ├── 64b.sh
│ │ ├── README.md
│ │ ├── gpt3_175b.sh
│ │ ├── llama2_13b.sh
│ │ ├── llama2_70b.sh
│ │ ├── llama2_70b_v5e-16.yml
│ │ ├── llama2_7b.sh
│ │ ├── llama3_405b_v5e-64.yml
│ │ └── llama3_70b_v5e-16.yml
│ ├── v5p
│ │ ├── 1024b.sh
│ │ ├── 128b.sh
│ │ ├── 256b.sh
│ │ ├── 32b.sh
│ │ ├── 512b.sh
│ │ ├── 64b.sh
│ │ ├── README.md
│ │ ├── gpt3_175b
│ │ │ ├── gpt3_175b_base.sh
│ │ │ ├── v5p_1024.sh
│ │ │ ├── v5p_12288.sh
│ │ │ ├── v5p_2048.sh
│ │ │ ├── v5p_3072.sh
│ │ │ ├── v5p_4096.sh
│ │ │ └── v5p_8192.sh
│ │ ├── llama2_70b.sh
│ │ └── llama2_7b.sh
│ └── v6e
│ │ └── inference
│ │ └── llama4_maverick_v6e-64.yml
├── convert_deepseek_ckpt.py
├── convert_deepseek_unscanned_ckpt.py
├── convert_gemma2_chkpt.py
├── convert_gemma3_chkpt.py
├── convert_gemma_chkpt.py
├── convert_gpt3_ckpt_from_paxml.py
├── decode.py
├── deepseek_fp8_to_bf16.py
├── elastic_train.py
├── experimental
│ └── rl
│ │ ├── grpo.yml
│ │ ├── grpo_input_pipeline.py
│ │ ├── grpo_trainer.py
│ │ └── grpo_trainer_test.yml
├── gcp_workload_monitor.py
├── generate_distillation_data.py
├── generate_param_only_checkpoint.py
├── globals.py
├── inference
│ ├── __init__.py
│ ├── configs
│ │ └── multi_host
│ │ │ ├── disaggregation
│ │ │ └── llama3_405b_v6e-16-16.yml
│ │ │ └── interleaved
│ │ │ ├── llama2_70b_v5e-16.yml
│ │ │ ├── llama3_405b_v5e-64.yml
│ │ │ └── llama3_70b_v5e-16.yml
│ ├── decode_multi.py
│ ├── gpu
│ │ ├── README.md
│ │ └── microbenchmark_llama2-70b_h100-8.sh
│ ├── jetstream_pathways
│ │ ├── Dockerfile
│ │ ├── README.md
│ │ └── jetstream_pathways_entrypoint.sh
│ ├── kvcache.py
│ ├── maxengine_server
│ │ ├── Dockerfile
│ │ ├── README.md
│ │ └── maxengine_server_entrypoint.sh
│ ├── page_manager.py
│ ├── paged_attention.py
│ └── paged_attention_kernel_v2.py
├── inference_microbenchmark.py
├── inference_microbenchmark_sweep.py
├── inference_mlperf
│ ├── README.md
│ ├── __init__.py
│ ├── evaluate-accuracy-fast.py
│ ├── evaluate-accuracy.py
│ ├── gpu
│ │ └── benchmarks_llama2-70b-h100_8.sh
│ ├── llama_offline_run.sh
│ ├── matmul
│ │ ├── __init__.py
│ │ ├── matmul_dtypes.py
│ │ ├── matmul_sharding.py
│ │ └── timing_util.py
│ ├── mixtral_offline_run.sh
│ ├── offline_inference.py
│ ├── offline_mode.py
│ ├── requirements.txt
│ ├── trillium
│ │ ├── __init__.py
│ │ ├── benchmarks_llama2-70b-trillium_2x4.sh
│ │ ├── microbenchmarks_llama2-70b-trillium_2x4.sh
│ │ └── select_xla_flags.py
│ ├── user.conf
│ ├── user100.conf
│ └── user5000.conf
├── inference_utils.py
├── input_pipeline
│ ├── __init__.py
│ ├── _distillation_data_processing.py
│ ├── _grain_data_processing.py
│ ├── _grain_tokenizer.py
│ ├── _hf_data_processing.py
│ ├── _input_pipeline_utils.py
│ ├── _tfds_data_processing.py
│ ├── _tfds_data_processing_c4_mlperf.py
│ └── input_pipeline_interface.py
├── kernels
│ ├── __init__.py
│ ├── megablox
│ │ ├── __init__.py
│ │ ├── common.py
│ │ ├── gmm.py
│ │ └── ops.py
│ └── ragged_attention.py
├── layers
│ ├── __init__.py
│ ├── attentions.py
│ ├── deepseek.py
│ ├── embeddings.py
│ ├── gemma.py
│ ├── gemma2.py
│ ├── gemma3.py
│ ├── gpt3.py
│ ├── initializers.py
│ ├── linears.py
│ ├── llama2.py
│ ├── llama4.py
│ ├── mistral.py
│ ├── mixtral.py
│ ├── models.py
│ ├── moe.py
│ ├── multi_token_prediction.py
│ ├── normalizations.py
│ ├── pipeline.py
│ ├── quantizations.py
│ └── simple_layer.py
├── llama4_ckpt_unscanned.py
├── llama_ckpt_conversion_inference_only.py
├── llama_mistral_mixtral_orbax_to_hf.py
├── llama_or_mistral_ckpt.py
├── load_and_quantize_checkpoint.py
├── max_logging.py
├── max_utils.py
├── maxengine.py
├── maxengine_config.py
├── maxengine_server.py
├── maxtext_utils.py
├── metric_logger.py
├── multihost_dataloading.py
├── multimodal_utils.py
├── optimizers.py
├── prefill_packing.py
├── profiler.py
├── pyconfig.py
├── scratch_code
│ ├── __init__.py
│ ├── analyze_sharegpt.py
│ ├── gemma_7b.sh
│ ├── generate_grpo_golden_logits.py
│ ├── generate_hf_golden_logits.py
│ ├── generate_sft_golden_data.py
│ ├── golden_gemma-2b_export.ipynb
│ ├── golden_gemma2-27b_export-flax.ipynb
│ ├── golden_gemma2-2b_export-flax.ipynb
│ ├── golden_gemma2-2b_export.ipynb
│ ├── golden_gemma2-9b_export-flax.ipynb
│ ├── golden_gemma2-9b_export.ipynb
│ ├── golden_llama2-70b_export.py
│ ├── golden_llama2-7b_export.ipynb
│ ├── golden_llama3-8b_export.ipynb
│ ├── golden_llama3_1_export.py
│ ├── golden_llama4_17b_16e_128e_export.ipynb
│ ├── golden_mistral-7b_export.ipynb
│ ├── golden_mixtral-8x22b_export.ipynb
│ ├── golden_mixtral-8x7b_export.ipynb
│ ├── mixtral-numerical-verification.ipynb
│ ├── run_inference_microbenchmark.sh
│ └── setup_transformer.sh
├── sequence_packing.py
├── sft_trainer.py
├── standalone_checkpointer.py
├── standalone_dataloader.py
├── test_assets
│ ├── .gitignore
│ ├── golden_data_deepseek_r1_distill_llama3.1-70b.jsonl
│ ├── golden_data_deepseek_r1_distill_llama3.1_8b.jsonl
│ ├── golden_data_gemma3_vit.jsonl
│ ├── golden_data_grpo_default.jsonl
│ ├── golden_data_sft_default.jsonl
│ └── test_image.jpg
├── tests
│ ├── __init__.py
│ ├── aot_hlo_identical_script.sh
│ ├── aot_hlo_identical_test.py
│ ├── attention_test.py
│ ├── check_llama4_layers.py
│ ├── check_mla_vs_reference.py
│ ├── context_parallelism_test.py
│ ├── decode_tests.py
│ ├── distillation_data_processing_test.py
│ ├── elastic_train_test.py
│ ├── forward_pass_logit_checker.py
│ ├── gpt3_test.py
│ ├── grain_data_processing_test.py
│ ├── grpo_trainer_correctness_test.py
│ ├── hf_checkpoint_conversion_checker.py
│ ├── hf_checkpoint_conversion_test.py
│ ├── hf_data_processing_test.py
│ ├── inference
│ │ ├── __init__.py
│ │ ├── kvcache_test.py
│ │ ├── page_manager_test.py
│ │ ├── test_llama2_7b_bf16.sh
│ │ └── test_llama2_7b_int8.sh
│ ├── integration_tests
│ │ ├── __init__.py
│ │ ├── checkpoint_compatibility_test.py
│ │ ├── checkpointing_test.py
│ │ ├── generate_param_only_checkpoint_test.py
│ │ ├── gradient_accumulation_test.py
│ │ ├── grpo_correctness.py
│ │ ├── inference_microbenchmark_smoke_test.py
│ │ ├── sft_trainer_correctness_test.py
│ │ ├── shmap_collective_matmul_test.py
│ │ ├── standalone_dl_ckpt_test.py
│ │ ├── train_tests.py
│ │ └── vision_encoder_test.py
│ ├── kernels_test.py
│ ├── llama_test.py
│ ├── max_utils_test.py
│ ├── maxengine_test.py
│ ├── maxtext_utils_test.py
│ ├── model_test.py
│ ├── moe_test.py
│ ├── multi_token_prediction_test.py
│ ├── multihost_dataloading_test.py
│ ├── multimodal_utils_test.py
│ ├── pipeline_parallelism_test.py
│ ├── profiler_test.py
│ ├── pyconfig_test.py
│ ├── quantizations_test.py
│ ├── sft_data_processing_test.py
│ ├── simple_decoder_layer_test.py
│ ├── state_dtypes_test.py
│ ├── tfds_data_processing_test.py
│ ├── tokenizer_test.py
│ ├── train_compile_test.py
│ ├── train_gpu_smoke_test.py
│ ├── train_int8_smoke_test.py
│ ├── train_smoke_test.py
│ └── train_using_ragged_dot_smoke_test.py
├── tokenizer.py
├── train.py
├── train_compile.py
├── train_tokenizer.py
├── utils
│ ├── __init__.py
│ ├── gcs_utils.py
│ └── lora_utils.py
├── vertex_tensorboard.py
└── weight_inspector.py
├── PREFLIGHT.md
├── README.md
├── assets
├── tokenizer
├── tokenizer.gemma
├── tokenizer.gemma3
├── tokenizer.llama2
├── tokenizer.mistral-v1
├── tokenizer.mistral-v3
└── tokenizer_llama3.tiktoken
├── benchmarks
├── Getting_Started_Benchmarking.md
├── __init__.py
├── benchmark_db_utils.py
├── benchmark_runner.py
├── command_utils.py
├── disruption_management
│ ├── __init__.py
│ ├── disruption_handler.py
│ ├── disruption_manager.py
│ ├── disruption_utils.py
│ └── monitor.py
├── llama2_v6e-256_benchmarks.py
├── maxtext_trillium_model_configs.py
├── maxtext_v5e_model_configs.py
├── maxtext_v5p_model_configs.py
├── maxtext_xpk_runner.py
├── mmlu
│ ├── __init__.py
│ ├── mmlu_categories.py
│ └── mmlu_eval.py
├── recipes
│ ├── __init__.py
│ ├── args_helper.py
│ ├── pw_elastic_training_recipe.py
│ ├── pw_long_running_recipe.py
│ ├── pw_mcjax_benchmark_recipe.py
│ ├── pw_mcjax_checkpoint_benchmark_recipe.py
│ ├── pw_remote_python_recipe.py
│ └── pw_suspend_resume.py
├── upload_metrics_to_bq.py
├── xla_flags_library.py
└── xpk_configs.py
├── code_style.sh
├── constraints_gpu.txt
├── docker_build_dependency_image.sh
├── docker_upload_runner.sh
├── download_dataset.sh
├── end_to_end
├── gpu
│ ├── a3
│ │ ├── test_convergence_125m_params.sh
│ │ ├── test_convergence_1b_params.sh
│ │ ├── test_gemma3_logits.sh
│ │ └── test_llama2_7b.sh
│ ├── mixtral
│ │ └── test_8x7b.sh
│ ├── test_collective_matmul_llama2_7b.sh
│ ├── test_feature.py
│ └── test_fp8_gemm_llama2_7b.sh
├── test_checkpoint_compatibility.sh
├── test_checkpointing.sh
├── test_generate_param_only_checkpoint.sh
├── test_jdi.sh
├── test_mtc_phase_2_save_path.sh
├── test_multi_tier_checkpointing.sh
├── test_profiler.py
└── tpu
│ ├── deepseek
│ ├── Run_DeepSeek.md
│ ├── v2-16b
│ │ └── test_deepseek.sh
│ └── v3-671b
│ │ └── test_deepseek.sh
│ ├── eval_assert.py
│ ├── gemma
│ ├── 2b
│ │ └── test_gemma.sh
│ ├── 7b
│ │ ├── 1_test_gemma.sh
│ │ └── 2_test_gemma.sh
│ └── Run_Gemma.md
│ ├── gemma2
│ ├── 27b
│ │ ├── 1_test_gemma.sh
│ │ └── 2_test_gemma.sh
│ ├── 2b
│ │ └── test_gemma2.sh
│ └── 9b
│ │ ├── 1_test_gemma.sh
│ │ └── 2_test_gemma.sh
│ ├── gemma3
│ ├── 12b
│ │ └── test_gemma3.sh
│ ├── 27b
│ │ └── test_gemma3.sh
│ ├── 4b
│ │ └── test_gemma3.sh
│ └── Run_Gemma3.md
│ ├── llama2
│ ├── 13b
│ │ ├── 1_test_llama2_13b.sh
│ │ └── 2_test_llama2_13b.sh
│ ├── 70b
│ │ ├── 1_test_llama2_70b.sh
│ │ └── 2_test_llama2_70b.sh
│ └── 7b
│ │ └── test_llama2_7b.sh
│ ├── llama3.1
│ ├── 405b
│ │ ├── 2_test_llama3.1_405b.sh
│ │ └── 3_test_llama3.1_405b.sh
│ ├── 70b
│ │ ├── 1_test_llama3.1_70b.sh
│ │ ├── 2_test_llama3.1_70b.sh
│ │ └── 3_test_llama3.1_70b.sh
│ └── 8b
│ │ ├── 1_test_llama3.1_8b.sh
│ │ ├── 2_test_llama3.1_8b.sh
│ │ └── 3_test_llama3.1_8b.sh
│ ├── llama3.3
│ └── 70b
│ │ ├── 1_test_llama3.3_70b.sh
│ │ └── 2_test_llama3.3_70b.sh
│ ├── llama3
│ ├── 70b
│ │ ├── 1_test_llama3_70b.sh
│ │ └── 2_test_llama3_70b.sh
│ └── 8b
│ │ ├── 1_test_llama3_8b.sh
│ │ └── 2_test_llama3_8b.sh
│ ├── llama4
│ ├── 1_test_llama4.sh
│ ├── 2_test_llama4.sh
│ └── Run_Llama4.md
│ ├── llama_finetuning_test.sh
│ ├── mistral
│ └── 7b
│ │ └── test_mistral-7b.sh
│ ├── mixtral
│ ├── 8x22b
│ │ ├── 1_test_mixtral.sh
│ │ └── 2_test_mixtral.sh
│ ├── 8x7b
│ │ ├── 1_test_mixtral.sh
│ │ └── 2_test_mixtral.sh
│ └── Run_Mixtral.md
│ ├── test_checkpoint_resharding.sh
│ ├── test_convergence_1b_params.sh
│ ├── test_decode.sh
│ ├── test_decode_load_quantized_ckpt.sh
│ ├── test_decode_save_quantized_ckpt.sh
│ ├── test_determinism.sh
│ ├── test_dpo.sh
│ ├── test_gpt3.sh
│ ├── test_sft_trainer.sh
│ ├── test_tflops.sh
│ ├── test_tflops_16b_params.sh
│ ├── test_tflops_32b_params.sh
│ ├── test_tflops_64b_params.sh
│ └── test_vocab_creation.sh
├── getting_started
├── Data_Input_Perf.md
├── Data_Input_Pipeline.md
├── First_run.md
├── GCP_Workload_Observability.md
├── Knowledge_Distillation.md
├── Monitor_Goodput.md
├── Run_Llama2.md
├── Run_MaxText_via_multihost_job.md
├── Run_MaxText_via_multihost_runner.md
├── Run_MaxText_via_xpk.md
├── Sharding.md
└── Use_Vertex_AI_Tensorboard.md
├── gpu_multi_process_run.sh
├── maxtext_custom_wheels.Dockerfile
├── maxtext_db_dependencies.Dockerfile
├── maxtext_dependencies.Dockerfile
├── maxtext_gpu_dependencies.Dockerfile
├── maxtext_jax_ai_image.Dockerfile
├── maxtext_libtpu_path.Dockerfile
├── maxtext_runner.Dockerfile
├── multihost_job.py
├── multihost_runner.py
├── pedagogical_examples
├── __init__.py
├── non_spmd.py
├── shardings.py
└── shmap_collective_matmul.py
├── preflight.sh
├── pylintrc
├── pyproject.toml
├── pytest.ini
├── requirements.txt
├── requirements_with_jax_ai_image.txt
├── rto_setup.sh
├── setup.py
├── setup.sh
├── setup_gcsfuse.sh
├── setup_with_retries.sh
└── unit_test_and_lint.sh
/.dockerignore:
--------------------------------------------------------------------------------
1 | .git
2 |
--------------------------------------------------------------------------------
/.github/CODEOWNERS:
--------------------------------------------------------------------------------
1 | # Changes in this file should match with requiredReviewers in file .github/workflows/AddLabel.yml
2 | * @gobbleturk @khatwanimohit @bvandermoon @vipannalla @RissyRan @richjames0 @gagika @shralex @yangyuwei @SurbhiJainUSC @hengtaoguo @A9isha @aireenmei
3 |
4 | # Features
5 | MaxText/experimental/rl @A9isha @khatwanimohit @gagika @richjames0
6 | MaxText/input_pipeline @aireenmei @SurbhiJainUSC @richjames0
7 | MaxText/kernels/megablox @RissyRan @michelle-yooh @gagika @richjames0
8 | MaxText/kernels/ragged_attention.py @patemotter @vipannalla @richjames0
9 | MaxText/layers/pipeline.py @gobbleturk @richjames0
10 | MaxText/layers/moe.py @RissyRan @michelle-yooh @gagika @richjames0
11 | MaxText/layers/multi_token_prediction.py @parambole @RissyRan @gagika @richjames0
12 |
13 | # Inference
14 | MaxText/tests/inference @vipannalla @mitalisi @gpolovets1 @mailvijayasingh @jrplatin @patemotter @lumosis @richjames0
15 | MaxText/inference @vipannalla @mitalisi @gpolovets1 @mailvijayasingh @jrplatin @patemotter @lumosis @richjames0
16 | MaxText/inference_mlperf @vipannalla @mitalisi @gpolovets1 @mailvijayasingh @jrplatin @patemotter @lumosis @richjames0
17 |
18 | # Dockerfiles and dependencies
19 | *.Dockerfile @bvandermoon @yangyuwei @parambole @richjames0
20 | *.txt @bvandermoon @yangyuwei @parambole @richjames0
21 |
22 | # Workflow files
23 | .github/workflows @gobbleturk @khatwanimohit @shralex @parambole @richjames0
24 |
--------------------------------------------------------------------------------
/.github/PULL_REQUEST_TEMPLATE.md:
--------------------------------------------------------------------------------
1 | # Description
2 |
3 | Start with a short description of what the PR does and how this is a change from
4 | the past.
5 |
6 | The rest of the description includes relevant details and context, examples:
7 | - why is this change being made,
8 | - the problem being solved and any relevant context,
9 | - why this is a good solution,
10 | - some information about the specific implementation,
11 | - shortcomings of the solution and possible future improvements.
12 |
13 | If the change fixes a bug or a Github issue, please include a link, e.g.,:
14 | FIXES: b/123456
15 | FIXES: #123456
16 |
17 | *Notice 1:* Once all tests pass, the "pull ready" label will automatically be assigned.
18 | This label is used for administrative purposes. Please do not add it manually.
19 |
20 | *Notice 2:* For external contributions, our settings currently require an approval from a MaxText maintainer to trigger CI tests.
21 |
22 | # Tests
23 |
24 | Please describe how you tested this change, and include any instructions and/or
25 | commands to reproduce.
26 |
27 | # Checklist
28 |
29 | Before submitting this PR, please make sure (put X in square brackets):
30 | - [ ] I have performed a self-review of my code.
31 | - [ ] I have necessary comments in my code, particularly in hard-to-understand areas.
32 | - [ ] I have run end-to-end tests tests and provided workload links above if applicable.
33 | - [ ] I have made or will make corresponding changes to the doc if needed.
34 |
--------------------------------------------------------------------------------
/.github/workflows/require-checklist.yml:
--------------------------------------------------------------------------------
1 | name: Require Checklist
2 | on:
3 | pull_request:
4 | types: [opened, edited, synchronize]
5 | jobs:
6 | check_pr_body:
7 | runs-on: ubuntu-latest
8 | steps:
9 | - uses: mheap/require-checklist-action@v2
10 | with:
11 | requireChecklist: true # If this is true and there are no checklists detected, the action will fail
12 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/codespell-project/codespell
3 | rev: v2.2.4
4 | hooks:
5 | - id: codespell
6 | name: Running codespell for typos
7 | entry: codespell -w --skip="*.txt,pylintrc,.*,assets/*" .
8 |
--------------------------------------------------------------------------------
/.vscode/settings.json:
--------------------------------------------------------------------------------
1 | {
2 | "python.testing.pytestArgs": [],
3 | "python.testing.cwd": "${workspaceFolder}/MaxText",
4 | "python.testing.unittestEnabled": false,
5 | "python.testing.pytestEnabled": true
6 | }
7 |
--------------------------------------------------------------------------------
/AUTHORS:
--------------------------------------------------------------------------------
1 | Google LLC
2 |
3 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | We'd love to accept your patches and contributions to this project.
4 |
5 | ## Before you begin
6 |
7 | ### Sign our Contributor License Agreement
8 |
9 | Contributions to this project must be accompanied by a
10 | [Contributor License Agreement](https://cla.developers.google.com/about) (CLA).
11 | You (or your employer) retain the copyright to your contribution; this simply
12 | gives us permission to use and redistribute your contributions as part of the
13 | project.
14 |
15 | If you or your current employer have already signed the Google CLA (even if it
16 | was for a different project), you probably don't need to do it again.
17 |
18 | Visit to see your current agreements or to
19 | sign a new one.
20 |
21 | ### Review our Community Guidelines
22 |
23 | This project follows
24 | [Google's Open Source Community Guidelines](https://opensource.google/conduct/).
25 |
26 | ## Contribution process
27 |
28 | ### Code Reviews
29 |
30 | All submissions, including submissions by project members, require review. We
31 | use GitHub pull requests for this purpose. Consult
32 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
33 | information on using pull requests.
--------------------------------------------------------------------------------
/MaxText/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2023 Google LLC
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | https://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | __author__ = "Google LLC"
18 | __version__ = "2025.04.25"
19 | __description__ = (
20 | "MaxText is a high performance, highly scalable, open-source LLM written in pure Python/Jax and "
21 | "targeting Google Cloud TPUs and GPUs for training and **inference."
22 | )
23 |
--------------------------------------------------------------------------------
/MaxText/configs/a3/llama_2_7b/16vm.sh:
--------------------------------------------------------------------------------
1 | echo "Running 16vm.sh"
2 | # Example command to invoke this script via XPK
3 | # python3 xpk/xpk.py workload create --cluster ${CLUSTER_NAME} \
4 | # --workload ${WORKLOAD_NAME} --docker-image=gcr.io/supercomputer-testing/${LOCAL_IMAGE_NAME} \
5 | # --device-type ${DEVICE_TYPE} --num-slices 16 --priority=high \
6 | # --command "bash MaxText/configs/a3/llama_2_7b/16vm.sh"
7 |
8 | # Stop execution if any command exits with error
9 | set -e
10 |
11 | export OUTPUT_PATH="gs://maxtext-experiments-multipod"
12 | export RUN_NAME="llama-2-16vm-$(date +%Y-%m-%d-%H-%M)"
13 | export EXECUTABLE="train"
14 |
15 | # Set environment variables
16 | for ARGUMENT in "$@"; do
17 | IFS='=' read -r KEY VALUE <<< "$ARGUMENT"
18 | export "$KEY"="$VALUE"
19 | done
20 |
21 | export XLA_FLAGS="--xla_dump_to=$OUTPUT_PATH/$RUN_NAME/HLO_dumps/
22 | --xla_gpu_enable_latency_hiding_scheduler=true
23 | --xla_gpu_enable_triton_gemm=false --xla_gpu_graph_level=0
24 | --xla_gpu_enable_highest_priority_async_stream=true
25 | --xla_gpu_all_reduce_combine_threshold_bytes=1073741824 --xla_gpu_all_gather_combine_threshold_bytes=134217728
26 | --xla_gpu_reduce_scatter_combine_threshold_bytes=134217728 --xla_gpu_enable_pipelined_all_gather=true
27 | --xla_gpu_enable_pipelined_reduce_scatter=true --xla_gpu_enable_pipelined_all_reduce=true
28 | --xla_gpu_enable_while_loop_double_buffering=true
29 | --xla_gpu_enable_all_gather_combine_by_dim=false --xla_gpu_enable_reduce_scatter_combine_by_dim=false
30 | --xla_disable_hlo_passes=rematerialization"
31 |
32 | # 16 nodes
33 | python3 -m MaxText.$EXECUTABLE MaxText/configs/models/gpu/llama2_7b.yml run_name=$RUN_NAME \
34 | dcn_data_parallelism=16 ici_fsdp_parallelism=8 base_output_directory=$OUTPUT_PATH profiler=xplane
35 |
--------------------------------------------------------------------------------
/MaxText/configs/a3/llama_2_7b/1vm.sh:
--------------------------------------------------------------------------------
1 | echo "Running 1vm.sh"
2 |
3 | # Example command to invoke this script via XPK
4 | # python3 xpk/xpk.py workload create --cluster ${CLUSTER_NAME} \
5 | # --workload ${WORKLOAD_NAME} --docker-image=gcr.io/supercomputer-testing/${LOCAL_IMAGE_NAME} \
6 | # --device-type ${DEVICE_TYPE} --num-slices 1 \
7 | # --command "bash MaxText/configs/a3/llama_2_7b/1vm.sh"
8 |
9 | # Stop execution if any command exits with error
10 | set -e
11 |
12 | export OUTPUT_PATH="gs://maxtext-experiments-multipod"
13 | export RUN_NAME="llama-2-1vm-$(date +%Y-%m-%d-%H-%M)"
14 |
15 | # Set environment variables
16 | for ARGUMENT in "$@"; do
17 | IFS='=' read -r KEY VALUE <<< "$ARGUMENT"
18 | export "$KEY"="$VALUE"
19 | done
20 |
21 | export XLA_FLAGS="--xla_dump_to=$OUTPUT_PATH/$RUN_NAME/HLO_dumps/
22 | --xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_triton_gemm=false
23 | --xla_gpu_graph_level=0 --xla_gpu_enable_highest_priority_async_stream=true
24 | --xla_gpu_all_reduce_combine_threshold_bytes=134217728 --xla_gpu_all_gather_combine_threshold_bytes=134217728
25 | --xla_gpu_reduce_scatter_combine_threshold_bytes=67108864 --xla_gpu_enable_pipelined_all_gather=true
26 | --xla_gpu_enable_pipelined_reduce_scatter=true --xla_gpu_enable_pipelined_all_reduce=true
27 | --xla_gpu_enable_while_loop_double_buffering=true
28 | --xla_gpu_enable_all_gather_combine_by_dim=false --xla_gpu_enable_reduce_scatter_combine_by_dim=false
29 | --xla_disable_hlo_passes=rematerialization"
30 |
31 |
32 | # 1 node, DATA_DP=1, ICI_FSDP=8
33 | python3 -m MaxText.train MaxText/configs/models/gpu/llama2_7b.yml run_name=$RUN_NAME \
34 | dcn_data_parallelism=1 ici_fsdp_parallelism=8 base_output_directory=$OUTPUT_PATH profiler=xplane
35 |
--------------------------------------------------------------------------------
/MaxText/configs/a3/llama_2_7b/2vm.sh:
--------------------------------------------------------------------------------
1 | echo "Running 2vm.sh"
2 |
3 | # Example command to invoke this script via XPK
4 | # python3 xpk/xpk.py workload create --cluster ${CLUSTER_NAME} \
5 | # --workload ${WORKLOAD_NAME} --docker-image=gcr.io/supercomputer-testing/${LOCAL_IMAGE_NAME} \
6 | # --device-type ${DEVICE_TYPE} --num-slices 2 \
7 | # --command "bash MaxText/configs/a3/llama_2_7b/2vm.sh"
8 |
9 | # Stop execution if any command exits with error
10 | set -e
11 |
12 | export OUTPUT_PATH="gs://maxtext-experiments-multipod"
13 | export RUN_NAME="llama-2-2vm-$(date +%Y-%m-%d-%H-%M)"
14 |
15 | # Set environment variables
16 | for ARGUMENT in "$@"; do
17 | IFS='=' read -r KEY VALUE <<< "$ARGUMENT"
18 | export "$KEY"="$VALUE"
19 | done
20 |
21 | export XLA_FLAGS="--xla_dump_to=$OUTPUT_PATH/$RUN_NAME/HLO_dumps/
22 | --xla_gpu_enable_latency_hiding_scheduler=true
23 | --xla_gpu_enable_triton_gemm=false --xla_gpu_graph_level=0
24 | --xla_gpu_enable_highest_priority_async_stream=true
25 | --xla_gpu_all_reduce_combine_threshold_bytes=67108864 --xla_gpu_all_gather_combine_threshold_bytes=134217728
26 | --xla_gpu_reduce_scatter_combine_threshold_bytes=67108864 --xla_gpu_enable_pipelined_all_gather=true
27 | --xla_gpu_enable_pipelined_reduce_scatter=true --xla_gpu_enable_pipelined_all_reduce=true
28 | --xla_gpu_enable_while_loop_double_buffering=true
29 | --xla_gpu_enable_all_gather_combine_by_dim=false --xla_gpu_enable_reduce_scatter_combine_by_dim=false
30 | --xla_disable_hlo_passes=rematerialization"
31 |
32 |
33 | # 2 nodes
34 | python3 -m MaxText.train MaxText/configs/models/gpu/llama2_7b.yml run_name=$RUN_NAME \
35 | dcn_data_parallelism=2 ici_fsdp_parallelism=8 base_output_directory=$OUTPUT_PATH profiler=xplane
36 |
--------------------------------------------------------------------------------
/MaxText/configs/a3/llama_2_7b/4vm.sh:
--------------------------------------------------------------------------------
1 | echo "Running 4vm.sh"
2 | # Example command to invoke this script via XPK
3 | # python3 xpk/xpk.py workload create --cluster ${CLUSTER_NAME} \
4 | # --workload ${WORKLOAD_NAME} --docker-image=gcr.io/supercomputer-testing/${LOCAL_IMAGE_NAME} \
5 | # --device-type ${DEVICE_TYPE} --num-slices 4 \
6 | # --command "bash MaxText/configs/a3/llama_2_7b/4vm.sh"
7 |
8 | # Stop execution if any command exits with error
9 | set -e
10 |
11 | export OUTPUT_PATH="gs://maxtext-experiments-multipod"
12 | export RUN_NAME="llama-2-4vm-$(date +%Y-%m-%d-%H-%M)"
13 |
14 | # Set environment variables
15 | for ARGUMENT in "$@"; do
16 | IFS='=' read -r KEY VALUE <<< "$ARGUMENT"
17 | export "$KEY"="$VALUE"
18 | done
19 |
20 | export XLA_FLAGS="--xla_dump_to=$OUTPUT_PATH/$RUN_NAME/HLO_dumps/
21 | --xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_triton_gemm=false
22 | --xla_gpu_graph_level=0 --xla_gpu_enable_highest_priority_async_stream=true
23 | --xla_gpu_all_reduce_combine_threshold_bytes=536870912 --xla_gpu_all_gather_combine_threshold_bytes=134217728
24 | --xla_gpu_reduce_scatter_combine_threshold_bytes=67108864 --xla_gpu_enable_pipelined_all_gather=true
25 | --xla_gpu_enable_pipelined_reduce_scatter=true --xla_gpu_enable_pipelined_all_reduce=true
26 | --xla_gpu_enable_while_loop_double_buffering=true
27 | --xla_gpu_enable_all_gather_combine_by_dim=false --xla_gpu_enable_reduce_scatter_combine_by_dim=false
28 | --xla_disable_hlo_passes=rematerialization"
29 |
30 | # 4 nodes
31 | python3 -m MaxText.train MaxText/configs/models/gpu/llama2_7b.yml run_name=$RUN_NAME \
32 | dcn_data_parallelism=4 ici_fsdp_parallelism=8 base_output_directory=$OUTPUT_PATH profiler=xplane
33 |
--------------------------------------------------------------------------------
/MaxText/configs/a3/llama_2_7b/8vm.sh:
--------------------------------------------------------------------------------
1 | echo "Running 8vm.sh"
2 | # Example command to invoke this script via XPK
3 | # python3 xpk/xpk.py workload create --cluster ${CLUSTER_NAME} \
4 | # --workload ${WORKLOAD_NAME} --docker-image=gcr.io/supercomputer-testing/${LOCAL_IMAGE_NAME} \
5 | # --device-type ${DEVICE_TYPE} --num-slices 8 \
6 | # --command "bash MaxText/configs/a3/llama_2_7b/8vm.sh"
7 |
8 | # Stop execution if any command exits with error
9 | set -e
10 |
11 | export OUTPUT_PATH="gs://maxtext-experiments-multipod"
12 | export RUN_NAME="llama-2-8vm-$(date +%Y-%m-%d-%H-%M)"
13 |
14 | # Set environment variables
15 | for ARGUMENT in "$@"; do
16 | IFS='=' read -r KEY VALUE <<< "$ARGUMENT"
17 | export "$KEY"="$VALUE"
18 | done
19 |
20 | export XLA_FLAGS="--xla_dump_to=$OUTPUT_PATH/$RUN_NAME/HLO_dumps/
21 | --xla_gpu_enable_latency_hiding_scheduler=true
22 | --xla_gpu_enable_triton_gemm=false --xla_gpu_graph_level=0
23 | --xla_gpu_enable_highest_priority_async_stream=true
24 | --xla_gpu_all_reduce_combine_threshold_bytes=1073741824 --xla_gpu_all_gather_combine_threshold_bytes=134217728
25 | --xla_gpu_reduce_scatter_combine_threshold_bytes=67108864 --xla_gpu_enable_pipelined_all_gather=true
26 | --xla_gpu_enable_pipelined_reduce_scatter=true --xla_gpu_enable_pipelined_all_reduce=true
27 | --xla_gpu_enable_while_loop_double_buffering=true
28 | --xla_gpu_enable_all_gather_combine_by_dim=false --xla_gpu_enable_reduce_scatter_combine_by_dim=false
29 | --xla_disable_hlo_passes=rematerialization"
30 |
31 | # 8 nodes
32 | python3 -m MaxText.train MaxText/configs/models/gpu/llama2_7b.yml run_name=$RUN_NAME \
33 | dcn_data_parallelism=8 ici_fsdp_parallelism=8 base_output_directory=$OUTPUT_PATH profiler=xplane
34 |
--------------------------------------------------------------------------------
/MaxText/configs/a3/llama_2_7b/README.md:
--------------------------------------------------------------------------------
1 |
16 |
17 | # High Performance Model Configs on A3 GPU
18 | Expected performance results for Llama2-7B model running on A3 GPU:
19 |
20 |
21 | ### Llama2-7B
22 | | Hardware | TFLOP/sec/chip |
23 | | ---------------------- | ---------------- |
24 | | 1x A3 (h100-80gb-8) | 492 |
25 | | 2x A3 (h100-80gb-8) | 422 |
26 | | 4x A3 (h100-80gb-8) | 407 |
27 | | 8x A3 (h100-80gb-8) | 409 |
28 | | 16x A3 (h100-80gb-8) | 375 |
29 |
--------------------------------------------------------------------------------
/MaxText/configs/dpo.yml:
--------------------------------------------------------------------------------
1 | base_config: "base.yml"
2 |
3 | use_dpo: true
4 | train_data_columns: ['chosen', 'rejected']
5 | eval_data_columns: ['chosen', 'rejected']
6 | base_output_directory: 'gs://maxtext-external/logs'
7 |
8 | per_device_batch_size: 2.0
9 | steps: 10
10 | max_target_length: 512
11 | eval_interval: 5 # test eval once, in the middle of 10 training steps
12 | eval_steps: 2
13 |
14 | # TFDS Pipeline ----------------------
15 | dataset_type: tfds
16 | dataset_path: 'gs://maxtext-dataset/dpo/anthropic_rlhf'
17 | dataset_name: 'tfds:1.0.0'
18 | eval_dataset_name: 'tfds:1.0.0'
19 | eval_split: 'test'
20 |
21 | # HF Pipeline -------------------------
22 | hf_eval_split: 'test'
23 |
24 | gradient_clipping_threshold: 10.0
25 | learning_rate: 5.0e-7
26 | dpo_label_smoothing: 0.0
27 | dpo_beta: 0.1
28 |
29 | enable_goodput_recording: false
30 | monitor_goodput: false
31 | enable_checkpointing: true
32 |
--------------------------------------------------------------------------------
/MaxText/configs/experimental/1024b.sh:
--------------------------------------------------------------------------------
1 | echo "Running 1024b.sh"
2 | # Example command to invoke this script
3 | # bash MaxText/configs/experimental/1024b.sh
4 |
5 | # Stop execution if any command exits with error
6 | set -e
7 |
8 | export OUTPUT_PATH="gs://maxtext-experiments-multipod"
9 | export DATASET_PATH="gs://maxtext-dataset/"
10 |
11 | # Set environment variables
12 | for ARGUMENT in "$@"; do
13 | IFS='=' read -r KEY VALUE <<< "$ARGUMENT"
14 | export "$KEY"="$VALUE"
15 | done
16 |
17 | # Use preflight.sh to set up env based on platform
18 | bash preflight.sh PLATFORM=$PLATFORM
19 |
20 | # Train
21 | export LIBTPU_INIT_ARGS="--xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"
22 | python3 -m MaxText.train MaxText/configs/base.yml run_name=$RUN_NAME\
23 | steps=20 per_device_batch_size=2 enable_checkpointing=false\
24 | remat_policy=full global_parameter_scale=1024\
25 | ici_fsdp_parallelism=-1 ici_tensor_parallelism=16\
26 | max_target_length=2048 base_output_directory=$OUTPUT_PATH\
27 | dataset_path=$DATASET_PATH use_iota_embed=true reuse_example_batch=1\
28 | dataset_type=synthetic gcs_metrics=true attention='flash' quantization=""
29 |
--------------------------------------------------------------------------------
/MaxText/configs/experimental/128b.sh:
--------------------------------------------------------------------------------
1 | echo "Running 128b.sh"
2 | # Example command to invoke this script
3 | # bash MaxText/configs/experimental/128b.sh
4 |
5 | # Stop execution if any command exits with error
6 | set -e
7 |
8 | export OUTPUT_PATH="gs://maxtext-experiments-multipod"
9 | export DATASET_PATH="gs://maxtext-dataset/"
10 |
11 | # Set environment variables
12 | for ARGUMENT in "$@"; do
13 | IFS='=' read -r KEY VALUE <<< "$ARGUMENT"
14 | export "$KEY"="$VALUE"
15 | done
16 |
17 | # Use preflight.sh to set up env based on platform
18 | bash preflight.sh PLATFORM=$PLATFORM
19 |
20 | # Train
21 | export LIBTPU_INIT_ARGS="--xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"
22 | python3 -m MaxText.train MaxText/configs/base.yml run_name=$RUN_NAME\
23 | steps=30 per_device_batch_size=2 enable_checkpointing=false\
24 | remat_policy=minimal global_parameter_scale=128\
25 | ici_fsdp_parallelism=-1 ici_tensor_parallelism=8\
26 | max_target_length=2048 base_output_directory=$OUTPUT_PATH\
27 | dataset_path=$DATASET_PATH use_iota_embed=true reuse_example_batch=1\
28 | dataset_type=synthetic gcs_metrics=true attention='flash' quantization=""\
29 |
--------------------------------------------------------------------------------
/MaxText/configs/experimental/256b.sh:
--------------------------------------------------------------------------------
1 | echo "Running 256b.sh"
2 | # Example command to invoke this script
3 | # bash MaxText/configs/experimental/256b.sh
4 |
5 | # Stop execution if any command exits with error
6 | set -e
7 |
8 | export OUTPUT_PATH="gs://maxtext-experiments-multipod"
9 | export DATASET_PATH="gs://maxtext-dataset/"
10 |
11 | # Set environment variables
12 | for ARGUMENT in "$@"; do
13 | IFS='=' read -r KEY VALUE <<< "$ARGUMENT"
14 | export "$KEY"="$VALUE"
15 | done
16 |
17 | # Use preflight.sh to set up env based on platform
18 | bash preflight.sh PLATFORM=$PLATFORM
19 |
20 | # Train
21 | export LIBTPU_INIT_ARGS="--xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"
22 | python3 -m MaxText.train MaxText/configs/base.yml run_name=$RUN_NAME\
23 | steps=20 per_device_batch_size=2 enable_checkpointing=false\
24 | remat_policy=minimal global_parameter_scale=256\
25 | ici_fsdp_parallelism=-1 ici_tensor_parallelism=8\
26 | max_target_length=2048 base_output_directory=$OUTPUT_PATH\
27 | dataset_path=$DATASET_PATH use_iota_embed=true reuse_example_batch=1\
28 | dataset_type=synthetic gcs_metrics=true attention='flash' quantization=""
29 |
--------------------------------------------------------------------------------
/MaxText/configs/experimental/32b.sh:
--------------------------------------------------------------------------------
1 | echo "Running 32b.sh"
2 | # Example command to invoke this script
3 | # bash MaxText/configs/experimental/32b.sh
4 |
5 | # Stop execution if any command exits with error
6 | set -e
7 |
8 | export OUTPUT_PATH="gs://maxtext-experiments-multipod"
9 | export DATASET_PATH="gs://maxtext-dataset/"
10 |
11 | # Set environment variables
12 | for ARGUMENT in "$@"; do
13 | IFS='=' read -r KEY VALUE <<< "$ARGUMENT"
14 | export "$KEY"="$VALUE"
15 | done
16 |
17 | # Use preflight.sh to set up env based on platform
18 | bash preflight.sh PLATFORM=$PLATFORM
19 |
20 | # Train
21 | export LIBTPU_INIT_ARGS="--xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"
22 | python3 -m MaxText.train MaxText/configs/base.yml run_name=$RUN_NAME\
23 | steps=30 per_device_batch_size=8 enable_checkpointing=false\
24 | remat_policy=minimal global_parameter_scale=32\
25 | ici_fsdp_parallelism=-1 ici_tensor_parallelism=4\
26 | max_target_length=2048 base_output_directory=$OUTPUT_PATH\
27 | dataset_path=$DATASET_PATH use_iota_embed=true reuse_example_batch=1\
28 | dataset_type=synthetic gcs_metrics=true attention='flash' quantization=""\
29 |
--------------------------------------------------------------------------------
/MaxText/configs/experimental/512b.sh:
--------------------------------------------------------------------------------
1 | echo "Running 512b.sh"
2 | # Example command to invoke this script
3 | # bash MaxText/configs/experimental/512b.sh
4 |
5 | # Stop execution if any command exits with error
6 | set -e
7 |
8 | export OUTPUT_PATH="gs://maxtext-experiments-multipod"
9 | export DATASET_PATH="gs://maxtext-dataset/"
10 |
11 | # Set environment variables
12 | for ARGUMENT in "$@"; do
13 | IFS='=' read -r KEY VALUE <<< "$ARGUMENT"
14 | export "$KEY"="$VALUE"
15 | done
16 |
17 | # Use preflight.sh to set up env based on platform
18 | bash preflight.sh PLATFORM=$PLATFORM
19 |
20 | # Train
21 | export LIBTPU_INIT_ARGS="--xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"
22 | python3 -m MaxText.train MaxText/configs/base.yml run_name=$RUN_NAME\
23 | steps=20 per_device_batch_size=2 enable_checkpointing=false\
24 | remat_policy=full global_parameter_scale=512\
25 | ici_fsdp_parallelism=-1 ici_tensor_parallelism=16\
26 | max_target_length=2048 base_output_directory=$OUTPUT_PATH\
27 | dataset_path=$DATASET_PATH use_iota_embed=true reuse_example_batch=1\
28 | dataset_type=synthetic gcs_metrics=true attention='flash' quantization=""
29 |
--------------------------------------------------------------------------------
/MaxText/configs/experimental/64b.sh:
--------------------------------------------------------------------------------
1 | echo "Running 64b.sh"
2 | # Example command to invoke this script
3 | # bash MaxText/configs/experimental/64b.sh
4 |
5 | # Stop execution if any command exits with error
6 | set -e
7 |
8 | export OUTPUT_PATH="gs://maxtext-experiments-multipod"
9 | export DATASET_PATH="gs://maxtext-dataset/"
10 |
11 | # Set environment variables
12 | for ARGUMENT in "$@"; do
13 | IFS='=' read -r KEY VALUE <<< "$ARGUMENT"
14 | export "$KEY"="$VALUE"
15 | done
16 |
17 | # Use preflight.sh to set up env based on platform
18 | bash preflight.sh PLATFORM=$PLATFORM
19 |
20 | # Train
21 | export LIBTPU_INIT_ARGS="--xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"
22 | python3 -m MaxText.train MaxText/configs/base.yml run_name=$RUN_NAME\
23 | steps=30 per_device_batch_size=4 enable_checkpointing=false\
24 | remat_policy=minimal global_parameter_scale=64\
25 | ici_fsdp_parallelism=-1 ici_tensor_parallelism=4\
26 | max_target_length=2048 base_output_directory=$OUTPUT_PATH\
27 | dataset_path=$DATASET_PATH use_iota_embed=true reuse_example_batch=1\
28 | dataset_type=synthetic gcs_metrics=true attention='flash' quantization=""\
29 |
--------------------------------------------------------------------------------
/MaxText/configs/gpu_smoke_test.yml:
--------------------------------------------------------------------------------
1 | base_config: "base.yml"
2 |
3 | hardware: "gpu"
4 | attention: "dot_product"
5 | base_emb_dim: 8
6 | base_num_query_heads: 4
7 | base_num_kv_heads: 4
8 | base_mlp_dim: 32
9 | base_num_decoder_layers: 8
10 | head_dim: 16
11 | per_device_batch_size: 2
12 | max_target_length: 1024
13 | dataset_type: "synthetic"
14 | steps: 10
15 |
--------------------------------------------------------------------------------
/MaxText/configs/inference_jetstream.yml:
--------------------------------------------------------------------------------
1 | base_config: "base.yml"
2 |
3 | enable_jax_profiler: False
4 | jax_profiler_port: 9999
5 |
6 | enable_model_warmup: False
--------------------------------------------------------------------------------
/MaxText/configs/models/deepseek2-16b.yml:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # model config for DeepSeek V2-Lite - 16B
16 |
17 | base_emb_dim: 2048
18 | base_num_query_heads: 16
19 | base_num_kv_heads: 16
20 | base_mlp_dim: 10944
21 | base_moe_mlp_dim: 1408
22 | base_num_decoder_layers: 27
23 | first_num_dense_layers: 1
24 | mlp_activations: ["silu","linear"]
25 | vocab_size: 102400
26 | enable_dropout: False
27 | logits_via_embedding: False
28 | normalization_layer_epsilon: 1.0e-6
29 | num_experts: 64
30 | num_experts_per_tok: 6
31 | shared_experts: 2
32 | routed_scaling_factor: 1.0
33 | routed_score_func: "softmax"
34 | routed_bias: False
35 | decoder_block: "deepseek"
36 | # MLA
37 | attention_type: "mla"
38 | q_lora_rank: 0
39 | kv_lora_rank: 512
40 | qk_nope_head_dim: 128
41 | qk_rope_head_dim: 64
42 | v_head_dim: 128
43 | # RoPE
44 | rope_type: "yarn"
45 | rope_max_timescale: 10_000 # DeepSeek uses "rope_theta": 10000
46 | max_position_embeddings: 163840
47 | original_max_position_embeddings: 4096
48 | rope_factor: 40
49 | beta_fast: 32
50 | mscale: 0.707
51 |
--------------------------------------------------------------------------------
/MaxText/configs/models/deepseek2-236b.yml:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # model config for DeepSeek V2 - 236B
16 | # Please note: DeepSeek V2 - 236B is not fully support at this moment
17 |
18 | base_emb_dim: 5120
19 | base_num_query_heads: 128
20 | base_num_kv_heads: 128
21 | base_mlp_dim: 12288
22 | base_moe_mlp_dim: 1536
23 | base_num_decoder_layers: 60
24 | first_num_dense_layers: 1
25 | mlp_activations: ["silu","linear"]
26 | vocab_size: 102400
27 | enable_dropout: False
28 | logits_via_embedding: False
29 | normalization_layer_epsilon: 1.0e-6
30 | num_experts: 160
31 | num_experts_per_tok: 6
32 | shared_experts: 2
33 | routed_scaling_factor: 16.0
34 | routed_score_func: "softmax"
35 | routed_bias: False
36 | decoder_block: "deepseek"
37 | # MLA
38 | attention_type: "mla"
39 | q_lora_rank: 1536
40 | kv_lora_rank: 512
41 | qk_nope_head_dim: 128
42 | qk_rope_head_dim: 64
43 | v_head_dim: 128
44 | # RoPE
45 | rope_type: "yarn"
46 | rope_max_timescale: 10_000 # DeepSeek uses "rope_theta": 10000
47 | max_position_embeddings: 163840
48 | original_max_position_embeddings: 4096
49 | rope_factor: 40
50 | beta_fast: 32
51 | mscale: 0.707
52 |
--------------------------------------------------------------------------------
/MaxText/configs/models/deepseek3-671b.yml:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # model config for DeepSeek V3 - 671B
16 |
17 | # For DeepSeek default device-limited routing,
18 | # please set n_routing_groups=8 and topk_routing_group=4 in your command-line arguments.
19 |
20 | base_emb_dim: 7168
21 | base_num_query_heads: 128
22 | base_num_kv_heads: 128
23 | base_mlp_dim: 18432
24 | base_moe_mlp_dim: 2048
25 | base_num_decoder_layers: 61
26 | first_num_dense_layers: 3
27 | mlp_activations: ["silu","linear"]
28 | vocab_size: 129280
29 | enable_dropout: False
30 | logits_via_embedding: False
31 | normalization_layer_epsilon: 1.0e-6
32 | num_experts: 256
33 | num_experts_per_tok: 8
34 | shared_experts: 1
35 | routed_scaling_factor: 2.5
36 | routed_score_func: "sigmoid"
37 | routed_bias: True
38 | decoder_block: "deepseek"
39 | # MLA
40 | attention_type: "mla"
41 | q_lora_rank: 1536
42 | kv_lora_rank: 512
43 | qk_nope_head_dim: 128
44 | qk_rope_head_dim: 64
45 | v_head_dim: 128
46 | mscale: 1.0
47 | # RoPE
48 | rope_type: "yarn"
49 | rope_max_timescale: 10_000 # DeepSeek uses "rope_theta": 10000
50 | max_position_embeddings: 163840
51 | original_max_position_embeddings: 4096
52 | rope_factor: 40
53 | beta_fast: 32
54 |
--------------------------------------------------------------------------------
/MaxText/configs/models/gemma-2b.yml:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # model config for gemma-2b
16 |
17 | base_emb_dim: 2048
18 | base_num_query_heads: 8
19 | base_num_kv_heads: 1
20 | base_mlp_dim: 16384
21 | base_num_decoder_layers: 18
22 | head_dim: 256
23 | mlp_activations: ["gelu","linear"]
24 | vocab_size: 256128
25 | decoder_block: "gemma"
26 | normalization_layer_epsilon: 1.e-06
27 | logits_via_embedding: True
--------------------------------------------------------------------------------
/MaxText/configs/models/gemma-7b.yml:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # model config for gemma-7b
16 |
17 | base_emb_dim: 3072
18 | base_num_query_heads: 16
19 | base_num_kv_heads: 16
20 | base_mlp_dim: 24576
21 | base_num_decoder_layers: 28
22 | head_dim: 256
23 | mlp_activations: ["gelu","linear"]
24 | vocab_size: 256128
25 | decoder_block: "gemma"
26 | normalization_layer_epsilon: 1.e-06
27 | logits_via_embedding: True
--------------------------------------------------------------------------------
/MaxText/configs/models/gemma2-27b.yml:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # model config for gemma2-27B
16 |
17 | base_emb_dim: 4608
18 | base_num_query_heads: 32
19 | base_num_kv_heads: 16
20 | base_mlp_dim: 36864
21 | base_num_decoder_layers: 23 # half of the real number of layers because we merge [local_attention, global_attention] into one layer
22 | head_dim: 128
23 | mlp_activations: ["gelu","linear"]
24 | vocab_size: 256128
25 | decoder_block: "gemma2"
26 | normalization_layer_epsilon: 1.e-06
27 | logits_via_embedding: True
28 | final_logits_soft_cap: 30.0
29 | attn_logits_soft_cap: 50.0
30 | sliding_window_size: 4096
31 | use_post_attn_norm: True
32 | use_post_ffw_norm: True
33 |
--------------------------------------------------------------------------------
/MaxText/configs/models/gemma2-2b.yml:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # model config for gemma-2B
16 |
17 | base_emb_dim: 2304
18 | base_num_query_heads: 8
19 | base_num_kv_heads: 4
20 | base_mlp_dim: 9216
21 | base_num_decoder_layers: 13 # half of the real number of layers because we merge [local_attention, global_attention] into one layer
22 | head_dim: 256
23 | mlp_activations: ["gelu","linear"]
24 | vocab_size: 256128
25 | decoder_block: "gemma2"
26 | normalization_layer_epsilon: 1.e-06
27 | logits_via_embedding: True
28 | final_logits_soft_cap: 30.0
29 | attn_logits_soft_cap: 50.0
30 | sliding_window_size: 4096
31 | use_post_attn_norm: True
32 | use_post_ffw_norm: True
33 |
--------------------------------------------------------------------------------
/MaxText/configs/models/gemma2-9b.yml:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # model config for gemma2-9B
16 |
17 | base_emb_dim: 3584
18 | base_num_query_heads: 16
19 | base_num_kv_heads: 8
20 | base_mlp_dim: 14336
21 | base_num_decoder_layers: 21 # half of the real number of layers because we merge [local_attention, global_attention] into one layer
22 | head_dim: 256
23 | mlp_activations: ["gelu","linear"]
24 | vocab_size: 256128
25 | decoder_block: "gemma2"
26 | normalization_layer_epsilon: 1.e-06
27 | logits_via_embedding: True
28 | final_logits_soft_cap: 30.0
29 | attn_logits_soft_cap: 50.0
30 | sliding_window_size: 4096
31 | use_post_attn_norm: True
32 | use_post_ffw_norm: True
33 |
--------------------------------------------------------------------------------
/MaxText/configs/models/gemma3-12b.yml:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # model config for gemma3-12b
16 |
17 | base_num_decoder_layers: 48
18 | base_emb_dim: 3840
19 | base_num_query_heads: 16
20 | base_num_kv_heads: 8
21 | base_mlp_dim: 15360
22 | head_dim: 256
23 | mlp_activations: ["gelu","linear"]
24 | vocab_size: 262_144
25 | decoder_block: "gemma3"
26 | normalization_layer_epsilon: 1e-6
27 | logits_via_embedding: True
28 | sliding_window_size: 1024
29 | use_post_attn_norm: true
30 | use_post_ffw_norm: true
31 | local_rope_max_timescale: 10_000
32 | rope_max_timescale: 1_000_000
33 |
--------------------------------------------------------------------------------
/MaxText/configs/models/gemma3-27b.yml:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # model config for gemma3-27b
16 |
17 | base_num_decoder_layers: 62
18 | base_emb_dim: 5376
19 | base_num_query_heads: 32
20 | base_num_kv_heads: 16
21 | base_mlp_dim: 21504
22 | head_dim: 128
23 | mlp_activations: ["gelu","linear"]
24 | vocab_size: 262_144
25 | decoder_block: "gemma3"
26 | normalization_layer_epsilon: 1e-6
27 | logits_via_embedding: True
28 | sliding_window_size: 1024
29 | use_post_attn_norm: true
30 | use_post_ffw_norm: true
31 | local_rope_max_timescale: 10_000
32 | rope_max_timescale: 1_000_000
33 |
--------------------------------------------------------------------------------
/MaxText/configs/models/gemma3-4b.yml:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # model config for gemma3-4b
16 |
17 | base_num_decoder_layers: 34
18 | base_emb_dim: 2560
19 | base_num_query_heads: 8
20 | base_num_kv_heads: 4
21 | base_mlp_dim: 10240
22 | head_dim: 256
23 | mlp_activations: ["gelu","linear"]
24 | vocab_size: 262_144
25 | decoder_block: "gemma3"
26 | normalization_layer_epsilon: 1e-6
27 | logits_via_embedding: True
28 | sliding_window_size: 1024
29 | use_post_attn_norm: true
30 | use_post_ffw_norm: true
31 | local_rope_max_timescale: 10_000
32 | rope_max_timescale: 1_000_000
33 |
--------------------------------------------------------------------------------
/MaxText/configs/models/gpt3-175b.yml:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Google LLC
2 | #
3 | # Licensed under the Apache License Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing software
10 | # distributed under the License is distributed on an "AS IS" BASIS
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # model config for gpt3-175b
16 |
17 | base_emb_dim: 12288
18 | base_num_query_heads: 96
19 | base_num_kv_heads: 96
20 | base_mlp_dim: 49152
21 | base_num_decoder_layers: 96
22 | head_dim: 128
23 | trainable_position_size: 16384
24 | mlp_activations: ["gelu"]
25 | vocab_size: 50304
26 | enable_dropout: False
27 | logits_via_embedding: True
28 | normalize_embedding_logits: False
29 | logits_dot_in_fp32: False
30 | normalization_layer_epsilon: 1.e-05
31 | use_iota_embed: True
32 | fused_qkv: True
33 | opt_type: "adam_pax"
34 | decoder_block: "gpt3"
35 | dataset_path: "gs://mlperf-llm-public2"
36 | dataset_name: "c4/en:3.0.4"
37 | eval_dataset_name: "c4/en:3.0.5"
38 | gradient_clipping_threshold: 1.
39 | adam_b1: 0.9
40 | adam_b2: 0.95
41 | adam_eps: 1.e-8
42 | adam_weight_decay: 0.1
43 | checkpoint_period: 10_000
44 | target_eval_loss: 2.69
45 | eval_per_device_batch_size: 1.0
46 |
--------------------------------------------------------------------------------
/MaxText/configs/models/gpt3-22b.yml:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Google LLC
2 | #
3 | # Licensed under the Apache License Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing software
10 | # distributed under the License is distributed on an "AS IS" BASIS
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # model config for gpt3-22b
16 |
17 | base_emb_dim: 6144
18 | base_num_query_heads: 24
19 | base_num_kv_heads: 24
20 | base_mlp_dim: 24576
21 | base_num_decoder_layers: 48
22 | head_dim: 256
23 | max_target_length: 1024
24 | trainable_position_size: 16384
25 | mlp_activations: ["gelu"]
26 | vocab_size: 32768
27 | enable_dropout: False
28 | logits_via_embedding: True
29 | normalize_embedding_logits: False
30 | logits_dot_in_fp32: False
31 | normalization_layer_epsilon: 1.e-05
32 | use_iota_embed: True
33 | fused_qkv: True
34 | opt_type: "adam_pax"
35 | decoder_block: "gpt3"
36 | gradient_clipping_threshold: 1.
37 | adam_b1: 0.9
38 | adam_b2: 0.95
39 | adam_eps: 1.e-8
40 | adam_weight_decay: 0.1
41 |
--------------------------------------------------------------------------------
/MaxText/configs/models/gpt3-52k.yml:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Google LLC
2 | #
3 | # Licensed under the Apache License Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing software
10 | # distributed under the License is distributed on an "AS IS" BASIS
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # model config for gpt3-52k, i.e. a fake and small model for testing purpose only
16 |
17 | base_emb_dim: 16
18 | base_num_query_heads: 2
19 | base_num_kv_heads: 2
20 | base_mlp_dim: 64
21 | base_num_decoder_layers: 1
22 | head_dim: 8
23 | trainable_position_size: 2048
24 | mlp_activations: ["gelu"]
25 | vocab_size: 1024
26 | enable_dropout: False
27 | logits_via_embedding: True
28 | normalize_embedding_logits: False
29 | logits_dot_in_fp32: False
30 | normalization_layer_epsilon: 1.e-05
31 | use_iota_embed: True
32 | fused_qkv: True
33 | opt_type: "adam_pax"
34 | decoder_block: "gpt3"
35 | gradient_clipping_threshold: 1.
36 | adam_b1: 0.9
37 | adam_b2: 0.95
38 | adam_eps: 1.e-8
39 | adam_weight_decay: 0.1
40 | attention: "dot_product" # head_dim 8 is too small for splash/flash attention
41 |
--------------------------------------------------------------------------------
/MaxText/configs/models/gpt3-6b.yml:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Google LLC
2 | #
3 | # Licensed under the Apache License Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing software
10 | # distributed under the License is distributed on an "AS IS" BASIS
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # model config for gpt3-6b
16 |
17 | base_emb_dim: 3072
18 | base_num_query_heads: 12
19 | base_num_kv_heads: 12
20 | base_mlp_dim: 12288
21 | base_num_decoder_layers: 48
22 | head_dim: 256
23 | max_target_length: 1024
24 | trainable_position_size: 16384
25 | mlp_activations: ["gelu"]
26 | vocab_size: 32768
27 | enable_dropout: False
28 | logits_via_embedding: True
29 | normalize_embedding_logits: False
30 | logits_dot_in_fp32: False
31 | normalization_layer_epsilon: 1.e-05
32 | use_iota_embed: True
33 | fused_qkv: True
34 | opt_type: "adam_pax"
35 | decoder_block: "gpt3"
36 | gradient_clipping_threshold: 1.
37 | adam_b1: 0.9
38 | adam_b2: 0.95
39 | adam_eps: 1.e-8
40 | adam_weight_decay: 0.1
41 |
--------------------------------------------------------------------------------
/MaxText/configs/models/gpu/llama2_70b.yml:
--------------------------------------------------------------------------------
1 | base_config: "base.yml"
2 |
3 | run_name: "gpu_train_test"
4 | # Args coming from the NVIDIA spreadsheet http://shortn/_AhULYn1mX4.
5 | hardware: "gpu"
6 | steps: 30
7 | model_name: "llama2-70b"
8 | enable_checkpointing: False
9 | attention: "cudnn_flash_te"
10 | remat_policy: "full"
11 | use_iota_embed: True
12 | scan_layers: True
13 | dataset_type: "synthetic"
14 | async_checkpointing: False
15 | logits_dot_in_fp32: False
16 |
17 | per_device_batch_size: 6
18 | max_target_length: 4096
19 |
--------------------------------------------------------------------------------
/MaxText/configs/models/gpu/llama2_7b.yml:
--------------------------------------------------------------------------------
1 | base_config: "base.yml"
2 |
3 | run_name: "gpu_train_test"
4 | hardware: "gpu"
5 | steps: 30
6 | per_device_batch_size: 4
7 | max_target_length: 4096
8 | model_name: "llama2-7b"
9 | enable_checkpointing: False
10 | attention: "cudnn_flash_te"
11 | remat_policy: "minimal_flash"
12 | use_iota_embed: True
13 | scan_layers: False
14 | dataset_type: "synthetic"
15 | async_checkpointing: False
--------------------------------------------------------------------------------
/MaxText/configs/models/gpu/llama3.1_405b.yml:
--------------------------------------------------------------------------------
1 | base_config: "base.yml"
2 | run_name: "gpu_train_test"
3 | # Args coming from the NVIDIA spreadsheet http://shortn/_AhULYn1mX4.
4 | hardware: "gpu"
5 | steps: 10
6 | model_name: "llama3.1-405b"
7 | enable_checkpointing: False
8 | #attention: "cudnn_flash_te"
9 | remat_policy: "full"
10 | use_iota_embed: True
11 | scan_layers: True
12 | dataset_type: "synthetic"
13 | async_checkpointing: False
14 | logits_dot_in_fp32: False
15 | per_device_batch_size: 1.0
16 | max_target_length: 4096
17 |
--------------------------------------------------------------------------------
/MaxText/configs/models/gpu/llama3_70b.yml:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | base_config: "base.yml"
16 | # The model_name gurantees we will use the correct model from
17 | # configs/models/llama3-70b.yml, see update_model_vars in pyconfig.py for details.
18 | model_name: "llama3-70b"
19 |
20 | run_name: "gpu_train_test"
21 | hardware: "gpu"
22 | steps: 30
23 | per_device_batch_size: 4
24 | max_target_length: 8192
25 | attention: "cudnn_flash_te"
26 | remat_policy: "full"
27 | use_iota_embed: True
28 | dataset_type: "synthetic"
29 | reuse_example_batch: 1
30 | enable_checkpointing: False
31 |
--------------------------------------------------------------------------------
/MaxText/configs/models/gpu/llama3_8b.yml:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | base_config: "base.yml"
16 | # The model_name gurantees we will use the correct model from
17 | # configs/models/llama3-8b.yml, see update_model_vars in pyconfig.py for details.
18 | model_name: "llama3-8b"
19 |
20 | run_name: "gpu_train_test"
21 | hardware: "gpu"
22 | steps: 30
23 | per_device_batch_size: 12
24 | max_target_length: 8192
25 | attention: "cudnn_flash_te"
26 | remat_policy: "minimal_flash"
27 | use_iota_embed: True
28 | dataset_type: "synthetic"
29 | reuse_example_batch: 1
30 | enable_checkpointing: False
31 |
--------------------------------------------------------------------------------
/MaxText/configs/models/gpu/mixtral_8x1b.yml:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | base_config: "base.yml"
16 | # The model_name gurantees we use the correct model params from
17 | # configs/models/mixtral-8x7b. mixtral-8x1b is mixtral-8x7b but
18 | # with different base_num_decoder_layers.
19 | model_name: "mixtral-8x7b"
20 | base_num_decoder_layers: 5
21 |
22 | run_name: "gpu_train_test"
23 | hardware: "gpu"
24 | steps: 30
25 |
26 | per_device_batch_size: 8
27 | max_target_length: 4096
28 | attention: "cudnn_flash_te"
29 | remat_policy: "full"
30 | use_iota_embed: True
31 | dataset_type: "synthetic"
32 | reuse_example_batch: 1
33 | enable_checkpointing: False
34 | megablox: False
35 | sparse_matmul: False
36 |
--------------------------------------------------------------------------------
/MaxText/configs/models/gpu/mixtral_8x2b.yml:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | base_config: "base.yml"
16 | # The model_name gurantees we use the correct model params from
17 | # configs/models/mixtral-8x7b. mixtral-8x2b is mixtral-8x7b but
18 | # with different base_num_decoder_layers.
19 | model_name: "mixtral-8x7b"
20 | base_num_decoder_layers: 10
21 |
22 | run_name: "gpu_train_test"
23 | hardware: "gpu"
24 | steps: 30
25 |
26 | per_device_batch_size: 8
27 | max_target_length: 4096
28 | attention: "cudnn_flash_te"
29 | remat_policy: "full"
30 | use_iota_embed: True
31 | dataset_type: "synthetic"
32 | reuse_example_batch: 1
33 | enable_checkpointing: False
34 | megablox: False
35 | sparse_matmul: False
36 |
--------------------------------------------------------------------------------
/MaxText/configs/models/gpu/mixtral_8x7b.yml:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | base_config: "base.yml"
16 | # The model_name gurantees we will use the correct model from
17 | # configs/models/mixtral-8x7b.yml, see update_model_vars in pyconfig.py for details.
18 | model_name: "mixtral-8x7b"
19 |
20 | run_name: "gpu_train_test"
21 | hardware: "gpu"
22 | steps: 30
23 | per_device_batch_size: 12
24 | max_target_length: 4096
25 | attention: "cudnn_flash_te"
26 | remat_policy: "minimal_flash"
27 | use_iota_embed: True
28 | dataset_type: "synthetic"
29 | reuse_example_batch: 1
30 | enable_checkpointing: False
31 | megablox: False
32 | scan_layers: False
33 | tokenizer_path: "/deps/assets/tokenizer.mistral-v1"
34 | profiler: "nsys"
35 | capacity_factor: 1.0
36 |
--------------------------------------------------------------------------------
/MaxText/configs/models/llama2-13b.yml:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # model config for llama2-13b
16 |
17 | base_emb_dim: 5120
18 | base_num_query_heads: 40
19 | base_num_kv_heads: 40
20 | base_mlp_dim: 13824
21 | base_num_decoder_layers: 40
22 | head_dim: 128
23 | mlp_activations: ["silu","linear"]
24 | vocab_size: 32000
25 | enable_dropout: False
26 | logits_via_embedding: False
27 | normalization_layer_epsilon: 1.0e-5
28 | decoder_block: "llama2"
29 | logical_axis_rules: [['norm', 'fsdp']]
30 |
--------------------------------------------------------------------------------
/MaxText/configs/models/llama2-70b.yml:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # model config for llama2-7b
16 |
17 | base_emb_dim: 8192
18 | base_num_query_heads: 64
19 | base_num_kv_heads: 8
20 | base_mlp_dim: 28672
21 | base_num_decoder_layers: 80
22 | head_dim: 128
23 | mlp_activations: ["silu","linear"]
24 | vocab_size: 32000
25 | logits_via_embedding: False
26 | normalization_layer_epsilon: 1.0e-5
27 | decoder_block: "llama2"
28 | logical_axis_rules: [['norm', 'fsdp']]
29 |
--------------------------------------------------------------------------------
/MaxText/configs/models/llama2-7b.yml:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # model config for llama2-7b
16 |
17 | base_emb_dim: 4096
18 | base_num_query_heads: 32
19 | base_num_kv_heads: 32
20 | base_mlp_dim: 11008
21 | base_num_decoder_layers: 32
22 | head_dim: 128
23 | mlp_activations: ["silu","linear"]
24 | vocab_size: 32000
25 | enable_dropout: False
26 | logits_via_embedding: False
27 | normalization_layer_epsilon: 1.0e-5
28 | decoder_block: "llama2"
29 | logical_axis_rules: [['norm', 'fsdp']]
30 |
--------------------------------------------------------------------------------
/MaxText/configs/models/llama3-405b.yml:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # model config for llama3-405b
16 |
17 |
18 | base_emb_dim: 16384
19 | base_num_query_heads: 128
20 | base_num_kv_heads: 8
21 | base_num_decoder_layers: 126
22 | base_mlp_dim: 53248
23 | head_dim: 128
24 | mlp_activations: ["silu","linear"]
25 | vocab_size: 128256
26 | enable_dropout: False
27 | logits_via_embedding: False
28 | normalization_layer_epsilon: 1.0e-5
29 | rope_max_timescale: 500_000
30 | decoder_block: "llama2" # Uses the same decoder block as llama2
31 |
--------------------------------------------------------------------------------
/MaxText/configs/models/llama3-70b.yml:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # model config for llama3-70b
16 |
17 | base_emb_dim: 8192
18 | base_num_query_heads: 64
19 | base_num_kv_heads: 8
20 | base_num_decoder_layers: 80
21 | base_mlp_dim: 28672
22 | head_dim: 128
23 | mlp_activations: ["silu","linear"]
24 | vocab_size: 128256
25 | enable_dropout: False
26 | logits_via_embedding: False
27 | normalization_layer_epsilon: 1.0e-5
28 | rope_max_timescale: 500_000
29 | decoder_block: "llama2" # Uses the same decoder block as llama2
30 |
--------------------------------------------------------------------------------
/MaxText/configs/models/llama3-8b.yml:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # model config for llama3-8b
16 |
17 | base_emb_dim: 4096
18 | base_num_query_heads: 32
19 | base_num_kv_heads: 8
20 | base_num_decoder_layers: 32
21 | base_mlp_dim: 14336
22 | head_dim: 128
23 | mlp_activations: ["silu","linear"]
24 | vocab_size: 128256
25 | enable_dropout: False
26 | logits_via_embedding: False
27 | normalization_layer_epsilon: 1.0e-5
28 | rope_max_timescale: 500_000
29 | decoder_block: "llama2" # Uses the same decoder block as llama2
30 |
--------------------------------------------------------------------------------
/MaxText/configs/models/llama3.1-405b.yml:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # model config for llama3.1-405b
16 |
17 | base_emb_dim: 16384
18 | base_num_query_heads: 128
19 | base_num_kv_heads: 8
20 | base_num_decoder_layers: 126
21 | base_mlp_dim: 53248
22 | head_dim: 128
23 | mlp_activations: ["silu","linear"]
24 | vocab_size: 128256
25 | enable_dropout: False
26 | logits_via_embedding: False
27 | normalization_layer_epsilon: 1.0e-5
28 | rope_max_timescale: 500_000
29 | decoder_block: "llama2" # Uses the same decoder block as llama2
30 |
--------------------------------------------------------------------------------
/MaxText/configs/models/llama3.1-70b.yml:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # model config for llama3.1-70b
16 |
17 | base_emb_dim: 8192
18 | base_num_query_heads: 64
19 | base_num_kv_heads: 8
20 | base_num_decoder_layers: 80
21 | base_mlp_dim: 28672
22 | head_dim: 128
23 | mlp_activations: ["silu","linear"]
24 | vocab_size: 128256
25 | enable_dropout: False
26 | logits_via_embedding: False
27 | normalization_layer_epsilon: 1.0e-5
28 | rope_max_timescale: 500_000
29 | decoder_block: "llama2" # Uses the same decoder block as llama2
30 |
--------------------------------------------------------------------------------
/MaxText/configs/models/llama3.1-8b.yml:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # model config for llama3.1-8b
16 |
17 | base_emb_dim: 4096
18 | base_num_query_heads: 32
19 | base_num_kv_heads: 8
20 | base_num_decoder_layers: 32
21 | base_mlp_dim: 14336
22 | head_dim: 128
23 | mlp_activations: ["silu","linear"]
24 | vocab_size: 128256
25 | enable_dropout: False
26 | logits_via_embedding: False
27 | normalization_layer_epsilon: 1.0e-5
28 | rope_max_timescale: 500_000
29 | decoder_block: "llama2" # Uses the same decoder block as llama2
30 |
--------------------------------------------------------------------------------
/MaxText/configs/models/llama3.3-70b.yml:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # model config for llama3.3-70b
16 |
17 | base_emb_dim: 8192
18 | base_num_query_heads: 64
19 | base_num_kv_heads: 8
20 | base_num_decoder_layers: 80
21 | base_mlp_dim: 28672
22 | head_dim: 128
23 | mlp_activations: ["silu","linear"]
24 | vocab_size: 128256
25 | enable_dropout: False
26 | logits_via_embedding: False
27 | normalization_layer_epsilon: 1.0e-5
28 | rope_max_timescale: 500_000
29 | decoder_block: "llama2" # Uses the same decoder block as llama2
30 |
--------------------------------------------------------------------------------
/MaxText/configs/models/llama4-17b-128e.yml:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | decoder_block: "llama4"
17 | mlp_activations: ["silu","linear"]
18 | enable_dropout: False
19 | tokenizer_type: "huggingface"
20 |
21 | base_emb_dim: 5120
22 | base_num_decoder_layers: 48
23 | base_num_query_heads: 40
24 | base_num_kv_heads: 8
25 | base_mlp_dim: 16384
26 | base_moe_mlp_dim: 8192
27 | vocab_size: 202048
28 | normalization_layer_epsilon: 1e-05
29 | rope_max_timescale: 500000
30 | rope_type: "llama3.1"
31 | rope_use_scale: False
32 | num_experts: 128
33 | shared_experts: 1
34 | num_experts_per_tok: 1
35 | use_qk_norm: False
36 | nope_layer_interval: 4 # Every fourth layer should NOT use RoPE
37 | interleave_moe_layer_step: 2 # Every 2nd layer is MoE layer, and 1st layer is dense layer
38 | inhomogeneous_layer_cycle_interval: 4 # Every four layers the pattern of nope and moe repeats (least common multiple of nope interval and moe interval)
39 |
40 | temperature_tuning: True
41 | # Chunk attention is used on all RoPE layers
42 | # otherwise, on NoPE layers, use global attention
43 | chunk_attn_window_size: 8192
44 | image_size_for_vit: 336
45 |
--------------------------------------------------------------------------------
/MaxText/configs/models/llama4-17b-16e.yml:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | decoder_block: "llama4"
17 | mlp_activations: ["silu","linear"]
18 | enable_dropout: False
19 | logits_via_embedding: False
20 | tokenizer_type: "huggingface"
21 |
22 | base_emb_dim: 5120
23 | base_num_decoder_layers: 48
24 | base_num_query_heads: 40
25 | base_num_kv_heads: 8
26 | base_mlp_dim: 16384
27 | base_moe_mlp_dim: 8192
28 | vocab_size: 202048
29 | normalization_layer_epsilon: 1e-05
30 | rope_max_timescale: 500000
31 | rope_type: "llama3.1"
32 | num_experts: 16
33 | shared_experts: 1
34 | num_experts_per_tok: 1
35 | use_qk_norm: True # Llama4 models apply an L2Norm to the Query and Keys after RoPE
36 | nope_layer_interval: 4 # Every fourth layer should NOT use RoPE
37 | interleave_moe_layer_step: 1 # Every layer is MoE layer
38 | inhomogeneous_layer_cycle_interval: 4 # Every four layers the pattern of nope and moe repeats (least common multiple of nope interval and moe interval)
39 |
40 | temperature_tuning: True
41 | # Chunk attention is used on all RoPE layers
42 | # otherwise, on NoPE layers, use global attention
43 | chunk_attn_window_size: 8192
44 | image_size_for_vit: 336
45 |
--------------------------------------------------------------------------------
/MaxText/configs/models/mistral-7b.yml:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # model config for mistral-7b
16 |
17 | base_emb_dim: 4096
18 | base_num_query_heads: 32
19 | base_num_kv_heads: 8
20 | base_mlp_dim: 14336
21 | base_num_decoder_layers: 32
22 | head_dim: 128
23 | mlp_activations: ["silu","linear"]
24 | vocab_size: 32000
25 | enable_dropout: False
26 | logits_via_embedding: False
27 | normalization_layer_epsilon: 1.0e-5
28 | rope_max_timescale: 1_000_000
29 | decoder_block: "mistral"
30 |
--------------------------------------------------------------------------------
/MaxText/configs/models/mixtral-8x22b.yml:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # model config for mixtral-8x22b
16 | # tokenizer_path is assets/tokenizer.mistral-v3
17 |
18 | base_emb_dim: 6144
19 | base_num_query_heads: 48
20 | base_num_kv_heads: 8
21 | base_mlp_dim: 16384
22 | base_num_decoder_layers: 56
23 | head_dim: 128
24 | mlp_activations: ["silu","linear"]
25 | vocab_size: 32768
26 | enable_dropout: False
27 | logits_via_embedding: False
28 | normalization_layer_epsilon: 1.0e-5
29 | num_experts: 8
30 | num_experts_per_tok: 2
31 | rope_max_timescale: 1_000_000
32 | decoder_block: "mixtral"
33 |
--------------------------------------------------------------------------------
/MaxText/configs/models/mixtral-8x7b.yml:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # model config for mixtral-8x7b
16 | # tokenizer_path is assets/tokenizer.mistral-v1
17 |
18 | base_emb_dim: 4096
19 | base_num_query_heads: 32
20 | base_num_kv_heads: 8
21 | base_mlp_dim: 14336
22 | base_num_decoder_layers: 32
23 | head_dim: 128
24 | mlp_activations: ["silu","linear"]
25 | vocab_size: 32000
26 | enable_dropout: False
27 | logits_via_embedding: False
28 | normalization_layer_epsilon: 1.0e-5
29 | num_experts: 8
30 | num_experts_per_tok: 2
31 | rope_max_timescale: 1_000_000
32 | decoder_block: "mixtral"
33 |
--------------------------------------------------------------------------------
/MaxText/configs/quantization/README.md:
--------------------------------------------------------------------------------
1 | # Mixed precision quantization configs (currently supported for inference only).
2 |
3 | This directory contains sample json files representing mixed precision quantization configs.
4 |
5 | A mixed precision config json is contains the following:
6 | Keys represent a regex for the layer to which the config is applied.
7 | Values represent the quantization config for the corresponding layer.
8 |
9 | The quantization config for any layer is defined using the following variables.
10 | w_bits: Number of bits used for weights, default None (implying no quantization)
11 | a_bits: Number of bits used for activations, default None (implying no quantization)
12 | w_scale: Clipping scale for weights, default 1.0
13 | a_scale: Clipping scale for activations, default 1.0
14 | tile_size: tile size for subchannel, default -1 (implying no subchannel)
15 |
16 | For example, the config below implies 4-bit weight_only quantization for layers wi_0 and w0.
17 | {
18 | ".*/wi_0": {"w_bits": 4},
19 | ".*/wo": {"w_bits": 4}
20 | }
21 |
22 | A special key '__default__' can be used to override the default values.
23 | For example the following config defines 8-bit weight only quantization for all layers.
24 | {
25 | "__default__": {"w_bits": 8}
26 | }
27 |
28 | # To configure mixed precision quantization, define the following.
29 | 1. A json file (e.g. example.json) in this directory with desired config
30 | 2. Set the following parameters defined in base.yml
31 | quantization="intmp"
32 | quant_cfg_path="/example.json"
33 |
--------------------------------------------------------------------------------
/MaxText/configs/quantization/dense_llm_subchannel.json:
--------------------------------------------------------------------------------
1 | {
2 | "__default__": {"w_bits": 8, "a_bits": 8},
3 | ".*/query": {"w_bits": 4, "tile_size": 128},
4 | ".*/key": {"w_bits": 4, "tile_size": 256},
5 | ".*/value": {"w_bits": 4},
6 | ".*/out": {"w_bits": 4},
7 | ".*/wi_0": {"w_bits": 4},
8 | ".*/wo": {"w_bits": 4}
9 | }
10 |
--------------------------------------------------------------------------------
/MaxText/configs/quantization/dense_llm_weight_only_scale.json:
--------------------------------------------------------------------------------
1 | {
2 | "__default__": {"w_bits": 8},
3 | ".*/query": {"w_bits": 4, "w_scale": 0.8},
4 | ".*/key": {"w_bits": 4, "w_scale": 0.9},
5 | ".*/value": {"w_bits": 4},
6 | ".*/out": {"w_bits": 4},
7 | ".*/wi_0": {"w_bits": 4},
8 | ".*/wo": {"w_bits": 4}
9 | }
--------------------------------------------------------------------------------
/MaxText/configs/quantization/int4_weight_only.json:
--------------------------------------------------------------------------------
1 | {
2 | "__default__": {"w_bits": 4}
3 | }
--------------------------------------------------------------------------------
/MaxText/configs/quantization/int8_weight_only.json:
--------------------------------------------------------------------------------
1 | {
2 | "__default__": {"w_bits": 8}
3 | }
--------------------------------------------------------------------------------
/MaxText/configs/sft.yml:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | base_config: "base.yml"
16 |
17 | use_sft: True
18 | # sft_train_on_completion_only=False trains on both prompt and completion tokens; trains only on completion tokens otherwise
19 | sft_train_on_completion_only: True
20 | packing: True
21 | learning_rate: 2.e-5
22 |
23 | # -------------- HF pipeline --------------
24 | dataset_type: hf
25 | hf_path: 'HuggingFaceH4/ultrachat_200k'
26 | train_split: 'train_sft'
27 | hf_eval_split: 'test_sft'
28 | train_data_columns: ['messages']
29 | eval_data_columns: ['messages']
30 |
--------------------------------------------------------------------------------
/MaxText/configs/tpu_smoke_test.yml:
--------------------------------------------------------------------------------
1 | base_config: "base.yml"
2 |
3 | hardware: "tpu"
4 | async_checkpointing: false
5 | attention: autoselected
6 | base_emb_dim: 32
7 | base_num_query_heads: 8
8 | base_num_kv_heads: 8
9 | base_mlp_dim: 32
10 | base_num_decoder_layers: 8
11 | dataset_type: "synthetic"
12 | head_dim: 128
13 | per_device_batch_size: 1
14 | steps: 10
15 |
--------------------------------------------------------------------------------
/MaxText/configs/trillium/gemma2_9b.sh:
--------------------------------------------------------------------------------
1 | # Gemma2-9b model.
2 | # This config will work out of the box for any number of trillium-256 slices.
3 | #
4 | # Command Flags:
5 | # OUTPUT_PATH (Required, unless base_output_directory is already set in base.yml)
6 | # DATASET_PATH (Required, unless dataset_path is already set in base.yml)
7 | # RUN_NAME (Required, unless run_name is already set in base.yml or running with XPK/GKE)
8 | #
9 | # Example to invoke this script:
10 | # bash MaxText/configs/trillium/gemma2_9b.sh RUN_NAME="" OUTPUT_PATH="gs://" DATASET_PATH="gs://"
11 | #
12 |
13 |
14 | # Stop execution if any command exits with error
15 | set -e
16 |
17 | export EXECUTABLE="train" # or train_compile
18 | export RUN_PREFLIGHT="true"
19 |
20 | # Set environment variables
21 | for ARGUMENT in "$@"; do
22 | IFS='=' read -r KEY VALUE <<< "$ARGUMENT"
23 | export "$KEY"="$VALUE"
24 | done
25 |
26 | # The setup accommodates two cases:
27 | # 1) Passing the 'RUN_NAME' variable at runtime
28 | # 2) Propagating the 'M_RUN_NAME' variable within an Airflow sweeping workflow
29 | if [ -n "$RUN_NAME" ];
30 | then
31 | export M_RUN_NAME=$RUN_NAME
32 | fi
33 |
34 | # Set up network optimizations
35 | if [ "$RUN_PREFLIGHT" = "true" ]; then
36 | bash preflight.sh
37 | fi
38 |
39 | # Train
40 | export LIBTPU_INIT_ARGS="--xla_tpu_scoped_vmem_limit_kib=114688 --xla_tpu_use_minor_sharding_for_major_trivial_input=true --xla_tpu_relayout_group_size_threshold_for_reduce_scatter=1 --xla_tpu_assign_all_reduce_scatter_layout --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"
41 |
42 | python3 -m MaxText.$EXECUTABLE MaxText/configs/base.yml model_name=gemma2-9b\
43 | steps=15 per_device_batch_size=3 enable_checkpointing=false\
44 | remat_policy=full ici_fsdp_transpose_parallelism=256 ici_fsdp_parallelism=-1\
45 | max_target_length=8192 base_output_directory=$OUTPUT_PATH\
46 | reuse_example_batch=1 dataset_type=synthetic gcs_metrics=true\
47 | attention='flash' sa_block_q=2048 sa_block_q_dkv=2048 sa_block_q_dq=2048
48 |
--------------------------------------------------------------------------------
/MaxText/configs/trillium/llama2_7b_4096.sh:
--------------------------------------------------------------------------------
1 | # Llama2-7b model.
2 | # This config will work out of the box for any number of trillium-256 slices.
3 | #
4 | # Command Flags:
5 | # OUTPUT_PATH (Required, unless base_output_directory is already set in base.yml)
6 | # DATASET_PATH (Required, unless dataset_path is already set in base.yml)
7 | # RUN_NAME (Required, unless run_name is already set in base.yml or running with XPK/GKE)
8 | #
9 | # Example to invoke this script:
10 | # bash MaxText/configs/trillium/llama2_7b_4096.sh RUN_NAME="" OUTPUT_PATH="gs://" DATASET_PATH="gs://"
11 | #
12 |
13 |
14 | # Stop execution if any command exits with error
15 | set -e
16 |
17 | export EXECUTABLE="train" # or train_compile
18 | export RUN_PREFLIGHT="true"
19 |
20 | # Set environment variables
21 | for ARGUMENT in "$@"; do
22 | IFS='=' read -r KEY VALUE <<< "$ARGUMENT"
23 | export "$KEY"="$VALUE"
24 | done
25 |
26 | # The setup accommodates two cases:
27 | # 1) Passing the 'RUN_NAME' variable at runtime
28 | # 2) Propagating the 'M_RUN_NAME' variable within an Airflow sweeping workflow
29 | if [ -n "$RUN_NAME" ];
30 | then
31 | export M_RUN_NAME=$RUN_NAME
32 | fi
33 |
34 | # Set up network optimizations
35 | if [ "$RUN_PREFLIGHT" = "true" ]; then
36 | bash preflight.sh
37 | fi
38 |
39 | # Train
40 | export LIBTPU_INIT_ARGS="--xla_tpu_scoped_vmem_limit_kib=98304 --xla_enable_async_all_gather=true --xla_tpu_overlap_compute_collective_tc=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true"
41 |
42 | python3 -m MaxText.$EXECUTABLE MaxText/configs/base.yml model_name=llama2-7b\
43 | steps=15 per_device_batch_size=12 enable_checkpointing=false\
44 | remat_policy=full ici_fsdp_parallelism=-1\
45 | max_target_length=4096 base_output_directory=$OUTPUT_PATH\
46 | reuse_example_batch=1 dataset_type=synthetic gcs_metrics=true\
47 | attention='flash' sa_block_q=1024 sa_block_q_dkv=2048 sa_block_q_dq=2048
48 |
--------------------------------------------------------------------------------
/MaxText/configs/trillium/mixtral_8x7b.sh:
--------------------------------------------------------------------------------
1 | # Mixtral-8x7b model.
2 | # This config will work out of the box for any number of trillium-256 slices.
3 | #
4 | # Command Flags:
5 | # OUTPUT_PATH (Required, unless base_output_directory is already set in base.yml)
6 | # DATASET_PATH (Required, unless dataset_path is already set in base.yml)
7 | # RUN_NAME (Required, unless run_name is already set in base.yml or running with XPK/GKE)
8 | #
9 | # Example to invoke this script:
10 | # bash MaxText/configs/trillium/mixtral_8x7b.sh RUN_NAME="" OUTPUT_PATH="gs://" DATASET_PATH="gs://"
11 | #
12 |
13 |
14 | # Stop execution if any command exits with error
15 | set -e
16 |
17 | export EXECUTABLE="train" # or train_compile
18 | export RUN_PREFLIGHT="true"
19 |
20 | # Set environment variables
21 | for ARGUMENT in "$@"; do
22 | IFS='=' read -r KEY VALUE <<< "$ARGUMENT"
23 | export "$KEY"="$VALUE"
24 | done
25 |
26 | # The setup accommodates two cases:
27 | # 1) Passing the 'RUN_NAME' variable at runtime
28 | # 2) Propagating the 'M_RUN_NAME' variable within an Airflow sweeping workflow
29 | if [ -n "$RUN_NAME" ];
30 | then
31 | export M_RUN_NAME=$RUN_NAME
32 | fi
33 |
34 | # Set up network optimizations
35 | if [ "$RUN_PREFLIGHT" = "true" ]; then
36 | bash preflight.sh
37 | fi
38 |
39 | # Train
40 | export LIBTPU_INIT_ARGS="--xla_tpu_scoped_vmem_limit_kib=81920 --xla_enable_async_all_gather=true --xla_tpu_overlap_compute_collective_tc=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true"
41 |
42 | python3 -m MaxText.$EXECUTABLE MaxText/configs/base.yml model_name=mixtral-8x7b\
43 | steps=15 per_device_batch_size=32 enable_checkpointing=false\
44 | remat_policy=full ici_fsdp_parallelism=-1\
45 | max_target_length=1024 base_output_directory=$OUTPUT_PATH\
46 | reuse_example_batch=1 dataset_type=synthetic gcs_metrics=true\
47 | attention='flash' sa_block_q=1024 sa_block_q_dkv=2048 sa_block_q_dq=2048
48 |
--------------------------------------------------------------------------------
/MaxText/configs/v4/README.md:
--------------------------------------------------------------------------------
1 |
16 |
17 | # High Performance Model Configs on TPU v4
18 | Expected performance results for 22B and 52B parameter models running on TPU v4:
19 |
20 |
21 | ### 22B model
22 | | Hardware | TFLOP/sec/chip | MFU |
23 | | ----------- | ---------------- | ----- |
24 | | 1x v4-128 | 156 | 56.7% |
25 | | 2x v4-128 | 152 | 55.2% |
26 | | 4x v4-128 | 149 | 54.3% |
27 | | 8x v4-128 | 146 | 53.2% |
28 |
29 | ### 52B model
30 | | Hardware | TFLOP/sec/chip | MFU |
31 | | ----------- | ---------------- | ----- |
32 | | 1x v4-384 | 154 | 56.0% |
33 | | 2x v4-384 | 162 | 58.9% | # this is quirkily higher than single slice because of choices made by the compiler, not for a fundamental reason.
--------------------------------------------------------------------------------
/MaxText/configs/v5e/README.md:
--------------------------------------------------------------------------------
1 |
16 |
17 | # High Performance Model Training Configs on TPU v5e
18 | Expected performance results for 16B, 32B, 64B, and 128B parameter models running on TPU v5e:
19 |
20 |
21 | | Hardware | 16B TFLOP/sec/chip | 16B MFU | 32B TFLOP/sec/chip | 32B MFU | 64B TFLOP/sec/chip | 64B MFU | 128B TFLOP/sec/chip | 128B MFU |
22 | | ----------- | -----------------: | ------- | -----------------: | ------- | -----------------: | ------- | ------------------: | -------- |
23 | | 1x v5e-256 | 120 | 61.10% | 132 | 66.86% | 118 | 59.90% | 110 | 56.06% |
24 | | 2x v5e-256 | 117 | 59.37% | 128 | 64.81% | 112 | 56.66% | 110 | 55.82% |
25 | | 4x v5e-256 | 117 | 59.14% | 126 | 64.10% | 110 | 55.85% | 108 | 54.93% |
26 | | 8x v5e-256 | 115 | 58.27% | 125 | 63.67% | 108 | 54.96% | 104 | 52.93% |
27 | | 16x v5e-256 | 111 | 56.56% | 123 | 62.26% | 105 | 53.29% | 100 | 50.86% |
28 | | 32x v5e-256 | 108 | 54.65% | 119 | 60.40% | 99 | 50.18% | 91 | 46.25% |
29 |
--------------------------------------------------------------------------------
/MaxText/configs/v5e/llama2_13b.sh:
--------------------------------------------------------------------------------
1 | # Llama2 13B model.
2 | # This config will work out of the box for any number of v5e-256 slices.
3 | #
4 | # Command Flags:
5 | # OUTPUT_PATH (Required, unless base_output_directory is already set in base.yml)
6 | # DATASET_PATH (Required, unless dataset_path is already set in base.yml)
7 | # RUN_NAME (Required, unless run_name is already set in base.yml or running with XPK/GKE)
8 | #
9 | # Example to invoke this script:
10 | # bash MaxText/configs/v5e/llama2_13b.sh RUN_NAME="" OUTPUT_PATH="gs://" DATASET_PATH="gs://"
11 | #
12 | # Example to AOT compile:
13 | # bash MaxText/configs/v5e/llama2_13b.sh EXECUTABLE=train_compile M_COMPILE_TOPOLOGY=v5e-256 M_COMPILE_TOPOLOGY_NUM_SLICES=2
14 |
15 |
16 | # Stop execution if any command exits with error
17 | set -e
18 |
19 | export EXECUTABLE="train" # or train_compile
20 | export RUN_PREFLIGHT="true"
21 |
22 | # Set environment variables
23 | for ARGUMENT in "$@"; do
24 | IFS='=' read -r KEY VALUE <<< "$ARGUMENT"
25 | export "$KEY"="$VALUE"
26 | done
27 |
28 | # The setup accommodates two cases:
29 | # 1) Passing the 'RUN_NAME' variable at runtime
30 | # 2) Propagating the 'M_RUN_NAME' variable within an Airflow sweeping workflow
31 | if [ -n "$RUN_NAME" ];
32 | then
33 | export M_RUN_NAME=$RUN_NAME
34 | fi
35 |
36 | # Set up network optimizations
37 | if [ "$RUN_PREFLIGHT" = "true" ]; then
38 | bash preflight.sh
39 | fi
40 |
41 | # Train
42 | export LIBTPU_INIT_ARGS="--xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"
43 |
44 | python3 -m MaxText.$EXECUTABLE MaxText/configs/base.yml model_name=llama2-13b\
45 | base_output_directory=$OUTPUT_PATH dataset_path=${DATASET_PATH}\
46 | tokenizer_path=assets/tokenizer.llama2 per_device_batch_size=8 remat_policy=qkv_proj_offloaded\
47 | steps=15 enable_checkpointing=false use_iota_embed=true
48 |
--------------------------------------------------------------------------------
/MaxText/configs/v5e/llama2_70b.sh:
--------------------------------------------------------------------------------
1 | # Llama2 70B model.
2 | # This config will work out of the box for any number of v5e-256 slices.
3 | #
4 | # Command Flags:
5 | # OUTPUT_PATH (Required, unless base_output_directory is already set in base.yml)
6 | # DATASET_PATH (Required, unless dataset_path is already set in base.yml)
7 | # RUN_NAME (Required, unless run_name is already set in base.yml or running with XPK/GKE)
8 | #
9 | # Example to invoke this script:
10 | # bash MaxText/configs/v5e/llama2_70b.sh RUN_NAME="" OUTPUT_PATH="gs://" DATASET_PATH="gs://"
11 | #
12 | # Example to AOT compile:
13 | # bash MaxText/configs/v5e/llama2_70b.sh EXECUTABLE=train_compile M_COMPILE_TOPOLOGY=v5e-256 M_COMPILE_TOPOLOGY_NUM_SLICES=2
14 |
15 |
16 | # Stop execution if any command exits with error
17 | set -e
18 |
19 | export EXECUTABLE="train" # or train_compile
20 | export RUN_PREFLIGHT="true"
21 |
22 | # Set environment variables
23 | for ARGUMENT in "$@"; do
24 | IFS='=' read -r KEY VALUE <<< "$ARGUMENT"
25 | export "$KEY"="$VALUE"
26 | done
27 |
28 | # The setup accommodates two cases:
29 | # 1) Passing the 'RUN_NAME' variable at runtime
30 | # 2) Propagating the 'M_RUN_NAME' variable within an Airflow sweeping workflow
31 | if [ -n "$RUN_NAME" ];
32 | then
33 | export M_RUN_NAME=$RUN_NAME
34 | fi
35 |
36 | # Set up network optimizations
37 | if [ "$RUN_PREFLIGHT" = "true" ]; then
38 | bash preflight.sh
39 | fi
40 |
41 | # Train
42 | export LIBTPU_INIT_ARGS="--xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"
43 |
44 | python3 -m MaxText.$EXECUTABLE MaxText/configs/base.yml model_name=llama2-70b\
45 | base_output_directory=$OUTPUT_PATH dataset_path=${DATASET_PATH}\
46 | tokenizer_path=assets/tokenizer.llama2 per_device_batch_size=2 remat_policy=qkv_proj_offloaded\
47 | steps=15 enable_checkpointing=false use_iota_embed=true
48 |
--------------------------------------------------------------------------------
/MaxText/configs/v5e/llama2_70b_v5e-16.yml:
--------------------------------------------------------------------------------
1 | base_config: "inference_jetstream.yml"
2 |
3 | # tensor = 8, autoregressive=2
4 | # per_device_batch_size=6
5 | # weight bf16, kv cache bf16
6 |
7 | model_name: "llama2-70b"
8 | sharding_strategy: "experimental"
9 | attention: 'dot_product'
10 | allow_split_physical_axes: True
11 | # Used to replicate the quantization scale to avoid the inefficient XLA fusion.
12 | replicate_quant_scale: True
13 |
14 | logical_axis_rules: [
15 | ['embed', []],
16 | ['vocab', ['tensor', 'autoregressive']],
17 | ['activation_batch', []],
18 | ['activation_length', []],
19 | ['activation_embed', []],
20 | ['activation_vocab', ['tensor']],
21 | ['heads', ['tensor', 'autoregressive']],
22 | ['kv', []],
23 | # TODO: fix the wrong XLA ops for the following sharding.
24 | # ['q_heads', ['tensor', 'autoregressive']],
25 | # ['kv_head_dim', ['autoregressive']],
26 | ['q_heads', ['tensor']],
27 | ['kv_heads', ['tensor']],
28 | ['kv_head_dim', []],
29 | ['activation_prefill_kv_batch', []],
30 | ['activation_kv_batch', ['autoregressive']],
31 | ['activation_kv_heads', ['tensor']],
32 | ['activation_kv_head_dim', []],
33 | ['activation_heads', ['tensor']],
34 | ['activation_kv', ['tensor', 'autoregressive']],
35 | ['norm', []],
36 | ['mlp', ['tensor', 'autoregressive']],
37 | ['activation_mlp', ['tensor', 'autoregressive']],
38 | ['cache_batch_prefill', []],
39 | ['cache_batch', ['autoregressive']],
40 | ['cache_sequence', []],
41 | ['cache_heads', ['tensor']],
42 | ['cache_kv', []],
43 | ]
44 |
--------------------------------------------------------------------------------
/MaxText/configs/v5e/llama2_7b.sh:
--------------------------------------------------------------------------------
1 | # Llama2 7B model.
2 | # This config will work out of the box for any number of v5e-256 slices.
3 | #
4 | # Command Flags:
5 | # OUTPUT_PATH (Required, unless base_output_directory is already set in base.yml)
6 | # DATASET_PATH (Required, unless dataset_path is already set in base.yml)
7 | # RUN_NAME (Required, unless run_name is already set in base.yml or running with XPK/GKE)
8 | #
9 | # Example to invoke this script:
10 | # bash MaxText/configs/v5e/llama2_7b.sh RUN_NAME="" OUTPUT_PATH="gs://" DATASET_PATH="gs://"
11 | #
12 | # Example to AOT compile:
13 | # bash MaxText/configs/v5e/llama2_7b.sh EXECUTABLE=train_compile M_COMPILE_TOPOLOGY=v5e-256 M_COMPILE_TOPOLOGY_NUM_SLICES=2
14 |
15 |
16 | # Stop execution if any command exits with error
17 | set -e
18 |
19 | export EXECUTABLE="train" # or train_compile
20 | export RUN_PREFLIGHT="true"
21 |
22 | # Set environment variables
23 | for ARGUMENT in "$@"; do
24 | IFS='=' read -r KEY VALUE <<< "$ARGUMENT"
25 | export "$KEY"="$VALUE"
26 | done
27 |
28 | # The setup accommodates two cases:
29 | # 1) Passing the 'RUN_NAME' variable at runtime
30 | # 2) Propagating the 'M_RUN_NAME' variable within an Airflow sweeping workflow
31 | if [ -n "$RUN_NAME" ];
32 | then
33 | export M_RUN_NAME=$RUN_NAME
34 | fi
35 |
36 | # Set up network optimizations
37 | if [ "$RUN_PREFLIGHT" = "true" ]; then
38 | bash preflight.sh
39 | fi
40 |
41 | # Train
42 | export LIBTPU_INIT_ARGS="--xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"
43 |
44 | python3 -m MaxText.$EXECUTABLE MaxText/configs/base.yml model_name=llama2-7b\
45 | base_output_directory=$OUTPUT_PATH dataset_path=${DATASET_PATH}\
46 | tokenizer_path=assets/tokenizer.llama2 per_device_batch_size=4 remat_policy=save_qkv_proj\
47 | steps=15 enable_checkpointing=false use_iota_embed=true
--------------------------------------------------------------------------------
/MaxText/configs/v5e/llama3_405b_v5e-64.yml:
--------------------------------------------------------------------------------
1 | base_config: "inference_jetstream.yml"
2 |
3 | # v5e-64
4 | # tensor = 8, autoregressive=8
5 | # per_device_batch_size=1
6 | # weight bf16, kv cache bf16
7 |
8 | model_name: "llama3.1-405b"
9 | sharding_strategy: "experimental"
10 | attention: 'dot_product'
11 | allow_split_physical_axes: True
12 | tokenizer_path: "assets/tokenizer_llama3.tiktoken"
13 | # Used to replicate the quantization scale to avoid the inefficient XLA fusion.
14 | replicate_quant_scale: True
15 |
16 | logical_axis_rules: [
17 | ['embed', []],
18 | ['vocab', ['tensor', 'autoregressive']],
19 | ['activation_batch', []],
20 | ['activation_length', []],
21 | ['activation_embed', []],
22 | ['activation_vocab', ['tensor', 'autoregressive']],
23 | ['heads', ['tensor', 'autoregressive']],
24 | ['kv', []],
25 | ['kv_heads', ['tensor']],
26 | ['q_heads', ['tensor']],
27 | ['kv_head_dim', []],
28 | ['activation_prefill_kv_batch', []],
29 | ['activation_kv_batch', ['autoregressive']],
30 | ['activation_kv_heads', ['tensor']],
31 | ['activation_kv_head_dim', []],
32 | ['activation_heads', ['tensor']],
33 | ['activation_kv', ['tensor', 'autoregressive']],
34 | ['norm', []],
35 | ['mlp', ['tensor', 'autoregressive']],
36 | ['activation_mlp', ['tensor', 'autoregressive']],
37 | ['cache_batch_prefill', []],
38 | ['cache_batch', ['autoregressive']],
39 | ['cache_sequence', []],
40 | ['cache_heads', ['tensor']],
41 | ['cache_kv', []],
42 | ]
43 |
--------------------------------------------------------------------------------
/MaxText/configs/v5e/llama3_70b_v5e-16.yml:
--------------------------------------------------------------------------------
1 | base_config: "inference_jetstream.yml"
2 |
3 | # tensor = 8, autoregressive=2
4 | # per_device_batch_size=6
5 | # weight bf16, kv cache bf16
6 |
7 | model_name: "llama3-70b"
8 | tokenizer_path: "assets/tokenizer_llama3.tiktoken"
9 | sharding_strategy: "experimental"
10 | attention: 'dot_product'
11 | allow_split_physical_axes: True
12 | # Used to replicate the quantization scale to avoid the inefficient XLA fusion.
13 | replicate_quant_scale: True
14 |
15 | logical_axis_rules: [
16 | ['embed', []],
17 | ['vocab', ['tensor', 'autoregressive']],
18 | ['activation_batch', []],
19 | ['activation_length', []],
20 | ['activation_embed', []],
21 | ['activation_vocab', ['tensor']],
22 | ['heads', ['tensor', 'autoregressive']],
23 | ['kv', []],
24 | # TODO: fix the wrong XLA ops for the following sharding.
25 | # ['q_heads', ['tensor', 'autoregressive']],
26 | # ['kv_head_dim', ['autoregressive']],
27 | ['q_heads', ['tensor']],
28 | ['kv_heads', ['tensor']],
29 | ['kv_head_dim', []],
30 | ['activation_prefill_kv_batch', []],
31 | ['activation_kv_batch', ['autoregressive']],
32 | ['activation_kv_heads', ['tensor']],
33 | ['activation_kv_head_dim', []],
34 | ['activation_heads', ['tensor']],
35 | ['activation_kv', ['tensor', 'autoregressive']],
36 | ['norm', []],
37 | ['mlp', ['tensor', 'autoregressive']],
38 | ['activation_mlp', ['tensor', 'autoregressive']],
39 | ['cache_batch_prefill', []],
40 | ['cache_batch', ['autoregressive']],
41 | ['cache_sequence', []],
42 | ['cache_heads', ['tensor']],
43 | ['cache_kv', []],
44 | ]
45 |
--------------------------------------------------------------------------------
/MaxText/configs/v5p/gpt3_175b/v5p_1024.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # GPT-3 175B Model.
3 | # Train GPT-3 175B on v5p-1024 slice.
4 |
5 | # Example to invoke this script:
6 | # bash MaxText/configs/v5p/gpt3_175b/v5p_1024.sh YOUR_RUN gs://YOUR_BUCKET"
7 |
8 | set -euox pipefail
9 |
10 | # Read arguments or use defaults from environment variables
11 | RUNNAME=${1:-${RUNNAME:-some-run}}
12 | BASE_OUTPUT_DIRECTORY=${2:-${BASE_OUTPUT_DIRECTORY:-gs://some-bucket}}
13 |
14 | chmod +x MaxText/configs/v5p/gpt3_175b/gpt3_175b_base.sh
15 | ./MaxText/configs/v5p/gpt3_175b/gpt3_175b_base.sh 4 "full" 1 64 8 "${RUNNAME}" "${BASE_OUTPUT_DIRECTORY}"
16 |
--------------------------------------------------------------------------------
/MaxText/configs/v5p/gpt3_175b/v5p_12288.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # GPT-3 175B Model.
3 | # Train GPT-3 175B on v5p-12288 slice, with a custom topology of 8x16x48.
4 |
5 | # Example to invoke this script:
6 | # bash MaxText/configs/v5p/gpt3_175b/v5p_12288.sh YOUR_RUN gs://YOUR_BUCKET"
7 |
8 | set -euox pipefail
9 |
10 | # Read arguments or use defaults from environment variables
11 | RUNNAME=${1:-${RUNNAME:-some-run}}
12 | BASE_OUTPUT_DIRECTORY=${2:-${BASE_OUTPUT_DIRECTORY:-gs://some-bucket}}
13 |
14 | chmod +x MaxText/configs/v5p/gpt3_175b/gpt3_175b_base.sh
15 | ./MaxText/configs/v5p/gpt3_175b/gpt3_175b_base.sh 1 "minimal" 16 48 8 "${RUNNAME}" "${BASE_OUTPUT_DIRECTORY}"
--------------------------------------------------------------------------------
/MaxText/configs/v5p/gpt3_175b/v5p_2048.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # GPT-3 175B Model.
3 | # Train GPT-3 175B on v5p-2048 slice.
4 |
5 | # Example to invoke this script:
6 | # bash MaxText/configs/v5p/gpt3_175b/v5p_2048.sh YOUR_RUN gs://YOUR_BUCKET"
7 |
8 | set -euox pipefail
9 |
10 | # Read arguments or use defaults from environment variables
11 | RUNNAME=${1:-${RUNNAME:-some-run}}
12 | BASE_OUTPUT_DIRECTORY=${2:-${BASE_OUTPUT_DIRECTORY:-gs://some-bucket}}
13 |
14 | chmod +x MaxText/configs/v5p/gpt3_175b/gpt3_175b_base.sh
15 | ./MaxText/configs/v5p/gpt3_175b/gpt3_175b_base.sh 2 "save_dot_except_mlpwi" 8 16 8 "${RUNNAME}" "${BASE_OUTPUT_DIRECTORY}"
--------------------------------------------------------------------------------
/MaxText/configs/v5p/gpt3_175b/v5p_3072.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # GPT-3 175B Model.
3 | # Train GPT-3 175B on v5p-3072 slice.
4 |
5 | # Example to invoke this script:
6 | # bash MaxText/configs/v5p/gpt3_175b/v5p_3072.sh YOUR_RUN gs://YOUR_BUCKET"
7 |
8 | set -euox pipefail
9 |
10 | # Read arguments or use defaults from environment variables
11 | RUNNAME=${1:-${RUNNAME:-some-run}}
12 | BASE_OUTPUT_DIRECTORY=${2:-${BASE_OUTPUT_DIRECTORY:-gs://some-bucket}}
13 |
14 | chmod +x MaxText/configs/v5p/gpt3_175b/gpt3_175b_base.sh
15 | ./MaxText/configs/v5p/gpt3_175b/gpt3_175b_base.sh 2 "save_dot_except_mlpwi" 12 16 8 "${RUNNAME}" "${BASE_OUTPUT_DIRECTORY}"
--------------------------------------------------------------------------------
/MaxText/configs/v5p/gpt3_175b/v5p_4096.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # GPT-3 175B Model.
3 | # Train GPT-3 175B on v5p-4096 slice, with a custom topology of 4x8x64.
4 |
5 | # Example to invoke this script:
6 | # bash MaxText/configs/v5p/gpt3_175b/v5p_4096.sh YOUR_RUN gs://YOUR_BUCKET"
7 |
8 | set -euox pipefail
9 |
10 | # Read arguments or use defaults from environment variables
11 | RUNNAME=${1:-${RUNNAME:-some-run}}
12 | BASE_OUTPUT_DIRECTORY=${2:-${BASE_OUTPUT_DIRECTORY:-gs://some-bucket}}
13 |
14 | chmod +x MaxText/configs/v5p/gpt3_175b/gpt3_175b_base.sh
15 | ./MaxText/configs/v5p/gpt3_175b/gpt3_175b_base.sh 1 "minimal" 4 64 8 "${RUNNAME}" "${BASE_OUTPUT_DIRECTORY}"
--------------------------------------------------------------------------------
/MaxText/configs/v5p/gpt3_175b/v5p_8192.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # GPT-3 175B Model.
3 | # Train GPT-3 175B on v5p-8192 slice, with a custom topology of 8x16x32.
4 |
5 | # Example to invoke this script:
6 | # bash MaxText/configs/v5p/gpt3_175b/v5p_8192.sh YOUR_RUN gs://YOUR_BUCKET"
7 |
8 | set -euox pipefail
9 |
10 | # Read arguments or use defaults from environment variables
11 | RUNNAME=${1:-${RUNNAME:-some-run}}
12 | BASE_OUTPUT_DIRECTORY=${2:-${BASE_OUTPUT_DIRECTORY:-gs://some-bucket}}
13 |
14 | chmod +x MaxText/configs/v5p/gpt3_175b/gpt3_175b_base.sh
15 | ./MaxText/configs/v5p/gpt3_175b/gpt3_175b_base.sh 1 "minimal" 16 32 8 "${RUNNAME}" "${BASE_OUTPUT_DIRECTORY}"
--------------------------------------------------------------------------------
/MaxText/experimental/rl/grpo.yml:
--------------------------------------------------------------------------------
1 | base_config: "base.yml"
2 |
3 | use_grpo: True
4 | train_data_columns: 'prompt'
5 |
6 | learning_rate: 1.e-6
7 |
8 | dataset_type: hf # we currently only support Huggingface input pipeline with GRPO.
9 |
10 | #TRL
11 | max_prefill_predict_length: 512
12 | max_target_length: 1024
13 |
14 | adam_b2: 0.999
15 |
16 | # Group Relative Policy Optimization (GRPO)
17 | num_generations: 4
18 | grpo_beta: 0.04
19 |
20 | decode_sampling_strategy: "weighted"
21 | decode_sampling_temperature: 0.9
22 | async_checkpointing: false
23 |
--------------------------------------------------------------------------------
/MaxText/experimental/rl/grpo_trainer_test.yml:
--------------------------------------------------------------------------------
1 | base_config: "grpo.yml"
2 |
3 | model_name: "default"
4 | vocab_size: 128256
5 | max_target_length: 32
6 | per_device_batch_size: 1
7 | max_prefill_predict_length: 16
8 | dataset_type: "synthetic"
9 | dtype: "float32"
10 | matmul_precision: "high"
11 | logits_dot_in_fp32: True
12 | prompt: "Hello world this is a test"
13 | init_weights_seed: 42
14 | enable_dropout: False
15 |
--------------------------------------------------------------------------------
/MaxText/globals.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2025 Google LLC
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | https://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | import os.path
18 |
19 | PKG_DIR = os.path.dirname(os.path.abspath(__file__))
20 |
21 | __all__ = ["PKG_DIR"]
22 |
--------------------------------------------------------------------------------
/MaxText/inference/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2025 Google LLC
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | https://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
--------------------------------------------------------------------------------
/MaxText/inference/configs/multi_host/disaggregation/llama3_405b_v6e-16-16.yml:
--------------------------------------------------------------------------------
1 | base_config: "inference_jetstream.yml"
2 |
3 | model_name: "llama3.1-405b"
4 | sharding_strategy: "experimental"
5 | attention: 'dot_product'
6 | allow_split_physical_axes: True
7 | tokenizer_path: "assets/tokenizer_llama3.tiktoken"
8 | # Used to replicate the quantization scale to avoid the inefficient XLA fusion.
9 | replicate_quant_scale: True
10 |
11 | inference_server: "ExperimentalMaxtextDisaggregatedServer"
12 |
13 | logical_axis_rules: [
14 | ['embed', []],
15 | ['vocab', ['tensor', 'autoregressive']],
16 | ['activation_batch', []],
17 | ['activation_length', []],
18 | ['activation_embed', []],
19 | ['activation_vocab', ['tensor', 'autoregressive']],
20 | ['heads', ['tensor', 'autoregressive']],
21 | ['kv', []],
22 | ['kv_heads', ['tensor']],
23 | ['q_heads', ['tensor']],
24 | ['kv_head_dim', []],
25 | ['activation_prefill_kv_batch', []],
26 | ['activation_kv_batch', ['autoregressive']],
27 | ['activation_kv_heads', ['tensor']],
28 | ['activation_kv_head_dim', []],
29 | ['activation_heads', ['tensor']],
30 | ['activation_kv', ['tensor', 'autoregressive']],
31 | ['norm', []],
32 | ['mlp', ['tensor', 'autoregressive']],
33 | ['activation_mlp', ['tensor', 'autoregressive']],
34 | ['cache_batch_prefill', []],
35 | ['cache_batch', []],
36 | ['cache_sequence', []],
37 | ['cache_heads', ['tensor']],
38 | ['cache_kv', []],
39 | ]
40 |
--------------------------------------------------------------------------------
/MaxText/inference/configs/multi_host/interleaved/llama2_70b_v5e-16.yml:
--------------------------------------------------------------------------------
1 | base_config: "inference_jetstream.yml"
2 |
3 | # tensor = 8, autoregressive=2
4 | # per_device_batch_size=6
5 | # weight bf16, kv cache bf16
6 |
7 | model_name: "llama2-70b"
8 | sharding_strategy: "experimental"
9 | attention: 'dot_product'
10 | allow_split_physical_axes: True
11 | # Used to replicate the quantization scale to avoid the inefficient XLA fusion.
12 | replicate_quant_scale: True
13 |
14 | logical_axis_rules: [
15 | ['embed', []],
16 | ['vocab', ['tensor', 'autoregressive']],
17 | ['activation_batch', []],
18 | ['activation_length', []],
19 | ['activation_embed', []],
20 | ['activation_vocab', ['tensor']],
21 | ['heads', ['tensor', 'autoregressive']],
22 | ['kv', []],
23 | # TODO: fix the wrong XLA ops for the following sharding.
24 | # ['q_heads', ['tensor', 'autoregressive']],
25 | # ['kv_head_dim', ['autoregressive']],
26 | ['q_heads', ['tensor']],
27 | ['kv_heads', ['tensor']],
28 | ['kv_head_dim', []],
29 | ['activation_prefill_kv_batch', []],
30 | ['activation_kv_batch', ['autoregressive']],
31 | ['activation_kv_heads', ['tensor']],
32 | ['activation_kv_head_dim', []],
33 | ['activation_heads', ['tensor']],
34 | ['activation_kv', ['tensor', 'autoregressive']],
35 | ['norm', []],
36 | ['mlp', ['tensor', 'autoregressive']],
37 | ['activation_mlp', ['tensor', 'autoregressive']],
38 | ['cache_batch_prefill', []],
39 | ['cache_batch', ['autoregressive']],
40 | ['cache_sequence', []],
41 | ['cache_heads', ['tensor']],
42 | ['cache_kv', []],
43 | ]
44 |
--------------------------------------------------------------------------------
/MaxText/inference/configs/multi_host/interleaved/llama3_405b_v5e-64.yml:
--------------------------------------------------------------------------------
1 | base_config: "inference_jetstream.yml"
2 |
3 | # v5e-64
4 | # tensor = 8, autoregressive=8
5 | # per_device_batch_size=1
6 | # weight bf16, kv cache bf16
7 |
8 | model_name: "llama3.1-405b"
9 | sharding_strategy: "experimental"
10 | attention: 'dot_product'
11 | allow_split_physical_axes: True
12 | tokenizer_path: "assets/tokenizer_llama3.tiktoken"
13 | # Used to replicate the quantization scale to avoid the inefficient XLA fusion.
14 | replicate_quant_scale: True
15 |
16 | logical_axis_rules: [
17 | ['embed', []],
18 | ['vocab', ['tensor', 'autoregressive']],
19 | ['activation_batch', []],
20 | ['activation_length', []],
21 | ['activation_embed', []],
22 | ['activation_vocab', ['tensor', 'autoregressive']],
23 | ['heads', ['tensor', 'autoregressive']],
24 | ['kv', []],
25 | ['kv_heads', ['tensor']],
26 | ['q_heads', ['tensor']],
27 | ['kv_head_dim', []],
28 | ['activation_prefill_kv_batch', []],
29 | ['activation_kv_batch', ['autoregressive']],
30 | ['activation_kv_heads', ['tensor']],
31 | ['activation_kv_head_dim', []],
32 | ['activation_heads', ['tensor']],
33 | ['activation_kv', ['tensor', 'autoregressive']],
34 | ['norm', []],
35 | ['mlp', ['tensor', 'autoregressive']],
36 | ['activation_mlp', ['tensor', 'autoregressive']],
37 | ['cache_batch_prefill', []],
38 | ['cache_batch', ['autoregressive']],
39 | ['cache_sequence', []],
40 | ['cache_heads', ['tensor']],
41 | ['cache_kv', []],
42 | ]
43 |
--------------------------------------------------------------------------------
/MaxText/inference/configs/multi_host/interleaved/llama3_70b_v5e-16.yml:
--------------------------------------------------------------------------------
1 | base_config: "inference_jetstream.yml"
2 |
3 | # tensor = 8, autoregressive=2
4 | # per_device_batch_size=6
5 | # weight bf16, kv cache bf16
6 |
7 | model_name: "llama3-70b"
8 | tokenizer_path: "assets/tokenizer_llama3.tiktoken"
9 | sharding_strategy: "experimental"
10 | attention: 'dot_product'
11 | allow_split_physical_axes: True
12 | # Used to replicate the quantization scale to avoid the inefficient XLA fusion.
13 | replicate_quant_scale: True
14 |
15 | logical_axis_rules: [
16 | ['embed', []],
17 | ['vocab', ['tensor', 'autoregressive']],
18 | ['activation_batch', []],
19 | ['activation_length', []],
20 | ['activation_embed', []],
21 | ['activation_vocab', ['tensor']],
22 | ['heads', ['tensor', 'autoregressive']],
23 | ['kv', []],
24 | # TODO: fix the wrong XLA ops for the following sharding.
25 | # ['q_heads', ['tensor', 'autoregressive']],
26 | # ['kv_head_dim', ['autoregressive']],
27 | ['q_heads', ['tensor']],
28 | ['kv_heads', ['tensor']],
29 | ['kv_head_dim', []],
30 | ['activation_prefill_kv_batch', []],
31 | ['activation_kv_batch', ['autoregressive']],
32 | ['activation_kv_heads', ['tensor']],
33 | ['activation_kv_head_dim', []],
34 | ['activation_heads', ['tensor']],
35 | ['activation_kv', ['tensor', 'autoregressive']],
36 | ['norm', []],
37 | ['mlp', ['tensor', 'autoregressive']],
38 | ['activation_mlp', ['tensor', 'autoregressive']],
39 | ['cache_batch_prefill', []],
40 | ['cache_batch', ['autoregressive']],
41 | ['cache_sequence', []],
42 | ['cache_heads', ['tensor']],
43 | ['cache_kv', []],
44 | ]
45 |
--------------------------------------------------------------------------------
/MaxText/inference/gpu/README.md:
--------------------------------------------------------------------------------
1 | ## Benchmarking Scripts
2 |
3 | This directory contains scripts used for baseline benchmarking with fixed parameters.
--------------------------------------------------------------------------------
/MaxText/inference/jetstream_pathways/README.md:
--------------------------------------------------------------------------------
1 | ## Build and upload MaxText + JetStream + Pathways Server image
2 |
3 | These instructions are to build the MaxText + JetStream + Pathways Server image, which calls an entrypoint script that invokes the [JetStream](https://github.com/AI-Hypercomputer/JetStream) inference server with the MaxText framework.
4 |
5 | ```
6 | docker build -t jetstream-pathways .
7 | docker tag jetstream-pathways us-docker.pkg.dev/${PROJECT_ID}/jetstream/jetstream-pathways:latest
8 | docker push us-docker.pkg.dev/${PROJECT_ID}/jetstream/jetstream-pathways:latest
9 | ```
10 |
11 | If you would like to change the version of MaxText or JetStream the image is built off of, change the `MAXTEXT_VERSION` / `JETSTREAM_VERSION` environment variable:
12 | ```
13 | ENV MAXTEXT_VERSION=
14 | ENV JETSTREAM_VERSION=
15 | ```
--------------------------------------------------------------------------------
/MaxText/inference/jetstream_pathways/jetstream_pathways_entrypoint.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Copyright 2024 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | cd /maxtext
18 | python3 -m MaxText.maxengine_server $@
19 |
--------------------------------------------------------------------------------
/MaxText/inference/maxengine_server/Dockerfile:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # Ubuntu:22.04
16 | # Use Ubuntu 22.04 from Docker Hub.
17 | # https://hub.docker.com/_/ubuntu/tags?page=1&name=22.04
18 | FROM ubuntu:22.04
19 |
20 | ENV DEBIAN_FRONTEND=noninteractive
21 | ENV MAXTEXT_VERSION=main
22 | ENV JETSTREAM_VERSION=main
23 |
24 | RUN apt -y update && apt install -y --no-install-recommends \
25 | ca-certificates \
26 | git \
27 | python3.10 \
28 | python3-pip
29 |
30 | RUN update-alternatives --install \
31 | /usr/bin/python3 python3 /usr/bin/python3.10 1
32 |
33 | RUN git clone https://github.com/AI-Hypercomputer/maxtext.git && \
34 | git clone https://github.com/AI-Hypercomputer/JetStream.git
35 |
36 | RUN cd maxtext/ && \
37 | git checkout ${MAXTEXT_VERSION} && \
38 | bash setup.sh
39 |
40 | RUN cd /JetStream && \
41 | git checkout ${JETSTREAM_VERSION} && \
42 | python3 -m pip install -e .
43 |
44 | COPY maxengine_server_entrypoint.sh /usr/bin/
45 |
46 | RUN chmod +x /usr/bin/maxengine_server_entrypoint.sh
47 |
48 | ENTRYPOINT ["/usr/bin/maxengine_server_entrypoint.sh"]
--------------------------------------------------------------------------------
/MaxText/inference/maxengine_server/README.md:
--------------------------------------------------------------------------------
1 | ## Build and upload Maxengine Server image
2 |
3 | These instructions are to build the Maxengine Server image, which calls an entrypoint script that invokes the [JetStream](https://github.com/AI-Hypercomputer/JetStream) inference server with the MaxText framework.
4 |
5 | ```
6 | docker build -t maxengine-server .
7 | docker tag maxengine-server us-docker.pkg.dev/${PROJECT_ID}/jetstream/maxengine-server:latest
8 | docker push us-docker.pkg.dev/${PROJECT_ID}/jetstream/maxengine-server:latest
9 | ```
10 |
11 | If you would like to change the version of MaxText or JetStream the image is built off of, change the `MAXTEXT_VERSION` / `JETSTREAM_VERSION` environment variable:
12 | ```
13 | ENV MAXTEXT_VERSION=
14 | ENV JETSTREAM_VERSION=
15 | ```
--------------------------------------------------------------------------------
/MaxText/inference/maxengine_server/maxengine_server_entrypoint.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Copyright 2024 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | cd /maxtext
18 | python3 -m MaxText.maxengine_server \
19 | MaxText/configs/base.yml $@
20 |
--------------------------------------------------------------------------------
/MaxText/inference_mlperf/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2025 Google LLC
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | https://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
--------------------------------------------------------------------------------
/MaxText/inference_mlperf/matmul/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2025 Google LLC
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | https://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
--------------------------------------------------------------------------------
/MaxText/inference_mlperf/matmul/matmul_dtypes.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """matrix multiplication data types"""
15 |
16 |
17 | import jax
18 |
19 | from MaxText.inference_mlperf.matmul import timing_util
20 |
21 | _PROFILE = False
22 | MATMUL_SIZES = [(250, 2048)]
23 |
24 | _INT4 = jax.numpy.int4
25 | _INT8 = jax.numpy.int8
26 | _DEFAULT = jax.numpy.bfloat16
27 |
28 |
29 | def f(X, Y):
30 | return jax.lax.batch_matmul(X, Y)
31 |
32 |
33 | f_jit = jax.jit(f)
34 |
35 | num_matmuls, matrix_size = MATMUL_SIZES[0]
36 |
37 | for dtypeA, dtypeB in [
38 | (_INT4, _INT4),
39 | (_INT4, _INT8),
40 | (_INT8, _INT4),
41 | (_INT8, _INT8),
42 | (_INT8, _DEFAULT),
43 | (_DEFAULT, _DEFAULT),
44 | ]:
45 | A = jax.numpy.ones((num_matmuls, matrix_size, matrix_size), dtype=dtypeA)
46 | B = jax.numpy.ones((num_matmuls, matrix_size, matrix_size), dtype=dtypeB)
47 |
48 | print(f"A, B shape is {f(A, B).shape}. A dtype is {A.dtype}, B dtype is {B.dtype} and prod type is {f(A, B).dtype}")
49 | timing_util.simple_timeit(f_jit, A, B, task="matmul_" + str(matrix_size), enable_profile=_PROFILE)
50 |
--------------------------------------------------------------------------------
/MaxText/inference_mlperf/matmul/timing_util.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """ Timing utility functions """
15 |
16 | import datetime
17 | import os.path
18 | import tempfile
19 |
20 | import jax
21 |
22 |
23 | def simple_timeit(f, *args, tries=10, task=None, enable_profile=False):
24 | """Simple utility to time a function for multiple runs"""
25 | assert task is not None
26 |
27 | trace_name = f"{task}" # + '_' ]+ ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(10))
28 | temp_dir = tempfile.gettempdir()
29 | trace_dir = os.path.join(temp_dir, trace_name)
30 | print(trace_dir)
31 |
32 | outcomes_ms = []
33 | jax.block_until_ready(f(*args)) # warm it up!
34 | if enable_profile:
35 | jax.profiler.start_trace(trace_dir)
36 | for _ in range(tries):
37 | s = datetime.datetime.now()
38 | jax.block_until_ready(f(*args))
39 | e = datetime.datetime.now()
40 | outcomes_ms.append(1000 * (e - s).total_seconds())
41 | if enable_profile:
42 | jax.profiler.stop_trace()
43 | average_time_ms = sum(outcomes_ms) / len(outcomes_ms)
44 | print(f"Average time ms for mm for {task} is {round(average_time_ms, 3)}")
45 | return average_time_ms / 1000
46 |
--------------------------------------------------------------------------------
/MaxText/inference_mlperf/requirements.txt:
--------------------------------------------------------------------------------
1 | transformers==4.31.0
2 | nltk==3.8.1
3 | evaluate==0.4.0
4 | absl-py==1.4.0
5 | rouge-score==0.1.2
6 | sentencepiece==0.1.99
7 | accelerate==0.21.0
8 | omegaconf
9 |
--------------------------------------------------------------------------------
/MaxText/inference_mlperf/trillium/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2025 Google LLC
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | https://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
--------------------------------------------------------------------------------
/MaxText/inference_mlperf/user.conf:
--------------------------------------------------------------------------------
1 | # The format of this config file is 'key = value'.
2 | # The key has the format 'model.scenario.key'. Value is mostly int64_t.
3 | # Model maybe '*' as wildcard. In that case the value applies to all models.
4 | # All times are in milli seconds
5 |
6 | # Set performance_sample_count for each model.
7 | llama2-70b.*.performance_sample_count_override = 24576
8 | mixtral-8x7b.*.performance_sample_count_override = 15000
9 | *.Offline.min_duration = 600000
10 |
11 |
12 | # In Offline scenario, we always have one query. But LoadGen maps this to
13 | # min_sample_count internally in Offline scenario. If the dataset size is larger
14 | # than 24576 we limit the min_query_count to 24576 and otherwise we use
15 | # the dataset size as the limit
16 | llama2-70b.Offline.min_query_count = 24576
17 | mixtral-8x7b.Offline.min_query_count = 15000
18 |
19 | # These fields should be defined and overridden by user.conf.
20 | *.Offline.target_qps = 5.0
21 |
--------------------------------------------------------------------------------
/MaxText/inference_mlperf/user100.conf:
--------------------------------------------------------------------------------
1 | # The format of this config file is 'key = value'.
2 | # The key has the format 'model.scenario.key'. Value is mostly int64_t.
3 | # Model maybe '*' as wildcard. In that case the value applies to all models.
4 | # All times are in milli seconds
5 |
6 | # Set performance_sample_count for each model.
7 | #llama2-70b.*.performance_sample_count_override = 24576
8 | llama2-70b.*.performance_sample_count_override = 100
9 | mixtral-8x7b.*.performance_sample_count_override = 100
10 | #*.Offline.min_duration = 600000
11 | *.Offline.min_duration = 60
12 |
13 |
14 | # In Offline scenario, we always have one query. But LoadGen maps this to
15 | # min_sample_count internally in Offline scenario. If the dataset size is larger
16 | # than 24576 we limit the min_query_count to 24576 and otherwise we use
17 | # the dataset size as the limit
18 | #llama2-70b.Offline.min_query_count = 24576
19 | llama2-70b.Offline.min_query_count = 100
20 | mixtral-8x7b.Offline.min_query_count = 100
21 |
22 |
23 | # These fields should be defined and overridden by user.conf.
24 | *.Offline.target_qps = 5.0
25 |
--------------------------------------------------------------------------------
/MaxText/inference_mlperf/user5000.conf:
--------------------------------------------------------------------------------
1 | # The format of this config file is 'key = value'.
2 | # The key has the format 'model.scenario.key'. Value is mostly int64_t.
3 | # Model maybe '*' as wildcard. In that case the value applies to all models.
4 | # All times are in milli seconds
5 |
6 | # Set performance_sample_count for each model.
7 | #llama2-70b.*.performance_sample_count_override = 24576
8 | llama2-70b.*.performance_sample_count_override = 5000
9 | mixtral-8x7b.*.performance_sample_count_override = 5000
10 | #*.Offline.min_duration = 600000
11 | *.Offline.min_duration = 600
12 |
13 |
14 | # In Offline scenario, we always have one query. But LoadGen maps this to
15 | # min_sample_count internally in Offline scenario. If the dataset size is larger
16 | # than 24576 we limit the min_query_count to 24576 and otherwise we use
17 | # the dataset size as the limit
18 | #llama2-70b.Offline.min_query_count = 24576
19 | llama2-70b.Offline.min_query_count = 5000
20 | mixtral-8x7b.Offline.min_query_count = 5000
21 |
22 |
23 | # These fields should be defined and overridden by user.conf.
24 | *.Offline.target_qps = 5.0
25 |
--------------------------------------------------------------------------------
/MaxText/input_pipeline/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2025 Google LLC
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | https://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
--------------------------------------------------------------------------------
/MaxText/kernels/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2025 Google LLC
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | https://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
--------------------------------------------------------------------------------
/MaxText/kernels/megablox/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Megablox kernel"""
15 |
16 | from MaxText.kernels.megablox.ops import gmm
17 |
--------------------------------------------------------------------------------
/MaxText/kernels/megablox/common.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Common utilities for GMM kernels."""
16 |
17 | import re
18 |
19 | import jax
20 | import jax.numpy as jnp
21 |
22 |
23 | def is_tpu() -> bool:
24 | return "TPU" in jax.devices()[0].device_kind
25 |
26 |
27 | def tpu_kind() -> str:
28 | """Query identification string for the currently attached TPU."""
29 | return jax.devices()[0].device_kind
30 |
31 |
32 | _TPU_KIND_PATTERN = re.compile(r"TPU v(\d+)")
33 |
34 |
35 | def tpu_generation() -> int:
36 | """Generation number of the currently attached TPU."""
37 | if version := _TPU_KIND_PATTERN.match(tpu_kind()):
38 | return int(version[1])
39 | raise NotImplementedError("only TPU devices are supported")
40 |
41 |
42 | def supports_bfloat16_matmul() -> bool:
43 | """Does the currently attached CPU support bfloat16 inputs?"""
44 | return not is_tpu() or tpu_generation() >= 4
45 |
46 |
47 | def assert_is_supported_dtype(dtype: jnp.dtype) -> None:
48 | if dtype not in (jnp.bfloat16, jnp.float32):
49 | raise ValueError(f"Expected bfloat16 or float32 array but got {dtype}.")
50 |
51 |
52 | def select_input_dtype(lhs: jnp.ndarray, rhs: jnp.ndarray) -> jnp.dtype:
53 | """A type to which both input should be adapted to before dot product."""
54 | # bf16xbf16 matmul is only supported since TPUv4 generation. In case of mixed
55 | # input precision, we need to convert bf16 argument to fp32 beforehand.
56 | if supports_bfloat16_matmul() and lhs.dtype == jnp.bfloat16 and rhs.dtype == jnp.bfloat16:
57 | return jnp.bfloat16
58 | else:
59 | return jnp.float32
60 |
--------------------------------------------------------------------------------
/MaxText/layers/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2025 Google LLC
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | https://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
--------------------------------------------------------------------------------
/MaxText/layers/initializers.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Google LLC
2 |
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 |
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 |
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Initializers."""
16 |
17 | from typing import Callable, Tuple, Union
18 |
19 | import jax
20 |
21 | from flax import linen as nn
22 |
23 | from MaxText.common_types import Array, DType, Shape, PRNGKey
24 |
25 | Initializer = Callable[[PRNGKey, Shape, DType], Array]
26 | InitializerAxis = Union[int, Tuple[int, ...]]
27 | NdInitializer = Callable[[PRNGKey, Shape, DType, InitializerAxis, InitializerAxis], Array]
28 |
29 | default_embed_init = nn.initializers.variance_scaling(1.0, "fan_in", "normal", out_axis=0)
30 |
31 | default_bias_init = jax.nn.initializers.constant(0.0)
32 |
33 |
34 | def nd_dense_init(scale, mode, distribution):
35 | """Initializer with in_axis, out_axis set at call time."""
36 |
37 | def init_fn(key, shape, dtype, in_axis, out_axis):
38 | fn = jax.nn.initializers.variance_scaling(scale, mode, distribution, in_axis, out_axis)
39 | return fn(key, shape, dtype)
40 |
41 | return init_fn
42 |
--------------------------------------------------------------------------------
/MaxText/layers/normalizations.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Normalization Layers."""
16 |
17 | from typing import Any, Tuple, Optional
18 |
19 | from flax import linen as nn
20 | from jax import lax
21 | import jax
22 | import jax.numpy as jnp
23 | from MaxText import max_logging
24 | from MaxText.layers.initializers import Initializer
25 |
26 |
27 | class RMSNorm(nn.Module):
28 | """RMS normalization."""
29 |
30 | epsilon: float = 1e-6
31 | dtype: Any = jnp.float32
32 | weight_dtype: Any = jnp.float32
33 | kernel_axes: Tuple[Optional[str], ...] = ()
34 | scale_init: Initializer = nn.initializers.ones
35 | parameter_memory_host_offload: bool = False
36 |
37 | @nn.compact
38 | def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
39 | """Applies layer normalization on the input."""
40 | x = jnp.asarray(x, jnp.float32)
41 | features = x.shape[-1]
42 | mean2 = jnp.mean(lax.square(x), axis=-1, keepdims=True)
43 | y = jnp.asarray(x * lax.rsqrt(mean2 + self.epsilon), self.dtype)
44 | scale = self.param(
45 | "scale",
46 | nn.with_logical_partitioning(self.scale_init, self.kernel_axes),
47 | (features,),
48 | self.weight_dtype,
49 | )
50 | # Move scale to device if parameter offloading is enabled
51 | if self.parameter_memory_host_offload:
52 | max_logging.log("normalizations.py: Moving scale parameter to device")
53 | scale = jax.device_put(scale, jax._src.sharding_impls.TransferToMemoryKind("device"))
54 |
55 | scale = jnp.asarray(scale, self.dtype)
56 | return y * scale
57 |
--------------------------------------------------------------------------------
/MaxText/load_and_quantize_checkpoint.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """CLI utility for loading and quantizing a checkpoint."""
16 |
17 | import os
18 | from typing import Sequence
19 |
20 | from absl import app
21 |
22 | import jax
23 |
24 | from MaxText import max_utils
25 | from MaxText import maxengine
26 | from MaxText import pyconfig
27 |
28 |
29 | def main(argv: Sequence[str]) -> None:
30 | jax.config.update("jax_default_prng_impl", "unsafe_rbg")
31 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
32 |
33 | config = pyconfig.initialize(argv)
34 | validate_config(config)
35 | max_utils.print_system_information()
36 |
37 | engine = maxengine.MaxEngine(config)
38 | rng = jax.random.PRNGKey(1234)
39 | rng, rng_load_params = jax.random.split(rng)
40 |
41 | # load_params will load a checkpoint and quantize if the following parameters are set:
42 | # quantization=$valid_quantization_type \
43 | # save_quantized_params_path=$gsbucket_path \
44 | # checkpoint_is_quantized=false (default)
45 | engine.load_params(rng_load_params)
46 |
47 |
48 | def validate_config(config):
49 | assert config.load_full_state_path == "", "Operation on full states not supported! Convert to parameter checkpoint first."
50 |
51 |
52 | if __name__ == "__main__":
53 | app.run(main)
54 |
--------------------------------------------------------------------------------
/MaxText/max_logging.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2023 Google LLC
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | https://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | """Stub for logging utilities. Right now just meant to avoid raw prints."""
18 |
19 |
20 | def log(user_str):
21 | print(user_str, flush=True)
22 |
--------------------------------------------------------------------------------
/MaxText/scratch_code/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2025 Google LLC
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | https://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
--------------------------------------------------------------------------------
/MaxText/scratch_code/gemma_7b.sh:
--------------------------------------------------------------------------------
1 | export M_LOAD_PARAMETERS_PATH=gs://runner-maxtext-logs/reroll5/checkpoints/10/items/
2 | export M_PER_DEVICE_BATCH_SIZE=24
3 | export M_MAX_PREFILL_PREDICT_LENGTH=1024
4 | export M_MAX_TARGET_LENGTH=2048
5 |
6 | #python3 -m MaxText.decode MaxText/configs/base.yml tokenizer_path=assets/tokenizer.gemma run_name=runner_2024-03-06-04-17 steps=10 weight_dtype=bfloat16 async_checkpointing=false model_name=gemma-7b ici_fsdp_parallelism=1 ici_autoregressive_parallelism=-1 scan_layers=false
7 |
8 | python3 -m MaxText.maxengine_server MaxText/configs/base.yml tokenizer_path=assets/tokenizer.gemma run_name=runner_2024-03-06-04-17 steps=10 weight_dtype=bfloat16 async_checkpointing=false model_name=gemma-7b ici_fsdp_parallelism=1 ici_autoregressive_parallelism=-1 scan_layers=false
9 |
--------------------------------------------------------------------------------
/MaxText/scratch_code/run_inference_microbenchmark.sh:
--------------------------------------------------------------------------------
1 | # llama2-7b
2 | python3 -m MaxText.inference_microbenchmark \
3 | MaxText/configs/base.yml \
4 | async_checkpointing=false \
5 | attention=autoselected \
6 | dataset_path=gs://maxtext-dataset \
7 | ici_fsdp_parallelism=1 \
8 | ici_autoregressive_parallelism=-1 \
9 | max_prefill_predict_length=1024 \
10 | max_target_length=2048 \
11 | per_device_batch_size=16 \
12 | quantization=int8 \
13 | quantize_kvcache=True \
14 | steps=10 \
15 | scan_layers=false \
16 | model_name=llama2-7b \
17 | weight_dtype=bfloat16 \
18 | tokenizer_path=assets/tokenizer.llama2
19 |
--------------------------------------------------------------------------------
/MaxText/scratch_code/setup_transformer.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | python3 -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
3 | python3 -m pip install tokenizers -U
4 | python3 -m pip install transformers -U
5 |
--------------------------------------------------------------------------------
/MaxText/test_assets/.gitignore:
--------------------------------------------------------------------------------
1 | golden_data_gemma-2b.jsonl
2 | golden_data_gemma2-27b.jsonl
3 | golden_data_gemma2-2b.jsonl
4 | golden_data_gemma2-9b.jsonl
5 | golden_data_gemma3-12b.jsonl
6 | golden_data_gemma3-27b.jsonl
7 | golden_data_gemma3-4b.jsonl
8 | golden_data_llama2-70b.jsonl
9 | golden_data_llama2-7b.jsonl
10 | golden_data_llama3-70b.jsonl
11 | golden_data_llama3-8b.jsonl
12 | golden_data_llama3.1-405b.jsonl
13 | golden_data_llama3.1-70b.jsonl
14 | golden_data_llama3.1-8b.jsonl
15 | golden_data_llama3.3-70b.jsonl
16 | golden_data_llama4-17b-16e.jsonl
17 | golden_data_mistral-7b.jsonl
18 | golden_data_mixtral-8x22b.jsonl
19 | golden_data_mixtral-8x7b.jsonl
20 |
--------------------------------------------------------------------------------
/MaxText/test_assets/golden_data_grpo_default.jsonl:
--------------------------------------------------------------------------------
1 | {"maxtext_loss": 0.0, "input_ids": [128000, 9906, 1917, 420, 374, 264, 1296, 128000, 9906, 1917, 420, 374, 264, 1296, 0, 0, 0, 0], "generated_completions": [128000, 9906, 1917, 420, 374, 264, 1296, 62387, 64248, 94859, 15603, 112205, 13091, 32909, 90304, 116037, 114513, 35686, 26560, 91645, 85220, 105433, 75171, 0, 0, 0, 0], "maxtext_per_token_logps_no_ckpt_loading": [-0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -12.265657424926758, -12.379983901977539, -13.414685249328613, -15.10253620147705, -13.143253326416016, -11.190017700195312, -12.384369850158691, -0.0, -0.0, -0.0, -0.0], "avg_kl": 0.0, "avg_advantage": 0.0}
2 |
--------------------------------------------------------------------------------
/MaxText/test_assets/golden_data_sft_default.jsonl:
--------------------------------------------------------------------------------
1 | {"data": {"messages": [{"role": "user", "content": "Hello, what is your name?"}, {"role": "assistant", "content": "I am a chatbot. How can I help?"}]}, "tokens": [1, 518, 25580, 29962, 15043, 29892, 825, 338, 596, 1024, 29973, 518, 29914, 25580, 29962, 306, 626, 263, 13563, 7451, 29889, 1128, 508, 306, 1371, 29973, 29871, 2], "attention_mask": [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "token_log_probs": [-10.455509185791016, -11.176178932189941, -11.196348190307617, -10.699331283569336, -10.40202808380127, -10.493494987487793, -10.981508255004883, -9.65949821472168, -10.184316635131836, -11.546117782592773, -11.033979415893555, -11.696186065673828, -10.98974609375, -10.65627670288086, -9.982662200927734, -11.240318298339844, -12.635238647460938, -9.757575988769531, -12.000450134277344, -11.398622512817383, -10.542476654052734, -10.546899795532227, -11.729068756103516, -10.480279922485352, -11.757697105407715, -10.342456817626953, -9.775711059570312]}
2 |
--------------------------------------------------------------------------------
/MaxText/test_assets/test_image.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AI-Hypercomputer/maxtext/2c82db1872ba4e80ea240e123229bc87ea69591c/MaxText/test_assets/test_image.jpg
--------------------------------------------------------------------------------
/MaxText/tests/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2025 Google LLC
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | https://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
--------------------------------------------------------------------------------
/MaxText/tests/aot_hlo_identical_script.sh:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | # https://www.apache.org/licenses/LICENSE-2.0
6 | # Unless required by applicable law or agreed to in writing, software
7 | # distributed under the License is distributed on an "AS IS" BASIS,
8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 | # See the License for the specific language governing permissions and
10 | # limitations under the License.
11 |
12 |
13 | # Bash script to run AOT and real runs (on a v4-8)
14 | # This needs to run via a bash script so the AOT and real runs each
15 | # initialize jax/XLA separtely (e.g. separate dump directories)
16 | # and we do not contaminate the XLA flags of the second run with the first.
17 |
18 | compile_dump_dir=$1
19 | train_dump_dir=$2
20 | custom_args=$3
21 |
22 | shared_args="configs/base.yml base_output_directory=gs://runner-maxtext-logs run_name=compile_equivalent_test dataset_path=gs://maxtext-dataset dataset_type=synthetic steps=5 enable_checkpointing=False $custom_args"
23 | aot_args="compile_topology=v4-8 compile_topology_num_slices=1"
24 |
25 | export XLA_FLAGS=--xla_dump_to=${compile_dump_dir}
26 | python3 -m MaxText.train_compile $shared_args $aot_args
27 |
28 | export XLA_FLAGS=--xla_dump_to=${train_dump_dir}
29 | python3 -m MaxText.train $shared_args
30 |
31 |
--------------------------------------------------------------------------------
/MaxText/tests/hf_checkpoint_conversion_test.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2025 Google LLC
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | https://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | """ Tests for kernels """
18 |
19 | import numpy as np
20 | from MaxText.max_utils import permute_to_match_maxtext_rope, unpermute_from_match_maxtext_rope
21 | import unittest
22 |
23 |
24 | class HFCheckpointConversionTest(unittest.TestCase):
25 |
26 | def test_huggingface_to_maxtext_back_to_huggingface_flow(self):
27 | base_num_query_heads = 16
28 | head_dim = 32
29 | wq = np.arange(base_num_query_heads * head_dim * base_num_query_heads * head_dim, dtype=np.float16).reshape(
30 | base_num_query_heads * head_dim, base_num_query_heads * head_dim
31 | )
32 | wq1 = wq.transpose()
33 | wq2 = np.reshape(wq1, [base_num_query_heads * head_dim, base_num_query_heads, head_dim])
34 |
35 | wq3 = permute_to_match_maxtext_rope(wq2)
36 | stack_shape = (1,)
37 | x = np.zeros(stack_shape + wq3.shape, dtype=np.float16)
38 | x[0, ...] = wq3
39 | x = np.transpose(x, axes=(1, 0, 2, 3))
40 |
41 | x = x[:, 0, :, :]
42 | wq4 = unpermute_from_match_maxtext_rope(x, "llama3.1")
43 | wq5 = wq4.reshape(base_num_query_heads * head_dim, base_num_query_heads * head_dim)
44 | wq6 = wq5.transpose()
45 |
46 | if not np.array_equal(wq, wq6):
47 | print("Test failed: wq does not match wq6")
48 |
49 | if not np.array_equal(wq1, wq5):
50 | print("Test failed: wq1 does not match wq5")
51 |
52 | if not np.array_equal(wq2, wq4):
53 | print("Test failed: wq2 does not match wq4")
54 |
55 |
56 | if __name__ == "__main__":
57 | unittest.main()
58 |
--------------------------------------------------------------------------------
/MaxText/tests/inference/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2025 Google LLC
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | https://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
--------------------------------------------------------------------------------
/MaxText/tests/inference/test_llama2_7b_bf16.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Define the arguments in an array
4 | args=(
5 | "-m"
6 | "MaxText.decode"
7 | "MaxText/configs/base.yml"
8 | "tokenizer_path=assets/tokenizer.llama2"
9 | "model_name=llama2-7b"
10 | "load_parameters_path=gs://runner-maxtext-logs/direct_generate_param_only_checkpoint_2024-06-11-04-13/checkpoints/0/items/"
11 | "checkpoint_is_quantized=false"
12 | "weight_dtype=bfloat16"
13 | "max_prefill_predict_length=16"
14 | "max_target_length=32"
15 | "ici_fsdp_parallelism=1"
16 | "ici_autoregressive_parallelism=1"
17 | "ici_tensor_parallelism=-1"
18 | "scan_layers=false"
19 | "per_device_batch_size=1"
20 | "attention=paged"
21 | "pagedattn_num_pages=64"
22 | "pagedattn_tokens_per_page=8"
23 | "pagedattn_pages_per_compute_block=4"
24 | )
25 |
26 | # Execute the Python script with the arguments
27 | python3 "${args[@]}"
28 |
29 |
--------------------------------------------------------------------------------
/MaxText/tests/inference/test_llama2_7b_int8.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Define the arguments in an array
4 | args=(
5 | "-m"
6 | "MaxText.decode"
7 | "MaxText/configs/base.yml"
8 | "tokenizer_path=assets/tokenizer.llama2"
9 | "model_name=llama2-7b"
10 | "load_parameters_path=gs://msingh-bkt/checkpoints/quant_llama2-7b-chat/20241120034012/int8_"
11 | "checkpoint_is_quantized=true"
12 | "quantization=int8"
13 | "weight_dtype=bfloat16"
14 | "max_prefill_predict_length=16"
15 | "max_target_length=32"
16 | "ici_fsdp_parallelism=1"
17 | "ici_autoregressive_parallelism=1"
18 | "ici_tensor_parallelism=-1"
19 | "scan_layers=false"
20 | "per_device_batch_size=1"
21 | "attention=paged"
22 | "pagedattn_num_pages=64"
23 | "pagedattn_tokens_per_page=8"
24 | "pagedattn_pages_per_compute_block=4"
25 | )
26 |
27 | # Execute the Python script with the arguments
28 | python3 "${args[@]}"
29 |
--------------------------------------------------------------------------------
/MaxText/tests/integration_tests/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2025 Google LLC
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | https://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
--------------------------------------------------------------------------------
/MaxText/tests/integration_tests/inference_microbenchmark_smoke_test.py:
--------------------------------------------------------------------------------
1 | """Copyright 2024 Google LLC
2 |
3 | Licensed under the Apache License, Version 2.0 (the "License");
4 | you may not use this file except in compliance with the License.
5 | You may obtain a copy of the License at
6 |
7 | https://www.apache.org/licenses/LICENSE-2.0
8 |
9 | Unless required by applicable law or agreed to in writing, software
10 | distributed under the License is distributed on an "AS IS" BASIS,
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | See the License for the specific language governing permissions and
13 | limitations under the License.
14 | """
15 |
16 | """ Smoke test for inference microbenchmark"""
17 | import jax
18 | import os.path
19 | import pytest
20 | import unittest
21 | from absl.testing import absltest
22 |
23 | from MaxText import pyconfig
24 | from MaxText.globals import PKG_DIR
25 | from MaxText.inference_microbenchmark import run_benchmarks
26 |
27 |
28 | class Inference_Microbenchmark(unittest.TestCase):
29 | """integration test for inference microbenchmark"""
30 |
31 | @pytest.mark.integration_test
32 | @pytest.mark.tpu_only
33 | def test(self):
34 | jax.config.update("jax_default_prng_impl", "unsafe_rbg")
35 | config = pyconfig.initialize(
36 | [
37 | None,
38 | os.path.join(PKG_DIR, "configs", "tpu_smoke_test.yml"),
39 | rf"tokenizer_path={os.path.join(os.path.dirname(PKG_DIR), 'assets', 'tokenizer.llama2')}",
40 | "ici_autoregressive_parallelism=-1",
41 | "ici_fsdp_parallelism=1",
42 | "max_prefill_predict_length=1024",
43 | "max_target_length=2048",
44 | "scan_layers=false",
45 | "weight_dtype=bfloat16",
46 | "attention=dot_product",
47 | "skip_jax_distributed_system=True",
48 | ]
49 | )
50 | run_benchmarks(config)
51 |
52 |
53 | if __name__ == "__main__":
54 | absltest.main()
55 |
--------------------------------------------------------------------------------
/MaxText/tests/integration_tests/shmap_collective_matmul_test.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2024 Google LLC
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | https://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | """Integration test for pedagogical_examples/shmap_collective_matmul.py"""
18 |
19 | import os.path
20 | import sys
21 |
22 | import pytest
23 |
24 | from MaxText.globals import PKG_DIR
25 |
26 | sys.path.append(os.path.join(os.path.dirname(PKG_DIR), "pedagogical_examples"))
27 |
28 | # Uncomment the import when b/415022795 is fixed
29 | # from pedagogical_examples.shmap_collective_matmul import main
30 |
31 |
32 | @pytest.mark.skip(reason="Enable when b/415022795 is fixed")
33 | @pytest.mark.integration_test
34 | @pytest.mark.tpu_only
35 | def test_shmap_collective_matmul_example():
36 | """Validate Pedagogical Example, Shmap_collective_matmul."""
37 | # Uncomment main() assertion when b/415022795 is fixed
38 | # assert main() is True
39 |
--------------------------------------------------------------------------------
/MaxText/tests/train_gpu_smoke_test.py:
--------------------------------------------------------------------------------
1 | """Copyright 2024 Google LLC
2 |
3 | Licensed under the Apache License, Version 2.0 (the "License");
4 | you may not use this file except in compliance with the License.
5 | You may obtain a copy of the License at
6 |
7 | https://www.apache.org/licenses/LICENSE-2.0
8 |
9 | Unless required by applicable law or agreed to in writing, software
10 | distributed under the License is distributed on an "AS IS" BASIS,
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | See the License for the specific language governing permissions and
13 | limitations under the License.
14 | """
15 |
16 | """ Smoke test """
17 | import os
18 | import unittest
19 |
20 | from absl.testing import absltest
21 |
22 | from MaxText.train import main as train_main
23 | from MaxText.globals import PKG_DIR
24 |
25 |
26 | class Train(unittest.TestCase):
27 | """Smoke test for GPUs."""
28 |
29 | def test_tiny_config(self):
30 | test_tmpdir = os.environ.get("TEST_TMPDIR") # pylint: disable=unused-variable
31 | train_main(
32 | [
33 | None,
34 | os.path.join(PKG_DIR, "configs", "gpu_smoke_test.yml"),
35 | # pylint: disable=f-string-without-interpolation
36 | f"base_output_directory=gs://runner-maxtext-logs",
37 | "run_name=runner_test",
38 | r"dataset_path=gs://maxtext-dataset",
39 | "enable_checkpointing=False",
40 | rf"tokenizer_path={os.path.join(os.path.dirname(PKG_DIR), 'assets', 'tokenizer.llama2')}",
41 | "enable_goodput_recording=False",
42 | "enable_checkpoint_cloud_logger=False",
43 | "monitor_goodput=False",
44 | ]
45 | )
46 |
47 |
48 | if __name__ == "__main__":
49 | absltest.main()
50 |
--------------------------------------------------------------------------------
/MaxText/tests/train_smoke_test.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2023 Google LLC
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | https://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | """ Smoke test """
18 | import os
19 | import unittest
20 |
21 | from absl.testing import absltest
22 |
23 | from MaxText.train import main as train_main
24 | from MaxText.globals import PKG_DIR
25 |
26 |
27 | class Train(unittest.TestCase):
28 | """Smoke test G3 only"""
29 |
30 | def test_tiny_config(self):
31 | test_tmpdir = os.environ.get("TEST_TMPDIR") # pylint: disable=unused-variable
32 | train_main(
33 | [
34 | None,
35 | os.path.join(PKG_DIR, "configs", "base.yml"),
36 | # pylint: disable=f-string-without-interpolation
37 | f"base_output_directory=gs://runner-maxtext-logs",
38 | "run_name=runner_test",
39 | r"dataset_path=gs://maxtext-dataset",
40 | "base_emb_dim=8",
41 | "base_num_query_heads=4",
42 | "base_num_kv_heads=4",
43 | "base_mlp_dim=32",
44 | "base_num_decoder_layers=8",
45 | "head_dim=128",
46 | "per_device_batch_size=2",
47 | "max_target_length=1024",
48 | "dataset_type=synthetic",
49 | "steps=10",
50 | "enable_checkpointing=False",
51 | rf"tokenizer_path={os.path.join(os.path.dirname(PKG_DIR), 'assets', 'tokenizer.llama2')}",
52 | "enable_goodput_recording=False",
53 | "enable_checkpoint_cloud_logger=False",
54 | "monitor_goodput=False",
55 | ]
56 | )
57 |
58 |
59 | if __name__ == "__main__":
60 | absltest.main()
61 |
--------------------------------------------------------------------------------
/MaxText/utils/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2025 Google LLC
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | https://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
--------------------------------------------------------------------------------
/assets/tokenizer:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AI-Hypercomputer/maxtext/2c82db1872ba4e80ea240e123229bc87ea69591c/assets/tokenizer
--------------------------------------------------------------------------------
/assets/tokenizer.gemma:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AI-Hypercomputer/maxtext/2c82db1872ba4e80ea240e123229bc87ea69591c/assets/tokenizer.gemma
--------------------------------------------------------------------------------
/assets/tokenizer.gemma3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AI-Hypercomputer/maxtext/2c82db1872ba4e80ea240e123229bc87ea69591c/assets/tokenizer.gemma3
--------------------------------------------------------------------------------
/assets/tokenizer.llama2:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AI-Hypercomputer/maxtext/2c82db1872ba4e80ea240e123229bc87ea69591c/assets/tokenizer.llama2
--------------------------------------------------------------------------------
/assets/tokenizer.mistral-v1:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AI-Hypercomputer/maxtext/2c82db1872ba4e80ea240e123229bc87ea69591c/assets/tokenizer.mistral-v1
--------------------------------------------------------------------------------
/assets/tokenizer.mistral-v3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AI-Hypercomputer/maxtext/2c82db1872ba4e80ea240e123229bc87ea69591c/assets/tokenizer.mistral-v3
--------------------------------------------------------------------------------
/benchmarks/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2025 Google LLC
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | https://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
--------------------------------------------------------------------------------
/benchmarks/disruption_management/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2025 Google LLC
3 | Licensed under the Apache License, Version 2.0 (the "License");
4 | you may not use this file except in compliance with the License.
5 | You may obtain a copy of the License at
6 | https://www.apache.org/licenses/LICENSE-2.0
7 | Unless required by applicable law or agreed to in writing, software
8 | distributed under the License is distributed on an "AS IS" BASIS,
9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10 | See the License for the specific language governing permissions and
11 | limitations under the License.
12 | """
13 |
--------------------------------------------------------------------------------
/benchmarks/mmlu/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2025 Google LLC
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | https://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
--------------------------------------------------------------------------------
/benchmarks/recipes/__init__.py:
--------------------------------------------------------------------------------
1 | """Copyright 2025 Google LLC
2 |
3 | Licensed under the Apache License, Version 2.0 (the "License");
4 | you may not use this file except in compliance with the License.
5 | You may obtain a copy of the License at
6 |
7 | https://www.apache.org/licenses/LICENSE-2.0
8 |
9 | Unless required by applicable law or agreed to in writing, software
10 | distributed under the License is distributed on an "AS IS" BASIS,
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | See the License for the specific language governing permissions and
13 | limitations under the License.
14 | """
15 |
--------------------------------------------------------------------------------
/benchmarks/xpk_configs.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2024 Google LLC
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | https://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | import dataclasses
18 |
19 |
20 | # This is needed to prevent circular imports.
21 | @dataclasses.dataclass
22 | class XpkClusterConfig:
23 | """Holds details related to a XPK cluster to run workloads on."""
24 |
25 | cluster_name: str
26 | project: str
27 | zone: str
28 | device_type: str
29 |
--------------------------------------------------------------------------------
/code_style.sh:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # Clean up Python codes using Pylint & Pyink
16 | # Googlers: please run `sudo apt install pipx; pipx install pylint --force; pipx install pyink==23.10.0` in advance
17 |
18 | set -e # Exit immediately if any command fails
19 |
20 | FOLDERS_TO_FORMAT=("MaxText" "pedagogical_examples")
21 | LINE_LENGTH=$(grep -E "^max-line-length=" pylintrc | cut -d '=' -f 2)
22 |
23 | # Check for --check flag
24 | CHECK_ONLY_PYINK_FLAGS=""
25 | if [[ "$1" == "--check" ]]; then
26 | CHECK_ONLY_PYINK_FLAGS="--check --diff --color"
27 | fi
28 |
29 | for folder in "${FOLDERS_TO_FORMAT[@]}"
30 | do
31 | pyink "$folder" ${CHECK_ONLY_PYINK_FLAGS} --pyink-indentation=2 --line-length=${LINE_LENGTH}
32 | done
33 |
34 | for folder in "${FOLDERS_TO_FORMAT[@]}"
35 | do
36 | # pylint doesn't change files, only reports errors.
37 | pylint --disable R0401,R0917,W0201,W0613 "./$folder"
38 | done
39 |
40 | echo "Successfully clean up all codes."
41 |
--------------------------------------------------------------------------------
/download_dataset.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Copyright 2023 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # https://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | # This script downloads c4/en/3.0.1 to your gcs bucket directory
18 | # Usage bash download_dataset.sh <> <>
19 | # Usage example: bash download_dataset.sh cloud-tpu-multipod-dev gs://maxtext-dataset
20 |
21 | function remove_trailing_slash {
22 | if [[ $1 =~ /$ ]]; then # Check if the path ends with a slash
23 | echo "${1::-1}" # Remove the last character
24 | else
25 | echo "$1" # Output the path as-is
26 | fi
27 | }
28 |
29 | gsutil -u $1 -m cp 'gs://allennlp-tensorflow-datasets/c4/en/3.0.1/*' $(remove_trailing_slash $2)/c4/en/3.0.1
30 |
--------------------------------------------------------------------------------
/end_to_end/test_jdi.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | echo "Running the jax.distributed.initialize test";
4 | python3 -c "import jax; jax.distributed.initialize()";
5 | if [[ "$?" -eq "0" ]]; then
6 | echo "Test exit status 0, success!"
7 | else
8 | echo "Non-zero exit status, test failed!"
9 | fi
10 |
--------------------------------------------------------------------------------
/end_to_end/test_mtc_phase_2_save_path.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -ex
3 |
4 | RUN_NAME=${1}_$(date +%Y-%m-%d-%H)
5 | OUTPUT_PATH=${2}
6 | DATASET_PATH=${3}
7 |
8 | export TPU_PREMAPPED_BUFFER_SIZE=20000014336
9 | export TPU_PREMAPPED_BUFFER_TRANSFER_THRESHOLD_BYTES=20000014336
10 |
11 | # Train and save checkpoint
12 | python3 -m MaxText.train MaxText/configs/base.yml remat_policy=full base_output_directory=$OUTPUT_PATH dataset_path=$DATASET_PATH \
13 | steps=100 checkpoint_period=200 run_name=$RUN_NAME enable_emergency_checkpoint=true local_checkpoint_directory=/local local_checkpoint_period=20 use_replicator_service=True replicator_backup_interval_minutes=5 metrics_file='saved_metrics.txt'
14 |
--------------------------------------------------------------------------------
/end_to_end/test_multi_tier_checkpointing.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -ex
3 |
4 | RUN_NAME=${1}_$(date +%Y-%m-%d-%H)
5 | OUTPUT_PATH=${2}
6 | DATASET_PATH=${3}
7 |
8 | export TPU_PREMAPPED_BUFFER_SIZE=20000014336
9 | export TPU_PREMAPPED_BUFFER_TRANSFER_THRESHOLD_BYTES=20000014336
10 |
11 | # Train and save checkpoint
12 | python3 -m MaxText.train MaxText/configs/base.yml remat_policy=full base_output_directory=$OUTPUT_PATH dataset_path=$DATASET_PATH \
13 | steps=100 enable_emergency_checkpoint=true checkpoint_period=200 local_checkpoint_directory=/local local_checkpoint_period=20 run_name=$RUN_NAME metrics_file='saved_metrics.txt'
14 |
15 | # Retrieve checkpoint
16 | python3 -m MaxText.train MaxText/configs/base.yml remat_policy=full base_output_directory=$OUTPUT_PATH dataset_path=$DATASET_PATH \
17 | steps=110 enable_emergency_checkpoint=true checkpoint_period=200 local_checkpoint_directory=/local local_checkpoint_period=20 run_name=$RUN_NAME metrics_file='restored_metrics.txt'
18 |
19 |
20 | python3 end_to_end/tpu/eval_assert.py checkpoint_save_restore metrics.txt learning/loss
21 |
22 | # Clean up ramdisk
23 | rm -rf /local/*
24 |
--------------------------------------------------------------------------------
/end_to_end/tpu/gemma/Run_Gemma.md:
--------------------------------------------------------------------------------
1 |
16 |
17 | # Gemma
18 | [Gemma](https://ai.google.dev/gemma) is a family of lightweight, state-of-the art open models built from research and technology that we used to create the Gemini models.
19 |
20 | Following the instructions at [kaggle](https://www.kaggle.com/models/google/gemma/frameworks/maxText) will let you download Gemma model weights. You will have to consent to license for Gemma using your kaggle account's [API credentials](https://github.com/Kaggle/kaggle-api?tab=readme-ov-file#api-credentials).
21 |
22 | After downloading the weights run [convert_gemma_chkpt.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/MaxText/convert_gemma_chkpt.py), which converts the checkpoint to be compatible with MaxText and uploads them to a GCS bucket. You can run decode and finetuning using instructions mentioned in the test scripts at [end_to_end/tpu/gemma](https://github.com/AI-Hypercomputer/maxtext/tree/main/end_to_end/tpu/gemma).
23 |
24 | ## MaxText supports pretraining and finetuning with high performance
25 |
26 | Model Flop utilization for training on v5e and v5p TPUs.
27 |
28 | | Model | v5e-256 (bf16) | v5p-128 (bf16) | v5e-256 (int8) | v5p-128 (int8) |
29 | | -------- | -------------- | -------------- | -------------- | -------------- |
30 | | Gemma-2b | 58% | 55% | 64% | 68% |
31 | | Gemma-7b | 58% | 60% | 70% | 70% |
32 |
--------------------------------------------------------------------------------
/end_to_end/tpu/llama3.1/405b/3_test_llama3.1_405b.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # This file tests the quantization of the Llama3.1-405b checkpoint, and assumes an unscanned checkpoint already exists.
4 |
5 | set -ex
6 |
7 | # We install torch CPU because the checkpoint conversion script does not need a TPU/GPU
8 | python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu
9 |
10 | # This is defined in 2_test_llama3.1_405b.sh
11 | export MODEL_VARIATION='llama3.1-405b'
12 | export UNSCANNED_CHECKPOINT=gs://maxtext-llama/llama3.1_405b_bf16/unscanned/0/items
13 |
14 | # Non-Googlers please remember to point `SAVE_QUANT_PARAMS_PATH` to the GCS bucket where you want to save your quantized checkpoint
15 | export SAVE_QUANT_PARAMS_PATH=gs://maxtext-llama/llama3.1_405b_int8
16 |
17 | export QUANTIZE_TYPE="int8"
18 |
19 | JAX_PLATFORMS=cpu python3 -m MaxText.load_and_quantize_checkpoint \
20 | MaxText/configs/base.yml \
21 | tokenizer_path=assets/tokenizer_llama3.tiktoken \
22 | tokenizer_type=tiktoken \
23 | load_parameters_path=${UNSCANNED_CHECKPOINT} \
24 | max_prefill_predict_length=1024 \
25 | max_target_length=2048 \
26 | model_name=${MODEL_VARIATION} \
27 | ici_fsdp_parallelism=1 \
28 | ici_autoregressive_parallelism=1 \
29 | ici_tensor_parallelism=-1 \
30 | scan_layers=false \
31 | weight_dtype=bfloat16 \
32 | per_device_batch_size=1 \
33 | attention=dot_product \
34 | quantization=${QUANTIZE_TYPE} \
35 | save_quantized_params_path=${SAVE_QUANT_PARAMS_PATH} \
36 | async_checkpointing=false
37 |
--------------------------------------------------------------------------------
/end_to_end/tpu/llama_finetuning_test.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # This script is designed for internal use within Google. External users can adapt it by:
4 | # - Updating GCS paths (gs://) to your accessible locations.
5 | # - Using the checkpoint generated from train.py or available one in open source (https://llama.meta.com/llama-downloads/).
6 |
7 | set -e
8 | idx=$(date +%Y-%m-%d-%H-%M)
9 |
10 | base_ckpt_path=gs://maxtext-llama/test/2024-01-15-06-49/decode-ckpt-maxtext/0/items
11 | BASE_OUTPUT_DIRECTORY=gs://runner-maxtext-logs
12 | DATASET_PATH=gs://maxtext-dataset
13 |
14 | export LOSS_THRESHOLD=2.5
15 |
16 | python3 -m MaxText.train MaxText/configs/base.yml run_name=runner_direct_${idx} base_output_directory=${BASE_OUTPUT_DIRECTORY} load_parameters_path=${base_ckpt_path} model_name='llama2-7b' dataset_path=${DATASET_PATH} async_checkpointing=false model_name='llama2-7b' ici_tensor_parallelism=4 steps=10 per_device_batch_size=.25 metrics_file='metrics.txt'
17 |
18 | # Assert training loss is smaller than input LOSS_THRESHOLD
19 | python3 end_to_end/tpu/eval_assert.py final_loss metrics.txt $LOSS_THRESHOLD
--------------------------------------------------------------------------------
/end_to_end/tpu/mixtral/Run_Mixtral.md:
--------------------------------------------------------------------------------
1 |
16 |
17 | # Mixtral
18 |
19 | [Mixtral](https://mistral.ai/news/mixtral-of-experts/) is a state-of-the-art AI model developed by Mistral AI, utilizing a sparse mixture-of-experts (MoE) architecture.
20 |
21 |
22 | To get started, follow the instructions at [mistral-inference](https://github.com/mistralai/mistral-inference) to download the model. Once downloaded, run [llama_or_mistral_ckpt.py](../../../MaxText/llama_or_mistral_ckpt.py) to convert the checkpoint for MaxText compatibility. You can then proceed with decoding, pretraining, and finetuning. You could find Mixtral 8x7B example in the [end_to_end/tpu/mixtral/8x7b](../mixtral/8x7b) test scripts.
23 |
24 |
25 | Additionally, Mixtral integrates with [MegaBlocks](https://arxiv.org/abs/2211.15841), an efficient dropless MoE strategy, which can be activated by setting both sparse_matmul and megablox flags to True (default).
26 |
27 |
28 | ## MaxText supports pretraining and finetuning with high performance
29 |
30 | Model Flop utilization for training on v5p TPUs.
31 |
32 | | Model size | Accelerator type | TFLOP/chip/sec | Model flops utilization (MFU) |
33 | | ------------ | -------------- | -------------- | -------------- |
34 | | Mixtral 8X7B | v5p-128 | 251.94 | 54.89% |
35 |
36 |
37 |
--------------------------------------------------------------------------------
/end_to_end/tpu/test_checkpoint_resharding.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -ex
3 |
4 | RUN_NAME=${1}_$(date +%Y-%m-%d-%H)
5 | OUTPUT_PATH=${2}
6 | DATASET_PATH=${3}
7 |
8 | # Train and save checkpoint - sharded with DCN Data Parallelism + ICI FSDP Parallelism
9 | python3 -m MaxText.train MaxText/configs/base.yml run_name=$RUN_NAME steps=101\
10 | metrics_file='saved_metrics.txt' checkpoint_period=20 base_output_directory=$OUTPUT_PATH dataset_path=$DATASET_PATH\
11 | dcn_data_parallelism=2 dcn_fsdp_parallelism=1 ici_fsdp_parallelism=4 ici_tensor_parallelism=1 collect_stack_trace=False
12 |
13 | # Retrieve checkpoint - sharded with DCN Data Parallelism + ICI FSDP + Tensor Parallelism
14 | python3 -m MaxText.train MaxText/configs/base.yml run_name=$RUN_NAME steps=102\
15 | metrics_file='restored_metrics.txt' base_output_directory=$OUTPUT_PATH dataset_path=$DATASET_PATH\
16 | dcn_data_parallelism=2 dcn_fsdp_parallelism=1 ici_fsdp_parallelism=2 ici_tensor_parallelism=2 collect_stack_trace=False
17 |
18 | python3 end_to_end/tpu/eval_assert.py checkpoint_save_restore metrics.txt learning/loss
19 |
--------------------------------------------------------------------------------
/end_to_end/tpu/test_determinism.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -ex
3 |
4 | RUN_NAME=${1}_$(date +%Y-%m-%d-%H)
5 | OUTPUT_PATH=${2}
6 | DATASET_PATH=${3}
7 | DATASET_TYPE=${4}
8 |
9 | if [ "$DATASET_TYPE" == "grain" ]
10 | then
11 | EVAL_METRICS=grain_checkpoint_save_restore
12 | echo "Using grain dataset type"
13 | echo "Mounting $DATASET_PATH to /tmp/gcsfuse/"
14 | bash setup_gcsfuse.sh DATASET_GCS_BUCKET=$DATASET_PATH MOUNT_PATH=/tmp/gcsfuse/
15 | DATASET_PATH=/tmp/gcsfuse/
16 | CMD_DATA=" dataset_type=grain grain_train_files=/tmp/gcsfuse/array-record/c4/en/3.0.1/c4-train.array_record*"
17 | fi
18 |
19 | #Train
20 | CMD1="python3 -m MaxText.train MaxText/configs/base.yml run_name=${RUN_NAME}_1 steps=5 metrics_file=run_1_metrics.txt\
21 | enable_checkpointing=False enable_data_shuffling=True enable_dropout=False base_output_directory=$OUTPUT_PATH dataset_path=$DATASET_PATH"
22 | CMD1+=$CMD_DATA
23 |
24 |
25 | CMD2="python3 -m MaxText.train MaxText/configs/base.yml run_name=${RUN_NAME}_2 steps=5 metrics_file=run_2_metrics.txt\
26 | enable_checkpointing=False enable_data_shuffling=True enable_dropout=False base_output_directory=$OUTPUT_PATH dataset_path=$DATASET_PATH"
27 | CMD2+=$CMD_DATA
28 |
29 | $CMD1
30 | $CMD2
31 | python3 end_to_end/tpu/eval_assert.py determinism metrics.txt learning/loss
32 |
--------------------------------------------------------------------------------
/end_to_end/tpu/test_dpo.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | set -xe
4 |
5 | RUN_NAME=dpo_$(date +%Y-%m-%d-%H-%M-%S)
6 |
7 | # get latest converted Gemma2 2B checkpoint from internal GCS bucket
8 | export GEMMA_2B_CKPT_PATH=$(gcloud storage ls gs://maxtext-gemma/gemma2/2b | sort -r | head -1)
9 | LOGS="gs://maxtext-external/logs"
10 |
11 | # tfds pipeline
12 | python3 -m MaxText.train MaxText/configs/dpo.yml tokenizer_path=assets/tokenizer.gemma \
13 | run_name="$RUN_NAME-tfds" model_name=gemma2-2b base_output_directory=${LOGS} \
14 | load_parameters_path=${GEMMA_2B_CKPT_PATH}/0/items \
15 | per_device_batch_size=0.5 allow_split_physical_axes=True \
16 | ici_data_parallelism=2 ici_tensor_parallelism=2 ici_fsdp_parallelism=1
17 |
18 | # grain pipeline
19 | mkdir -p /tmp/anthropic_rlhf || true
20 | gcloud storage cp -r gs://maxtext-dataset/dpo/anthropic_rlhf/array_record /tmp/anthropic_rlhf
21 | python3 -m MaxText.train MaxText/configs/dpo.yml tokenizer_path=assets/tokenizer.gemma \
22 | run_name="$RUN_NAME-grain" model_name=gemma2-2b base_output_directory=${LOGS} \
23 | load_parameters_path=${GEMMA_2B_CKPT_PATH}/0/items \
24 | dataset_type=grain grain_worker_count=16 \
25 | grain_train_files='/tmp/anthropic_rlhf/array_record/anthropic_rlhf_tfds-train.array_record*' \
26 | grain_eval_files='/tmp/anthropic_rlhf/array_record/anthropic_rlhf_tfds-test.array_record*' \
27 | per_device_batch_size=0.5 allow_split_physical_axes=True \
28 | ici_data_parallelism=2 ici_tensor_parallelism=2 ici_fsdp_parallelism=1
29 |
30 | # hf pipeline
31 | python3 -m MaxText.train MaxText/configs/dpo.yml tokenizer_path='google/gemma-2-2b-it' \
32 | run_name="$RUN_NAME-grain" model_name=gemma2-2b base_output_directory=${LOGS} \
33 | load_parameters_path=${GEMMA_2B_CKPT_PATH}/0/items \
34 | dataset_type=hf hf_access_token=$HF_TOKEN hf_path='Anthropic/hh-rlhf' \
35 | per_device_batch_size=0.5 allow_split_physical_axes=True ici_tensor_parallelism=2 \
36 | ici_data_parallelism=2 ici_tensor_parallelism=2 ici_fsdp_parallelism=1
37 |
--------------------------------------------------------------------------------
/end_to_end/tpu/test_gpt3.sh:
--------------------------------------------------------------------------------
1 | set -euox pipefail
2 |
3 | TIMESTAMP=$(date +%Y%m%d-%H%M)
4 | export PAXML_CKPT_PATH=gs://maxtext-gpt3/ckpt_test/paxml/checkpoints/checkpoint_00000000/state
5 | export OUTPUT_PATH=gs://maxtext-gpt3/tests
6 | export RUN_NAME=test_${TIMESTAMP}
7 |
8 | # convert gpt3-52k model
9 | python3 -m MaxText.convert_gpt3_ckpt_from_paxml --paxml-ckpt-path=${PAXML_CKPT_PATH} --maxtext-model-name=gpt3-52k --run-name=${RUN_NAME} --base-output-directory=${OUTPUT_PATH}
10 |
11 | # Run gpt3-52k with the converted ckpt
12 | python3 -m MaxText.train MaxText/configs/base.yml run_name=${RUN_NAME} model_name=gpt3-52k\
13 | steps=10 per_device_batch_size=6 enable_checkpointing=true async_checkpointing=false\
14 | remat_policy=full max_target_length=2048 base_output_directory=${OUTPUT_PATH}\
15 | dataset_type=synthetic
16 |
--------------------------------------------------------------------------------
/end_to_end/tpu/test_tflops.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -ex
3 |
4 | USER=${1}
5 | TFLOP_THRESHOLD=${2}
6 | OUTPUT_PATH=${3}
7 | DATASET_PATH=${4}
8 |
9 |
10 | if [ -z ${5} ]
11 | then
12 | RUN_NAME=${USER}_$(date +%Y-%m-%d-%H-%M-%S)
13 | else
14 | RUN_NAME=${5}_$(date +%Y-%m-%d-%H)
15 | fi
16 |
17 | #Train
18 | python3 -m MaxText.train MaxText/configs/base.yml run_name=$RUN_NAME\
19 | steps=150 reuse_example_batch=1 remat_policy='full' profiler=xplane enable_checkpointing=False metrics_file='metrics.txt'\
20 | base_output_directory=$OUTPUT_PATH dataset_path=$DATASET_PATH log_period=150
21 |
22 | python3 end_to_end/tpu/eval_assert.py metrics_average metrics.txt $TFLOP_THRESHOLD perf/per_device_tflops_per_sec
23 |
--------------------------------------------------------------------------------
/end_to_end/tpu/test_tflops_16b_params.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | echo "Running test_tflops_16b_params.sh"
3 |
4 | # Command Flags:
5 | # OUTPUT_PATH (Required, unless base_output_directory is already set in base.yml)
6 | # DATASET_PATH (Required, unless dataset_path is already set in base.yml)
7 | # RUN_NAME (Required, unless run_name is already set in base.yml or running with XPK/GKE)
8 | # PLATFORM (Optional, can be "gke" or "gce", default is "gce")
9 | # TFLOP_THRESHOLD (Optional, default is 0 )
10 | #
11 | # Example to invoke this script:
12 | # bash end_to_end/tpu/test_tflops_16b_params.sh RUN_NAME=""" OUTPUT_PATH="gs://" DATASET_PATH="gs://" PLATFORM="gke" TFLOP_THRESHOLD=0
13 |
14 | # Stop execution if any command exits with error
15 | set -ex
16 |
17 | export TFLOP_THRESHOLD=0
18 | export PLATFORM="gce"
19 |
20 | # Set environment variables
21 | for ARGUMENT in "$@"; do
22 | IFS='=' read -r KEY VALUE <<< "$ARGUMENT"
23 | export "$KEY"="$VALUE"
24 | done
25 |
26 | # Set up network optimizations
27 | bash preflight.sh PLATFORM=$PLATFORM
28 |
29 | # Train
30 | export LIBTPU_INIT_ARGS="--xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"
31 | python3 -m MaxText.train MaxText/configs/base.yml run_name=$RUN_NAME\
32 | steps=150 per_device_batch_size=6 enable_checkpointing=false remat_policy=full\
33 | max_target_length=2048 metrics_file='metrics.txt' base_output_directory=$OUTPUT_PATH\
34 | dataset_path=$DATASET_PATH log_period=150 global_parameter_scale=16
35 |
36 | # Assert TFLOP/s
37 | python3 end_to_end/tpu/eval_assert.py metrics_average metrics.txt $TFLOP_THRESHOLD perf/per_device_tflops_per_sec
38 |
--------------------------------------------------------------------------------
/end_to_end/tpu/test_tflops_32b_params.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | echo "Running test_tflops_32b_params.sh"
3 |
4 | # Command Flags:
5 | # OUTPUT_PATH (Required, unless base_output_directory is already set in base.yml)
6 | # DATASET_PATH (Required, unless dataset_path is already set in base.yml)
7 | # RUN_NAME (Required, unless run_name is already set in base.yml or running with XPK/GKE)
8 | # PLATFORM (Optional, can be "gke" or "gce", default is "gce")
9 | # TFLOP_THRESHOLD (Optional, default is 0 )
10 | #
11 | # Example to invoke this script:
12 | # bash end_to_end/tpu/test_tflops_32b_params.sh RUN_NAME=""" OUTPUT_PATH="gs://" DATASET_PATH="gs://" PLATFORM="gke" TFLOP_THRESHOLD=0
13 |
14 | # Stop execution if any command exits with error
15 | set -ex
16 |
17 | export TFLOP_THRESHOLD=0
18 | export PLATFORM="gce"
19 |
20 | # Set environment variables
21 | for ARGUMENT in "$@"; do
22 | IFS='=' read -r KEY VALUE <<< "$ARGUMENT"
23 | export "$KEY"="$VALUE"
24 | done
25 |
26 | # Set up network optimizations
27 | bash preflight.sh PLATFORM=$PLATFORM
28 |
29 | # Train
30 | export LIBTPU_INIT_ARGS="--xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"
31 | python3 -m MaxText.train MaxText/configs/base.yml run_name=$RUN_NAME\
32 | steps=150 per_device_batch_size=4 enable_checkpointing=false remat_policy=full\
33 | max_target_length=2048 metrics_file='metrics.txt' base_output_directory=$OUTPUT_PATH\
34 | dataset_path=$DATASET_PATH log_period=150 global_parameter_scale=32
35 |
36 | # Assert TFLOP/s
37 | python3 end_to_end/tpu/eval_assert.py metrics_average metrics.txt $TFLOP_THRESHOLD perf/per_device_tflops_per_sec
38 |
--------------------------------------------------------------------------------
/end_to_end/tpu/test_tflops_64b_params.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | echo "Running test_tflops_64b_params.sh"
3 |
4 | # Command Flags:
5 | # OUTPUT_PATH (Required, unless base_output_directory is already set in base.yml)
6 | # DATASET_PATH (Required, unless dataset_path is already set in base.yml)
7 | # RUN_NAME (Required, unless run_name is already set in base.yml or running with XPK/GKE)
8 | # PLATFORM (Optional, can be "gke" or "gce", default is "gce")
9 | # TFLOP_THRESHOLD (Optional, default is 0 )
10 | #
11 | # Example to invoke this script:
12 | # bash end_to_end/tpu/test_tflops_64b_params.sh RUN_NAME=""" OUTPUT_PATH="gs://" DATASET_PATH="gs://" PLATFORM="gke" TFLOP_THRESHOLD=0
13 |
14 | # Stop execution if any command exits with error
15 | set -ex
16 |
17 | export TFLOP_THRESHOLD=0
18 | export PLATFORM="gce"
19 |
20 | # Set environment variables
21 | for ARGUMENT in "$@"; do
22 | IFS='=' read -r KEY VALUE <<< "$ARGUMENT"
23 | export "$KEY"="$VALUE"
24 | done
25 |
26 | # Set up network optimizations
27 | bash preflight.sh PLATFORM=$PLATFORM
28 |
29 | # Train
30 | export LIBTPU_INIT_ARGS="--xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"
31 | python3 -m MaxText.train MaxText/configs/base.yml run_name=$RUN_NAME\
32 | steps=150 per_device_batch_size=2 enable_checkpointing=false remat_policy=full\
33 | max_target_length=2048 metrics_file='metrics.txt' base_output_directory=$OUTPUT_PATH\
34 | dataset_path=$DATASET_PATH log_period=150 global_parameter_scale=64
35 |
36 | # Assert TFLOP/s
37 | python3 end_to_end/tpu/eval_assert.py metrics_average metrics.txt $TFLOP_THRESHOLD perf/per_device_tflops_per_sec
38 |
--------------------------------------------------------------------------------
/end_to_end/tpu/test_vocab_creation.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -ex
3 |
4 | RUN_NAME=${1}_$(date +%Y-%m-%d-%H)
5 | OUTPUT_PATH=${2}
6 | DATASET_PATH=${3}
7 | VOCAB_PATH=$OUTPUT_PATH/vocab_test_creation_$RUN_NAME
8 |
9 |
10 | #Train
11 | python3 -m MaxText.train MaxText/configs/base.yml run_name=$RUN_NAME steps=5 enable_checkpointing=False\
12 | base_output_directory=$OUTPUT_PATH dataset_path=$DATASET_PATH tokenizer_path=$VOCAB_PATH
13 |
14 | python3 end_to_end/tpu/eval_assert.py vocab_creation $VOCAB_PATH
15 |
--------------------------------------------------------------------------------
/getting_started/GCP_Workload_Observability.md:
--------------------------------------------------------------------------------
1 | # Enable GCP Workload Observabiltiy
2 | This guide provides an overview on how to enable GCP workload observability for your MaxText workload.
3 |
4 | ## Overview
5 | Google offers a monitoring and alerting feature that is well suited for critical MaxText workloads sensitive to infrastructure changes.
6 | Once enabled, metrics will be automatically sent to [Cloud Monarch](https://research.google/pubs/monarch-googles-planet-scale-in-memory-time-series-database/) for monitoring.
7 | If a metric hits its pre-defined threshold, the Google Cloud on-call team will be alerted to see if any action is needed.
8 |
9 | The feature currently supports heartbeat and performance (training step time in seconds) metrics. In the near future, support for the goodput metric will also be added.
10 | Users should work with their Customer Engineer (CE) and the Google team to define appropriate thresholds for the performance metrics.
11 |
12 | This guide layouts how to enable the feature for your MaxText workload.
13 |
14 | ## Enabling GCP Workload Observabiltiy
15 | User can control which metric they want to report via config:
16 |
17 | ### Heartbeat metric
18 | - This metric will be a boolean flag.
19 | - To turn on this metric, set `report_heartbeat_metric_for_gcp_monitoring` to `True`
20 | - To control the frequency of heartbeat reporting (default is every 5 seconds), set `heartbeat_reporting_interval_in_seconds` to your desired value.
21 |
22 | ### Performance metric
23 | - This metric will be a double, capturing the training step time in seconds.
24 | - To turn on this metric, set `report_performance_metric_for_gcp_monitoring` to `True`
25 |
26 | For an example, please refer to [base.yml](../MaxText/configs/base.yml).
--------------------------------------------------------------------------------
/getting_started/Run_Llama2.md:
--------------------------------------------------------------------------------
1 |
16 |
17 | ## About Llama2
18 |
19 | MaxText supports [Llama2](https://llama.meta.com/llama2) pretraining, finetuning and decoding for its 7B and 70B flavors. To get started on decoding and finetuning of Llama2, you will first need to download weights along with its tokenizer from [Meta](https://llama.meta.com/llama-downloads).
20 |
21 | The file [test_llama2_7b.sh](https://github.com/google/maxtext/blob/main/end_to_end/tpu/llama2/7b/test_llama2_7b.sh) provides details on how to convert the PyTorch weights in orbax checkpoint format, and thereafter use it for running decoding and finetuning. [test_llama2_7b.sh](https://github.com/google/maxtext/blob/main/end_to_end/tpu/llama2/7b/test_llama2_7b.sh) also shows how to run pretraining and also how to run decoding on the finetuned model checkpoint.
22 |
23 | ### MaxText supports pretraining and finetuning with high performance.
24 |
25 | Model Flop utilization for training on v5e and v5p and v4 TPUs with MaxText.
26 |
27 |
28 | | Model | v4-128 (bf16) | v5p-128 (bf16) | v5e-256 (bf16) |
29 | | ---------- | -------------- | -------------- | -------------- |
30 | | Llama2-70b | 57% | 65% | 57% |
31 |
--------------------------------------------------------------------------------
/maxtext_custom_wheels.Dockerfile:
--------------------------------------------------------------------------------
1 | ARG BASEIMAGE=maxtext_base_image
2 | FROM $BASEIMAGE
3 |
4 | # Requires wheels be in /deps. This means any custom wheels should be placed
5 | # in the maxtext directory.
6 | RUN python3 -m pip install --force-reinstall /deps/*.whl
7 |
--------------------------------------------------------------------------------
/maxtext_db_dependencies.Dockerfile:
--------------------------------------------------------------------------------
1 | # syntax=docker/dockerfile:experimental
2 | # Copy benchmark-db
3 | FROM gcr.io/tpu-prod-env-one-vm/benchmark-db:2025-02-14
4 |
5 | # Install system dependencies
6 | RUN apt-get update && apt-get install -y curl gnupg
7 |
8 | # Add the Google Cloud SDK package repository
9 | RUN echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] https://packages.cloud.google.com/apt cloud-sdk main" | tee -a /etc/apt/sources.list.d/google-cloud-sdk.list
10 | RUN curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | apt-key --keyring /usr/share/keyrings/cloud.google.gpg add -
11 |
12 | # Install the Google Cloud SDK
13 | RUN apt-get update && apt-get install -y google-cloud-sdk
14 |
15 | # Set the default Python version to 3.10
16 | RUN update-alternatives --install /usr/bin/python3 python3 /usr/local/bin/python3.10 1
17 |
18 | # Set environment variables for Google Cloud SDK and Python 3.10
19 | ENV PATH="/usr/local/google-cloud-sdk/bin:/usr/local/bin/python3.10:${PATH}"
20 |
21 | # Set environment variables via build arguments
22 | ARG MODE
23 | ENV ENV_MODE=$MODE
24 |
25 | ARG JAX_VERSION
26 | ENV ENV_JAX_VERSION=$JAX_VERSION
27 |
28 | ARG LIBTPU_GCS_PATH
29 | ENV ENV_LIBTPU_GCS_PATH=$LIBTPU_GCS_PATH
30 |
31 | ARG DEVICE
32 | ENV ENV_DEVICE=$DEVICE
33 |
34 | RUN mkdir -p /deps
35 |
36 | # Set the working directory in the container
37 | WORKDIR /deps
38 |
39 | # Copy setup files and dependency files separately for better caching
40 | COPY setup.sh ./
41 | COPY constraints_gpu.txt requirements.txt requirements_with_jax_ai_image.txt ./
42 |
43 | # Install dependencies - these steps are cached unless the copied files change
44 | RUN echo "Running command: bash setup.sh MODE=$ENV_MODE JAX_VERSION=$ENV_JAX_VERSION LIBTPU_GCS_PATH=${ENV_LIBTPU_GCS_PATH} DEVICE=${ENV_DEVICE}"
45 | RUN --mount=type=cache,target=/root/.cache/pip bash setup.sh MODE=${ENV_MODE} JAX_VERSION=${ENV_JAX_VERSION} LIBTPU_GCS_PATH=${ENV_LIBTPU_GCS_PATH} DEVICE=${ENV_DEVICE}
46 |
47 | # Now copy the remaining code (source files that may change frequently)
48 | COPY . .
49 |
--------------------------------------------------------------------------------
/maxtext_dependencies.Dockerfile:
--------------------------------------------------------------------------------
1 | # syntax=docker/dockerfile:experimental
2 | # Use Python 3.10 as the base image
3 | FROM python:3.10-slim-bullseye
4 |
5 | # Install system dependencies
6 | RUN apt-get update && apt-get install -y curl gnupg
7 |
8 | # Add the Google Cloud SDK package repository
9 | RUN echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] https://packages.cloud.google.com/apt cloud-sdk main" | tee -a /etc/apt/sources.list.d/google-cloud-sdk.list
10 | RUN curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | apt-key --keyring /usr/share/keyrings/cloud.google.gpg add -
11 |
12 | # Install the Google Cloud SDK
13 | RUN apt-get update && apt-get install -y google-cloud-sdk
14 |
15 | # Set the default Python version to 3.10
16 | RUN update-alternatives --install /usr/bin/python3 python3 /usr/local/bin/python3.10 1
17 |
18 | # Set environment variables for Google Cloud SDK and Python 3.10
19 | ENV PATH="/usr/local/google-cloud-sdk/bin:/usr/local/bin/python3.10:${PATH}"
20 |
21 | # Set environment variables via build arguments
22 | ARG MODE
23 | ENV ENV_MODE=$MODE
24 |
25 | ARG JAX_VERSION
26 | ENV ENV_JAX_VERSION=$JAX_VERSION
27 |
28 | ARG LIBTPU_GCS_PATH
29 | ENV ENV_LIBTPU_GCS_PATH=$LIBTPU_GCS_PATH
30 |
31 | ARG DEVICE
32 | ENV ENV_DEVICE=$DEVICE
33 |
34 | RUN mkdir -p /deps
35 |
36 | # Set the working directory in the container
37 | WORKDIR /deps
38 |
39 | # Copy setup files and dependency files separately for better caching
40 | COPY setup.sh ./
41 | COPY constraints_gpu.txt requirements.txt requirements_with_jax_ai_image.txt ./
42 |
43 | # Install dependencies - these steps are cached unless the copied files change
44 | RUN echo "Running command: bash setup.sh MODE=$ENV_MODE JAX_VERSION=$ENV_JAX_VERSION LIBTPU_GCS_PATH=${ENV_LIBTPU_GCS_PATH} DEVICE=${ENV_DEVICE}"
45 | RUN --mount=type=cache,target=/root/.cache/pip bash setup.sh MODE=${ENV_MODE} JAX_VERSION=${ENV_JAX_VERSION} LIBTPU_GCS_PATH=${ENV_LIBTPU_GCS_PATH} DEVICE=${ENV_DEVICE}
46 |
47 | # Now copy the remaining code (source files that may change frequently)
48 | COPY . .
49 |
--------------------------------------------------------------------------------
/maxtext_gpu_dependencies.Dockerfile:
--------------------------------------------------------------------------------
1 | # syntax=docker/dockerfile:experimental
2 | ARG BASEIMAGE=ghcr.io/nvidia/jax:base
3 | FROM $BASEIMAGE
4 |
5 | # Stopgaps measure to circumvent gpg key setup issue.
6 | RUN echo "deb [trusted=yes] https://developer.download.nvidia.com/devtools/repos/ubuntu2204/amd64/ /" > /etc/apt/sources.list.d/devtools-ubuntu2204-amd64.list
7 |
8 | # Install dependencies for adjusting network rto
9 | RUN apt-get update && apt-get install -y iproute2 ethtool lsof
10 |
11 | # Install DNS util dependencies
12 | RUN apt-get install -y dnsutils
13 |
14 | # Add the Google Cloud SDK package repository
15 | RUN echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] https://packages.cloud.google.com/apt cloud-sdk main" | tee -a /etc/apt/sources.list.d/google-cloud-sdk.list
16 | RUN curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | apt-key --keyring /usr/share/keyrings/cloud.google.gpg add -
17 |
18 | # Install the Google Cloud SDK
19 | RUN apt-get update && apt-get install -y google-cloud-sdk
20 |
21 | # Set environment variables for Google Cloud SDK
22 | ENV PATH="/usr/local/google-cloud-sdk/bin:${PATH}"
23 |
24 | # Upgrade libcusprase to work with Jax
25 | RUN apt-get update && apt-get install -y libcusparse-12-6
26 |
27 | ARG MODE
28 | ENV ENV_MODE=$MODE
29 |
30 | ARG JAX_VERSION
31 | ENV ENV_JAX_VERSION=$JAX_VERSION
32 |
33 | ARG DEVICE
34 | ENV ENV_DEVICE=$DEVICE
35 |
36 | RUN mkdir -p /deps
37 |
38 | # Set the working directory in the container
39 | WORKDIR /deps
40 |
41 | # Copy setup files and dependency files separately for better caching
42 | COPY setup.sh ./
43 | COPY constraints_gpu.txt requirements.txt requirements_with_jax_ai_image.txt ./
44 |
45 | # Install dependencies - these steps are cached unless the copied files change
46 | RUN echo "Running command: bash setup.sh MODE=$ENV_MODE JAX_VERSION=$ENV_JAX_VERSION DEVICE=${ENV_DEVICE}"
47 | RUN --mount=type=cache,target=/root/.cache/pip bash setup.sh MODE=${ENV_MODE} JAX_VERSION=${ENV_JAX_VERSION} DEVICE=${ENV_DEVICE}
48 |
49 | # Now copy the remaining code (source files that may change frequently)
50 | COPY . .
51 |
--------------------------------------------------------------------------------
/maxtext_jax_ai_image.Dockerfile:
--------------------------------------------------------------------------------
1 | ARG JAX_AI_IMAGE_BASEIMAGE
2 |
3 | # JAX AI Base Image
4 | FROM $JAX_AI_IMAGE_BASEIMAGE
5 | ARG JAX_AI_IMAGE_BASEIMAGE
6 |
7 | ARG COMMIT_HASH
8 |
9 | ENV COMMIT_HASH=$COMMIT_HASH
10 |
11 | RUN mkdir -p /deps
12 |
13 | # Set the working directory in the container
14 | WORKDIR /deps
15 |
16 | # Copy setup files and dependency files separately for better caching
17 | COPY setup.sh ./
18 | COPY requirements.txt requirements_with_jax_ai_image.txt ./
19 |
20 |
21 | # For JAX AI tpu training images 0.4.37 AND 0.4.35
22 | # Orbax checkpoint installs the latest version of JAX,
23 | # but the libtpu version in the base image is older.
24 | # This version mismatch can cause compatibility issues
25 | # and break MaxText.
26 | # Upgrade libtpu version if using either of the old stable images
27 |
28 | ARG DEVICE
29 | ENV DEVICE=$DEVICE
30 |
31 | RUN if [ "$DEVICE" = "tpu" ] && ([ "$JAX_AI_IMAGE_BASEIMAGE" = "us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.4.37-rev1" ] || [ "$JAX_AI_IMAGE_BASEIMAGE" = "us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.4.35-rev1" ]); then \
32 | python3 -m pip install --no-cache-dir --upgrade jax[tpu]; fi
33 |
34 | # Install Maxtext requirements with Jax AI Image
35 | RUN apt-get update && apt-get install --yes && apt-get install --yes dnsutils
36 | # TODO(bvandermoon, parambole): Remove this when it's added to JAX AI Image
37 | RUN pip install google-cloud-monitoring
38 | RUN python3 -m pip install -r /deps/requirements_with_jax_ai_image.txt
39 |
40 | # Now copy the remaining code (source files that may change frequently)
41 | COPY . .
42 | RUN ls .
43 |
44 | # Run the script available in JAX AI base image to generate the manifest file
45 | RUN bash /jax-stable-stack/generate_manifest.sh PREFIX=maxtext COMMIT_HASH=$COMMIT_HASH
46 |
--------------------------------------------------------------------------------
/maxtext_libtpu_path.Dockerfile:
--------------------------------------------------------------------------------
1 | ARG BASEIMAGE=maxtext_base_image
2 | FROM $BASEIMAGE
3 |
4 | #FROM maxtext_base_image
5 | # Set the TPU_LIBRARY_PATH
6 | ENV TPU_LIBRARY_PATH='/root/custom_libtpu/libtpu.so'
7 |
8 | WORKDIR /deps
--------------------------------------------------------------------------------
/maxtext_runner.Dockerfile:
--------------------------------------------------------------------------------
1 | # syntax=docker.io/docker/dockerfile:1.7-labs
2 |
3 | ARG BASEIMAGE=maxtext_base_image
4 | FROM $BASEIMAGE
5 |
6 | #FROM maxtext_base_image
7 |
8 | # Set the working directory in the container
9 | WORKDIR /deps
10 |
11 | # Copy assets separately
12 | COPY assets assets/
13 | COPY MaxText/test_assets/ MaxText/test_assets/
14 |
15 | # Copy all files except assets from local workspace into docker container
16 | COPY --exclude=assets --exclude=MaxText/test_assets . .
17 |
--------------------------------------------------------------------------------
/pedagogical_examples/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2025 Google LLC
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | https://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
--------------------------------------------------------------------------------
/pedagogical_examples/non_spmd.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python3
2 |
3 | """
4 | Copyright 2023 Google LLC
5 |
6 | Licensed under the Apache License, Version 2.0 (the "License");
7 | you may not use this file except in compliance with the License.
8 | You may obtain a copy of the License at
9 |
10 | https://www.apache.org/licenses/LICENSE-2.0
11 |
12 | Unless required by applicable law or agreed to in writing, software
13 | distributed under the License is distributed on an "AS IS" BASIS,
14 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | See the License for the specific language governing permissions and
16 | limitations under the License.
17 | """
18 |
19 | """
20 | This programs demonstrates embarrassingly parallelizable non-SPMD computations in Jax, in this case by having each
21 | process_index run its own computation.
22 | The same approach can be extended for non-embarrassingly parallelizable computations.
23 | The simplest way to do that would be by running embarrassingly parallelizable computations on arbitrary submeshes,
24 | then using a `host_local_array_to_global_array` to reshard into a new global array.
25 | An important limitation of this approach is that we cannot overlap communication and computation between the different
26 | kernel calls.
27 | """
28 |
29 | import numpy as np
30 |
31 | import jax
32 | from jax.sharding import PartitionSpec
33 | from jax.sharding import Mesh
34 |
35 |
36 | # Notice this is jax.local_devices(), not jax.devices(). Hence each process (on TPUVMs, each VM) will run separate programs
37 | # on its mesh.
38 | mesh = Mesh(np.array(jax.local_devices()), ["data"])
39 | sharding = jax.sharding.NamedSharding(mesh, PartitionSpec(None))
40 | idx = jax.process_index()
41 |
42 |
43 | # Example step depends on idx which is different on each program
44 | def example_step():
45 | return idx * jax.numpy.ones((idx + 1))
46 |
47 |
48 | jit_func = jax.jit(
49 | example_step,
50 | out_shardings=sharding,
51 | )
52 |
53 | # pylint: disable=not-callable
54 | print(f"{idx=} -> {jit_func()=}")
55 |
--------------------------------------------------------------------------------
/preflight.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | echo "Running preflight.sh"
3 | # Command Flags:
4 | #
5 | # Example to invoke this script:
6 | # bash preflight.sh
7 |
8 | # Warning:
9 | # For any dependencies, please add them into `setup.sh` or `maxtext_dependencies.Dockerfile`.
10 | # You should not install any dependencies in this file.
11 |
12 | # Stop execution if any command exits with error
13 | set -e
14 |
15 | # Set environment variables
16 | for ARGUMENT in "$@"; do
17 | IFS='=' read -r KEY VALUE <<< "$ARGUMENT"
18 | export "$KEY"="$VALUE"
19 | done
20 |
21 | # Check if sudo is available
22 | if command -v sudo >/dev/null 2>&1; then
23 | # sudo is available, use it
24 | echo "running rto_setup.sh with sudo"
25 |
26 | # apply network settings.
27 | sudo bash rto_setup.sh
28 | else
29 | # sudo is not available, run the script without sudo
30 | echo "running rto_setup.sh without sudo"
31 |
32 | # apply network settings.
33 | bash rto_setup.sh
34 | fi
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | [build-system]
16 | requires = ["setuptools>=61.0"]
17 | build-backend = "setuptools.build_meta"
18 |
19 | [project]
20 | name = "MaxText"
21 | version = "0.1.0"
22 | description = "MaxText is a high performance, highly scalable, open-source LLM written in pure Python/Jax and targeting Google Cloud TPUs and GPUs for training and inference"
23 | readme = "README.md"
24 | requires-python = ">=3.10"
25 | license = {file = "LICENSE"}
26 | dynamic = ["dependencies"]
27 |
28 | [tool.setuptools.dynamic]
29 | dependencies = {file = ["requirements.txt"]}
30 |
31 | [tool.setuptools.packages.find]
32 | where = ["MaxText"]
33 | include = ["MaxText*"]
34 | exclude = ["MaxText.tests.*"]
35 |
--------------------------------------------------------------------------------
/pytest.ini:
--------------------------------------------------------------------------------
1 | # pytest.ini
2 | [pytest]
3 | testpaths =
4 | tests
5 | python_files = *_test.py *_tests.py
6 | addopts =
7 | -rf --import-mode=importlib --strict-markers
8 | --ignore=MaxText/tests/profiler_test.py
9 | --ignore=MaxText/tests/train_smoke_test.py
10 | --ignore=MaxText/tests/train_int8_smoke_test.py
11 | --ignore=MaxText/tests/train_gpu_smoke_test.py
12 | --ignore=MaxText/tests/train_using_ragged_dot_smoke_test.py
13 | markers =
14 | tpu_only: marks tests to be run on TPUs only
15 | gpu_only: marks tests to be run on GPUs only
16 | cpu_only: marks tests to be run on CPUs only
17 | integration_test: tests exercising larger portions of the system,
18 | including interactions with other systems like GCS,
19 | e.g., end_to_end tests
20 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | jax>=0.4.30
2 | jaxlib>=0.4.30
3 | orbax-checkpoint>=0.5.12
4 | absl-py
5 | array-record
6 | aqtp
7 | cloud-accelerator-diagnostics
8 | cloud-tpu-diagnostics
9 | datasets
10 | gcsfs
11 | google-cloud-aiplatform==1.61.0
12 | google-cloud-storage
13 | google-cloud-monitoring
14 | google-api-core
15 | google-api-python-client
16 | grain[parquet]>=0.2.6
17 | huggingface_hub
18 | flax>=v0.10.6
19 | jaxtyping
20 | ml-collections
21 | ml-goodput-measurement==0.0.10
22 | numpy
23 | optax
24 | protobuf==3.20.3
25 | pylint
26 | pytest
27 | pyink
28 | pre-commit
29 | pytype
30 | pillow>=11.1.0
31 | sentencepiece==0.2.0
32 | tensorflow-text>=2.13.0
33 | tensorflow>=2.13.0
34 | tensorflow-datasets
35 | tensorboardx>=2.6.2.2
36 | tensorboard-plugin-profile
37 | tiktoken
38 | transformers
39 | mlperf-logging@git+https://github.com/mlperf/logging.git
40 | google-jetstream@git+https://github.com/AI-Hypercomputer/JetStream.git
41 | jsonlines
42 | pathwaysutils==0.1.1
43 | omegaconf
44 |
--------------------------------------------------------------------------------
/requirements_with_jax_ai_image.txt:
--------------------------------------------------------------------------------
1 | # Requirements for Building the MaxText Docker Image
2 | # These requirements are additional to the dependencies present in the JAX AI base image.
3 | datasets
4 | grain[parquet]>=0.2.6
5 | orbax-checkpoint>=0.10.3
6 | pylint
7 | pytest
8 | pyink
9 | pre-commit
10 | protobuf==3.20.3
11 | pytype
12 | pillow>=11.1.0
13 | sentencepiece==0.1.97
14 | tensorflow-text>=2.13.0
15 | tensorflow-datasets
16 | tiktoken
17 | transformers
18 | mlperf-logging@git+https://github.com/mlperf/logging.git
19 | google-jetstream@git+https://github.com/AI-Hypercomputer/JetStream.git
20 | jsonlines
21 | pathwaysutils==0.1.1
22 | google-api-python-client
23 | omegaconf
24 | jaxtyping
25 |
--------------------------------------------------------------------------------
/rto_setup.sh:
--------------------------------------------------------------------------------
1 | echo "Running rto_setup.sh"
2 |
3 | # Stop execution if any command exits with error
4 | set -e
5 |
6 | echo "Adjust Network settings and apply non cache copy"
7 |
8 | # Disable slow start after idle
9 | sysctl net.ipv4.tcp_slow_start_after_idle=0
10 |
11 | # Disable metrics cache
12 | sysctl net.ipv4.tcp_no_metrics_save=1
13 |
14 | # Address rto_min issue with two default routing entries: screen/7RGgkiXkGXSeYF2
15 | route=$(ip route show | sed -n 1p)
16 | second_route=$(ip route show | sed -n 2p)
17 | if [[ "${second_route}" =~ ^default.* ]]; then
18 | modified_route=${route//" lock"/}
19 | ip route delete ${modified_route}
20 | fi
21 | route=$(ip route show | sed -n 1p)
22 | echo "route rto before change: $route"
23 | if [[ "${route}" =~ .*lock.*5ms.* ]]; then
24 | echo "${route}"
25 | else
26 | # shellcheck disable=SC2086
27 | ip route change $route rto_min 5ms
28 | fi
29 | route=$(ip route show | sed -n 1p)
30 | echo "route rto after change: $route"
31 |
32 | # Disable Cubic Hystart Ack-Train
33 | echo 2 > /sys/module/tcp_cubic/parameters/hystart_detect
34 |
35 | # Improve handling SYN burst
36 | echo 4096 > /proc/sys/net/core/somaxconn
37 | echo 4096 > /proc/sys/net/ipv4/tcp_max_syn_backlog
38 |
39 | # Disable MTU Discovery
40 | echo 0 > /proc/sys/net/ipv4/tcp_mtu_probing
41 |
42 | # Increase TCP Zerocopy control memory
43 | sysctl -w net.core.optmem_max=131072
44 |
45 | # Printing output of `ip route show`
46 | echo -e "\nPrinting output of 'ip route show':"
47 | ip route show
48 |
49 | first_line_res=$(ip route show | head -n 1)
50 | dev_name=$(echo "$first_line_res" | awk -F'[[:space:]]' '{ print $5 }')
51 | echo "dev_name=${dev_name}"
52 | ethtool -K "${dev_name}" tx-nocache-copy on
53 |
54 | echo "rto_setup.sh finished"
--------------------------------------------------------------------------------
/setup_with_retries.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | max_attempts=5
4 | current_attempt=1
5 |
6 | # Loop until setup succeeds or reaches the maximum number of attempts
7 | while [ $current_attempt -le $max_attempts ]; do
8 | echo "Attempt to run setup.sh $current_attempt:"
9 | bash setup.sh "$@"
10 |
11 | # Check the exit status of run_setup
12 | if [ $? -eq 0 ]; then
13 | echo "Success for running setup on attempt $current_attempt!"
14 | exit 0
15 | else
16 | echo "Failed to run setup on attempt $current_attempt."
17 | ((current_attempt++))
18 | sleep 5 # Short delay before next attempt
19 | fi
20 | done
21 |
22 | echo "All attempts to run setup failed. Exiting."
23 | exit 1
--------------------------------------------------------------------------------
/unit_test_and_lint.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Copyright 2023 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # https://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | python3 -m pylint $(git ls-files '*.py')
18 |
19 | python3 -m pytest --pyargs MaxText.tests
20 |
--------------------------------------------------------------------------------