├── .dockerignore ├── .github ├── CODEOWNERS ├── PULL_REQUEST_TEMPLATE.md └── workflows │ ├── AddLabel.yml │ ├── CPUTests.yml │ ├── RunTests.yml │ ├── UploadDockerImages.yml │ ├── build_and_upload_images.sh │ ├── build_upload_internal.yml │ ├── require-checklist.yml │ ├── run_tests_internal.yml │ └── utils │ └── setup_runner.sh ├── .gitignore ├── .pre-commit-config.yaml ├── .vscode ├── launch.json └── settings.json ├── AUTHORS ├── CONTRIBUTING.md ├── LICENSE ├── MaxText ├── __init__.py ├── accelerator_to_spec_map.py ├── benchmark_chunked_prefill.py ├── checkpointing.py ├── common_types.py ├── configs │ ├── README.md │ ├── a3 │ │ ├── llama_2_7b │ │ │ ├── 16vm.sh │ │ │ ├── 1vm.sh │ │ │ ├── 2vm.sh │ │ │ ├── 4vm.sh │ │ │ ├── 8vm.sh │ │ │ └── README.md │ │ └── llama_3.1_405b │ │ │ └── 128vm.sh │ ├── base.yml │ ├── dpo.yml │ ├── experimental │ │ ├── 1024b.sh │ │ ├── 128b.sh │ │ ├── 256b.sh │ │ ├── 32b.sh │ │ ├── 512b.sh │ │ └── 64b.sh │ ├── gpu_smoke_test.yml │ ├── inference.yml │ ├── inference_jetstream.yml │ ├── models │ │ ├── deepseek2-16b.yml │ │ ├── deepseek2-236b.yml │ │ ├── deepseek3-671b.yml │ │ ├── gemma-2b.yml │ │ ├── gemma-7b.yml │ │ ├── gemma2-27b.yml │ │ ├── gemma2-2b.yml │ │ ├── gemma2-9b.yml │ │ ├── gemma3-12b.yml │ │ ├── gemma3-27b.yml │ │ ├── gemma3-4b.yml │ │ ├── gpt3-175b.yml │ │ ├── gpt3-22b.yml │ │ ├── gpt3-52k.yml │ │ ├── gpt3-6b.yml │ │ ├── gpu │ │ │ ├── llama2_70b.yml │ │ │ ├── llama2_7b.yml │ │ │ ├── llama3.1_405b.yml │ │ │ ├── llama3_70b.yml │ │ │ ├── llama3_8b.yml │ │ │ ├── mixtral_8x1b.yml │ │ │ ├── mixtral_8x2b.yml │ │ │ └── mixtral_8x7b.yml │ │ ├── llama2-13b.yml │ │ ├── llama2-70b.yml │ │ ├── llama2-7b.yml │ │ ├── llama3-405b.yml │ │ ├── llama3-70b.yml │ │ ├── llama3-8b.yml │ │ ├── llama3.1-405b.yml │ │ ├── llama3.1-70b.yml │ │ ├── llama3.1-8b.yml │ │ ├── llama3.3-70b.yml │ │ ├── llama4-17b-128e.yml │ │ ├── llama4-17b-16e.yml │ │ ├── mistral-7b.yml │ │ ├── mixtral-8x22b.yml │ │ └── mixtral-8x7b.yml │ ├── quantization │ │ ├── README.md │ │ ├── dense_llm_subchannel.json │ │ ├── dense_llm_weight_only_scale.json │ │ ├── int4_weight_only.json │ │ └── int8_weight_only.json │ ├── sft.yml │ ├── tpu_smoke_test.yml │ ├── trillium │ │ ├── gemma2_27b.sh │ │ ├── gemma2_9b.sh │ │ ├── gemma3_27b.sh │ │ ├── gpt3_175b.sh │ │ ├── llama2_7b_4096.sh │ │ └── mixtral_8x7b.sh │ ├── v4 │ │ ├── 22b.sh │ │ ├── 52b.sh │ │ └── README.md │ ├── v5e │ │ ├── 128b.sh │ │ ├── 16b.sh │ │ ├── 32b.sh │ │ ├── 64b.sh │ │ ├── README.md │ │ ├── gpt3_175b.sh │ │ ├── llama2_13b.sh │ │ ├── llama2_70b.sh │ │ ├── llama2_70b_v5e-16.yml │ │ ├── llama2_7b.sh │ │ ├── llama3_405b_v5e-64.yml │ │ └── llama3_70b_v5e-16.yml │ ├── v5p │ │ ├── 1024b.sh │ │ ├── 128b.sh │ │ ├── 256b.sh │ │ ├── 32b.sh │ │ ├── 512b.sh │ │ ├── 64b.sh │ │ ├── README.md │ │ ├── gpt3_175b │ │ │ ├── gpt3_175b_base.sh │ │ │ ├── v5p_1024.sh │ │ │ ├── v5p_12288.sh │ │ │ ├── v5p_2048.sh │ │ │ ├── v5p_3072.sh │ │ │ ├── v5p_4096.sh │ │ │ └── v5p_8192.sh │ │ ├── llama2_70b.sh │ │ └── llama2_7b.sh │ └── v6e │ │ └── inference │ │ └── llama4_maverick_v6e-64.yml ├── convert_deepseek_ckpt.py ├── convert_deepseek_unscanned_ckpt.py ├── convert_gemma2_chkpt.py ├── convert_gemma3_chkpt.py ├── convert_gemma_chkpt.py ├── convert_gpt3_ckpt_from_paxml.py ├── decode.py ├── deepseek_fp8_to_bf16.py ├── elastic_train.py ├── experimental │ └── rl │ │ ├── grpo.yml │ │ ├── grpo_input_pipeline.py │ │ ├── grpo_trainer.py │ │ └── grpo_trainer_test.yml ├── gcp_workload_monitor.py ├── generate_distillation_data.py ├── generate_param_only_checkpoint.py ├── globals.py ├── inference │ ├── __init__.py │ ├── configs │ │ └── multi_host │ │ │ ├── disaggregation │ │ │ └── llama3_405b_v6e-16-16.yml │ │ │ └── interleaved │ │ │ ├── llama2_70b_v5e-16.yml │ │ │ ├── llama3_405b_v5e-64.yml │ │ │ └── llama3_70b_v5e-16.yml │ ├── decode_multi.py │ ├── gpu │ │ ├── README.md │ │ └── microbenchmark_llama2-70b_h100-8.sh │ ├── jetstream_pathways │ │ ├── Dockerfile │ │ ├── README.md │ │ └── jetstream_pathways_entrypoint.sh │ ├── kvcache.py │ ├── maxengine_server │ │ ├── Dockerfile │ │ ├── README.md │ │ └── maxengine_server_entrypoint.sh │ ├── page_manager.py │ ├── paged_attention.py │ └── paged_attention_kernel_v2.py ├── inference_microbenchmark.py ├── inference_microbenchmark_sweep.py ├── inference_mlperf │ ├── README.md │ ├── __init__.py │ ├── evaluate-accuracy-fast.py │ ├── evaluate-accuracy.py │ ├── gpu │ │ └── benchmarks_llama2-70b-h100_8.sh │ ├── llama_offline_run.sh │ ├── matmul │ │ ├── __init__.py │ │ ├── matmul_dtypes.py │ │ ├── matmul_sharding.py │ │ └── timing_util.py │ ├── mixtral_offline_run.sh │ ├── offline_inference.py │ ├── offline_mode.py │ ├── requirements.txt │ ├── trillium │ │ ├── __init__.py │ │ ├── benchmarks_llama2-70b-trillium_2x4.sh │ │ ├── microbenchmarks_llama2-70b-trillium_2x4.sh │ │ └── select_xla_flags.py │ ├── user.conf │ ├── user100.conf │ └── user5000.conf ├── inference_utils.py ├── input_pipeline │ ├── __init__.py │ ├── _distillation_data_processing.py │ ├── _grain_data_processing.py │ ├── _grain_tokenizer.py │ ├── _hf_data_processing.py │ ├── _input_pipeline_utils.py │ ├── _tfds_data_processing.py │ ├── _tfds_data_processing_c4_mlperf.py │ └── input_pipeline_interface.py ├── kernels │ ├── __init__.py │ ├── megablox │ │ ├── __init__.py │ │ ├── common.py │ │ ├── gmm.py │ │ └── ops.py │ └── ragged_attention.py ├── layers │ ├── __init__.py │ ├── attentions.py │ ├── deepseek.py │ ├── embeddings.py │ ├── gemma.py │ ├── gemma2.py │ ├── gemma3.py │ ├── gpt3.py │ ├── initializers.py │ ├── linears.py │ ├── llama2.py │ ├── llama4.py │ ├── mistral.py │ ├── mixtral.py │ ├── models.py │ ├── moe.py │ ├── multi_token_prediction.py │ ├── normalizations.py │ ├── pipeline.py │ ├── quantizations.py │ └── simple_layer.py ├── llama4_ckpt_unscanned.py ├── llama_ckpt_conversion_inference_only.py ├── llama_mistral_mixtral_orbax_to_hf.py ├── llama_or_mistral_ckpt.py ├── load_and_quantize_checkpoint.py ├── max_logging.py ├── max_utils.py ├── maxengine.py ├── maxengine_config.py ├── maxengine_server.py ├── maxtext_utils.py ├── metric_logger.py ├── multihost_dataloading.py ├── multimodal_utils.py ├── optimizers.py ├── prefill_packing.py ├── profiler.py ├── pyconfig.py ├── scratch_code │ ├── __init__.py │ ├── analyze_sharegpt.py │ ├── gemma_7b.sh │ ├── generate_grpo_golden_logits.py │ ├── generate_hf_golden_logits.py │ ├── generate_sft_golden_data.py │ ├── golden_gemma-2b_export.ipynb │ ├── golden_gemma2-27b_export-flax.ipynb │ ├── golden_gemma2-2b_export-flax.ipynb │ ├── golden_gemma2-2b_export.ipynb │ ├── golden_gemma2-9b_export-flax.ipynb │ ├── golden_gemma2-9b_export.ipynb │ ├── golden_llama2-70b_export.py │ ├── golden_llama2-7b_export.ipynb │ ├── golden_llama3-8b_export.ipynb │ ├── golden_llama3_1_export.py │ ├── golden_llama4_17b_16e_128e_export.ipynb │ ├── golden_mistral-7b_export.ipynb │ ├── golden_mixtral-8x22b_export.ipynb │ ├── golden_mixtral-8x7b_export.ipynb │ ├── mixtral-numerical-verification.ipynb │ ├── run_inference_microbenchmark.sh │ └── setup_transformer.sh ├── sequence_packing.py ├── sft_trainer.py ├── standalone_checkpointer.py ├── standalone_dataloader.py ├── test_assets │ ├── .gitignore │ ├── golden_data_deepseek_r1_distill_llama3.1-70b.jsonl │ ├── golden_data_deepseek_r1_distill_llama3.1_8b.jsonl │ ├── golden_data_gemma3_vit.jsonl │ ├── golden_data_grpo_default.jsonl │ ├── golden_data_sft_default.jsonl │ └── test_image.jpg ├── tests │ ├── __init__.py │ ├── aot_hlo_identical_script.sh │ ├── aot_hlo_identical_test.py │ ├── attention_test.py │ ├── check_llama4_layers.py │ ├── check_mla_vs_reference.py │ ├── context_parallelism_test.py │ ├── decode_tests.py │ ├── distillation_data_processing_test.py │ ├── elastic_train_test.py │ ├── forward_pass_logit_checker.py │ ├── gpt3_test.py │ ├── grain_data_processing_test.py │ ├── grpo_trainer_correctness_test.py │ ├── hf_checkpoint_conversion_checker.py │ ├── hf_checkpoint_conversion_test.py │ ├── hf_data_processing_test.py │ ├── inference │ │ ├── __init__.py │ │ ├── kvcache_test.py │ │ ├── page_manager_test.py │ │ ├── test_llama2_7b_bf16.sh │ │ └── test_llama2_7b_int8.sh │ ├── integration_tests │ │ ├── __init__.py │ │ ├── checkpoint_compatibility_test.py │ │ ├── checkpointing_test.py │ │ ├── generate_param_only_checkpoint_test.py │ │ ├── gradient_accumulation_test.py │ │ ├── grpo_correctness.py │ │ ├── inference_microbenchmark_smoke_test.py │ │ ├── sft_trainer_correctness_test.py │ │ ├── shmap_collective_matmul_test.py │ │ ├── standalone_dl_ckpt_test.py │ │ ├── train_tests.py │ │ └── vision_encoder_test.py │ ├── kernels_test.py │ ├── llama_test.py │ ├── max_utils_test.py │ ├── maxengine_test.py │ ├── maxtext_utils_test.py │ ├── model_test.py │ ├── moe_test.py │ ├── multi_token_prediction_test.py │ ├── multihost_dataloading_test.py │ ├── multimodal_utils_test.py │ ├── pipeline_parallelism_test.py │ ├── profiler_test.py │ ├── pyconfig_test.py │ ├── quantizations_test.py │ ├── sft_data_processing_test.py │ ├── simple_decoder_layer_test.py │ ├── state_dtypes_test.py │ ├── tfds_data_processing_test.py │ ├── tokenizer_test.py │ ├── train_compile_test.py │ ├── train_gpu_smoke_test.py │ ├── train_int8_smoke_test.py │ ├── train_smoke_test.py │ └── train_using_ragged_dot_smoke_test.py ├── tokenizer.py ├── train.py ├── train_compile.py ├── train_tokenizer.py ├── utils │ ├── __init__.py │ ├── gcs_utils.py │ └── lora_utils.py ├── vertex_tensorboard.py └── weight_inspector.py ├── PREFLIGHT.md ├── README.md ├── assets ├── tokenizer ├── tokenizer.gemma ├── tokenizer.gemma3 ├── tokenizer.llama2 ├── tokenizer.mistral-v1 ├── tokenizer.mistral-v3 └── tokenizer_llama3.tiktoken ├── benchmarks ├── Getting_Started_Benchmarking.md ├── __init__.py ├── benchmark_db_utils.py ├── benchmark_runner.py ├── command_utils.py ├── disruption_management │ ├── __init__.py │ ├── disruption_handler.py │ ├── disruption_manager.py │ ├── disruption_utils.py │ └── monitor.py ├── llama2_v6e-256_benchmarks.py ├── maxtext_trillium_model_configs.py ├── maxtext_v5e_model_configs.py ├── maxtext_v5p_model_configs.py ├── maxtext_xpk_runner.py ├── mmlu │ ├── __init__.py │ ├── mmlu_categories.py │ └── mmlu_eval.py ├── recipes │ ├── __init__.py │ ├── args_helper.py │ ├── pw_elastic_training_recipe.py │ ├── pw_long_running_recipe.py │ ├── pw_mcjax_benchmark_recipe.py │ ├── pw_mcjax_checkpoint_benchmark_recipe.py │ ├── pw_remote_python_recipe.py │ └── pw_suspend_resume.py ├── upload_metrics_to_bq.py ├── xla_flags_library.py └── xpk_configs.py ├── code_style.sh ├── constraints_gpu.txt ├── docker_build_dependency_image.sh ├── docker_upload_runner.sh ├── download_dataset.sh ├── end_to_end ├── gpu │ ├── a3 │ │ ├── test_convergence_125m_params.sh │ │ ├── test_convergence_1b_params.sh │ │ ├── test_gemma3_logits.sh │ │ └── test_llama2_7b.sh │ ├── mixtral │ │ └── test_8x7b.sh │ ├── test_collective_matmul_llama2_7b.sh │ ├── test_feature.py │ └── test_fp8_gemm_llama2_7b.sh ├── test_checkpoint_compatibility.sh ├── test_checkpointing.sh ├── test_generate_param_only_checkpoint.sh ├── test_jdi.sh ├── test_mtc_phase_2_save_path.sh ├── test_multi_tier_checkpointing.sh ├── test_profiler.py └── tpu │ ├── deepseek │ ├── Run_DeepSeek.md │ ├── v2-16b │ │ └── test_deepseek.sh │ └── v3-671b │ │ └── test_deepseek.sh │ ├── eval_assert.py │ ├── gemma │ ├── 2b │ │ └── test_gemma.sh │ ├── 7b │ │ ├── 1_test_gemma.sh │ │ └── 2_test_gemma.sh │ └── Run_Gemma.md │ ├── gemma2 │ ├── 27b │ │ ├── 1_test_gemma.sh │ │ └── 2_test_gemma.sh │ ├── 2b │ │ └── test_gemma2.sh │ └── 9b │ │ ├── 1_test_gemma.sh │ │ └── 2_test_gemma.sh │ ├── gemma3 │ ├── 12b │ │ └── test_gemma3.sh │ ├── 27b │ │ └── test_gemma3.sh │ ├── 4b │ │ └── test_gemma3.sh │ └── Run_Gemma3.md │ ├── llama2 │ ├── 13b │ │ ├── 1_test_llama2_13b.sh │ │ └── 2_test_llama2_13b.sh │ ├── 70b │ │ ├── 1_test_llama2_70b.sh │ │ └── 2_test_llama2_70b.sh │ └── 7b │ │ └── test_llama2_7b.sh │ ├── llama3.1 │ ├── 405b │ │ ├── 2_test_llama3.1_405b.sh │ │ └── 3_test_llama3.1_405b.sh │ ├── 70b │ │ ├── 1_test_llama3.1_70b.sh │ │ ├── 2_test_llama3.1_70b.sh │ │ └── 3_test_llama3.1_70b.sh │ └── 8b │ │ ├── 1_test_llama3.1_8b.sh │ │ ├── 2_test_llama3.1_8b.sh │ │ └── 3_test_llama3.1_8b.sh │ ├── llama3.3 │ └── 70b │ │ ├── 1_test_llama3.3_70b.sh │ │ └── 2_test_llama3.3_70b.sh │ ├── llama3 │ ├── 70b │ │ ├── 1_test_llama3_70b.sh │ │ └── 2_test_llama3_70b.sh │ └── 8b │ │ ├── 1_test_llama3_8b.sh │ │ └── 2_test_llama3_8b.sh │ ├── llama4 │ ├── 1_test_llama4.sh │ ├── 2_test_llama4.sh │ └── Run_Llama4.md │ ├── llama_finetuning_test.sh │ ├── mistral │ └── 7b │ │ └── test_mistral-7b.sh │ ├── mixtral │ ├── 8x22b │ │ ├── 1_test_mixtral.sh │ │ └── 2_test_mixtral.sh │ ├── 8x7b │ │ ├── 1_test_mixtral.sh │ │ └── 2_test_mixtral.sh │ └── Run_Mixtral.md │ ├── test_checkpoint_resharding.sh │ ├── test_convergence_1b_params.sh │ ├── test_decode.sh │ ├── test_decode_load_quantized_ckpt.sh │ ├── test_decode_save_quantized_ckpt.sh │ ├── test_determinism.sh │ ├── test_dpo.sh │ ├── test_gpt3.sh │ ├── test_sft_trainer.sh │ ├── test_tflops.sh │ ├── test_tflops_16b_params.sh │ ├── test_tflops_32b_params.sh │ ├── test_tflops_64b_params.sh │ └── test_vocab_creation.sh ├── getting_started ├── Data_Input_Perf.md ├── Data_Input_Pipeline.md ├── First_run.md ├── GCP_Workload_Observability.md ├── Knowledge_Distillation.md ├── Monitor_Goodput.md ├── Run_Llama2.md ├── Run_MaxText_via_multihost_job.md ├── Run_MaxText_via_multihost_runner.md ├── Run_MaxText_via_xpk.md ├── Sharding.md └── Use_Vertex_AI_Tensorboard.md ├── gpu_multi_process_run.sh ├── maxtext_custom_wheels.Dockerfile ├── maxtext_db_dependencies.Dockerfile ├── maxtext_dependencies.Dockerfile ├── maxtext_gpu_dependencies.Dockerfile ├── maxtext_jax_ai_image.Dockerfile ├── maxtext_libtpu_path.Dockerfile ├── maxtext_runner.Dockerfile ├── multihost_job.py ├── multihost_runner.py ├── pedagogical_examples ├── __init__.py ├── non_spmd.py ├── shardings.py └── shmap_collective_matmul.py ├── preflight.sh ├── pylintrc ├── pyproject.toml ├── pytest.ini ├── requirements.txt ├── requirements_with_jax_ai_image.txt ├── rto_setup.sh ├── setup.py ├── setup.sh ├── setup_gcsfuse.sh ├── setup_with_retries.sh └── unit_test_and_lint.sh /.dockerignore: -------------------------------------------------------------------------------- 1 | .git 2 | -------------------------------------------------------------------------------- /.github/CODEOWNERS: -------------------------------------------------------------------------------- 1 | # Changes in this file should match with requiredReviewers in file .github/workflows/AddLabel.yml 2 | * @gobbleturk @khatwanimohit @bvandermoon @vipannalla @RissyRan @richjames0 @gagika @shralex @yangyuwei @SurbhiJainUSC @hengtaoguo @A9isha @aireenmei 3 | 4 | # Features 5 | MaxText/experimental/rl @A9isha @khatwanimohit @gagika @richjames0 6 | MaxText/input_pipeline @aireenmei @SurbhiJainUSC @richjames0 7 | MaxText/kernels/megablox @RissyRan @michelle-yooh @gagika @richjames0 8 | MaxText/kernels/ragged_attention.py @patemotter @vipannalla @richjames0 9 | MaxText/layers/pipeline.py @gobbleturk @richjames0 10 | MaxText/layers/moe.py @RissyRan @michelle-yooh @gagika @richjames0 11 | MaxText/layers/multi_token_prediction.py @parambole @RissyRan @gagika @richjames0 12 | 13 | # Inference 14 | MaxText/tests/inference @vipannalla @mitalisi @gpolovets1 @mailvijayasingh @jrplatin @patemotter @lumosis @richjames0 15 | MaxText/inference @vipannalla @mitalisi @gpolovets1 @mailvijayasingh @jrplatin @patemotter @lumosis @richjames0 16 | MaxText/inference_mlperf @vipannalla @mitalisi @gpolovets1 @mailvijayasingh @jrplatin @patemotter @lumosis @richjames0 17 | 18 | # Dockerfiles and dependencies 19 | *.Dockerfile @bvandermoon @yangyuwei @parambole @richjames0 20 | *.txt @bvandermoon @yangyuwei @parambole @richjames0 21 | 22 | # Workflow files 23 | .github/workflows @gobbleturk @khatwanimohit @shralex @parambole @richjames0 24 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | # Description 2 | 3 | Start with a short description of what the PR does and how this is a change from 4 | the past. 5 | 6 | The rest of the description includes relevant details and context, examples: 7 | - why is this change being made, 8 | - the problem being solved and any relevant context, 9 | - why this is a good solution, 10 | - some information about the specific implementation, 11 | - shortcomings of the solution and possible future improvements. 12 | 13 | If the change fixes a bug or a Github issue, please include a link, e.g.,: 14 | FIXES: b/123456 15 | FIXES: #123456 16 | 17 | *Notice 1:* Once all tests pass, the "pull ready" label will automatically be assigned. 18 | This label is used for administrative purposes. Please do not add it manually. 19 | 20 | *Notice 2:* For external contributions, our settings currently require an approval from a MaxText maintainer to trigger CI tests. 21 | 22 | # Tests 23 | 24 | Please describe how you tested this change, and include any instructions and/or 25 | commands to reproduce. 26 | 27 | # Checklist 28 | 29 | Before submitting this PR, please make sure (put X in square brackets): 30 | - [ ] I have performed a self-review of my code. 31 | - [ ] I have necessary comments in my code, particularly in hard-to-understand areas. 32 | - [ ] I have run end-to-end tests tests and provided workload links above if applicable. 33 | - [ ] I have made or will make corresponding changes to the doc if needed. 34 | -------------------------------------------------------------------------------- /.github/workflows/require-checklist.yml: -------------------------------------------------------------------------------- 1 | name: Require Checklist 2 | on: 3 | pull_request: 4 | types: [opened, edited, synchronize] 5 | jobs: 6 | check_pr_body: 7 | runs-on: ubuntu-latest 8 | steps: 9 | - uses: mheap/require-checklist-action@v2 10 | with: 11 | requireChecklist: true # If this is true and there are no checklists detected, the action will fail 12 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/codespell-project/codespell 3 | rev: v2.2.4 4 | hooks: 5 | - id: codespell 6 | name: Running codespell for typos 7 | entry: codespell -w --skip="*.txt,pylintrc,.*,assets/*" . 8 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python.testing.pytestArgs": [], 3 | "python.testing.cwd": "${workspaceFolder}/MaxText", 4 | "python.testing.unittestEnabled": false, 5 | "python.testing.pytestEnabled": true 6 | } 7 | -------------------------------------------------------------------------------- /AUTHORS: -------------------------------------------------------------------------------- 1 | Google LLC 2 | 3 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. 4 | 5 | ## Before you begin 6 | 7 | ### Sign our Contributor License Agreement 8 | 9 | Contributions to this project must be accompanied by a 10 | [Contributor License Agreement](https://cla.developers.google.com/about) (CLA). 11 | You (or your employer) retain the copyright to your contribution; this simply 12 | gives us permission to use and redistribute your contributions as part of the 13 | project. 14 | 15 | If you or your current employer have already signed the Google CLA (even if it 16 | was for a different project), you probably don't need to do it again. 17 | 18 | Visit to see your current agreements or to 19 | sign a new one. 20 | 21 | ### Review our Community Guidelines 22 | 23 | This project follows 24 | [Google's Open Source Community Guidelines](https://opensource.google/conduct/). 25 | 26 | ## Contribution process 27 | 28 | ### Code Reviews 29 | 30 | All submissions, including submissions by project members, require review. We 31 | use GitHub pull requests for this purpose. Consult 32 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 33 | information on using pull requests. -------------------------------------------------------------------------------- /MaxText/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2023 Google LLC 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | __author__ = "Google LLC" 18 | __version__ = "2025.04.25" 19 | __description__ = ( 20 | "MaxText is a high performance, highly scalable, open-source LLM written in pure Python/Jax and " 21 | "targeting Google Cloud TPUs and GPUs for training and **inference." 22 | ) 23 | -------------------------------------------------------------------------------- /MaxText/configs/a3/llama_2_7b/16vm.sh: -------------------------------------------------------------------------------- 1 | echo "Running 16vm.sh" 2 | # Example command to invoke this script via XPK 3 | # python3 xpk/xpk.py workload create --cluster ${CLUSTER_NAME} \ 4 | # --workload ${WORKLOAD_NAME} --docker-image=gcr.io/supercomputer-testing/${LOCAL_IMAGE_NAME} \ 5 | # --device-type ${DEVICE_TYPE} --num-slices 16 --priority=high \ 6 | # --command "bash MaxText/configs/a3/llama_2_7b/16vm.sh" 7 | 8 | # Stop execution if any command exits with error 9 | set -e 10 | 11 | export OUTPUT_PATH="gs://maxtext-experiments-multipod" 12 | export RUN_NAME="llama-2-16vm-$(date +%Y-%m-%d-%H-%M)" 13 | export EXECUTABLE="train" 14 | 15 | # Set environment variables 16 | for ARGUMENT in "$@"; do 17 | IFS='=' read -r KEY VALUE <<< "$ARGUMENT" 18 | export "$KEY"="$VALUE" 19 | done 20 | 21 | export XLA_FLAGS="--xla_dump_to=$OUTPUT_PATH/$RUN_NAME/HLO_dumps/ 22 | --xla_gpu_enable_latency_hiding_scheduler=true 23 | --xla_gpu_enable_triton_gemm=false --xla_gpu_graph_level=0 24 | --xla_gpu_enable_highest_priority_async_stream=true 25 | --xla_gpu_all_reduce_combine_threshold_bytes=1073741824 --xla_gpu_all_gather_combine_threshold_bytes=134217728 26 | --xla_gpu_reduce_scatter_combine_threshold_bytes=134217728 --xla_gpu_enable_pipelined_all_gather=true 27 | --xla_gpu_enable_pipelined_reduce_scatter=true --xla_gpu_enable_pipelined_all_reduce=true 28 | --xla_gpu_enable_while_loop_double_buffering=true 29 | --xla_gpu_enable_all_gather_combine_by_dim=false --xla_gpu_enable_reduce_scatter_combine_by_dim=false 30 | --xla_disable_hlo_passes=rematerialization" 31 | 32 | # 16 nodes 33 | python3 -m MaxText.$EXECUTABLE MaxText/configs/models/gpu/llama2_7b.yml run_name=$RUN_NAME \ 34 | dcn_data_parallelism=16 ici_fsdp_parallelism=8 base_output_directory=$OUTPUT_PATH profiler=xplane 35 | -------------------------------------------------------------------------------- /MaxText/configs/a3/llama_2_7b/1vm.sh: -------------------------------------------------------------------------------- 1 | echo "Running 1vm.sh" 2 | 3 | # Example command to invoke this script via XPK 4 | # python3 xpk/xpk.py workload create --cluster ${CLUSTER_NAME} \ 5 | # --workload ${WORKLOAD_NAME} --docker-image=gcr.io/supercomputer-testing/${LOCAL_IMAGE_NAME} \ 6 | # --device-type ${DEVICE_TYPE} --num-slices 1 \ 7 | # --command "bash MaxText/configs/a3/llama_2_7b/1vm.sh" 8 | 9 | # Stop execution if any command exits with error 10 | set -e 11 | 12 | export OUTPUT_PATH="gs://maxtext-experiments-multipod" 13 | export RUN_NAME="llama-2-1vm-$(date +%Y-%m-%d-%H-%M)" 14 | 15 | # Set environment variables 16 | for ARGUMENT in "$@"; do 17 | IFS='=' read -r KEY VALUE <<< "$ARGUMENT" 18 | export "$KEY"="$VALUE" 19 | done 20 | 21 | export XLA_FLAGS="--xla_dump_to=$OUTPUT_PATH/$RUN_NAME/HLO_dumps/ 22 | --xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_triton_gemm=false 23 | --xla_gpu_graph_level=0 --xla_gpu_enable_highest_priority_async_stream=true 24 | --xla_gpu_all_reduce_combine_threshold_bytes=134217728 --xla_gpu_all_gather_combine_threshold_bytes=134217728 25 | --xla_gpu_reduce_scatter_combine_threshold_bytes=67108864 --xla_gpu_enable_pipelined_all_gather=true 26 | --xla_gpu_enable_pipelined_reduce_scatter=true --xla_gpu_enable_pipelined_all_reduce=true 27 | --xla_gpu_enable_while_loop_double_buffering=true 28 | --xla_gpu_enable_all_gather_combine_by_dim=false --xla_gpu_enable_reduce_scatter_combine_by_dim=false 29 | --xla_disable_hlo_passes=rematerialization" 30 | 31 | 32 | # 1 node, DATA_DP=1, ICI_FSDP=8 33 | python3 -m MaxText.train MaxText/configs/models/gpu/llama2_7b.yml run_name=$RUN_NAME \ 34 | dcn_data_parallelism=1 ici_fsdp_parallelism=8 base_output_directory=$OUTPUT_PATH profiler=xplane 35 | -------------------------------------------------------------------------------- /MaxText/configs/a3/llama_2_7b/2vm.sh: -------------------------------------------------------------------------------- 1 | echo "Running 2vm.sh" 2 | 3 | # Example command to invoke this script via XPK 4 | # python3 xpk/xpk.py workload create --cluster ${CLUSTER_NAME} \ 5 | # --workload ${WORKLOAD_NAME} --docker-image=gcr.io/supercomputer-testing/${LOCAL_IMAGE_NAME} \ 6 | # --device-type ${DEVICE_TYPE} --num-slices 2 \ 7 | # --command "bash MaxText/configs/a3/llama_2_7b/2vm.sh" 8 | 9 | # Stop execution if any command exits with error 10 | set -e 11 | 12 | export OUTPUT_PATH="gs://maxtext-experiments-multipod" 13 | export RUN_NAME="llama-2-2vm-$(date +%Y-%m-%d-%H-%M)" 14 | 15 | # Set environment variables 16 | for ARGUMENT in "$@"; do 17 | IFS='=' read -r KEY VALUE <<< "$ARGUMENT" 18 | export "$KEY"="$VALUE" 19 | done 20 | 21 | export XLA_FLAGS="--xla_dump_to=$OUTPUT_PATH/$RUN_NAME/HLO_dumps/ 22 | --xla_gpu_enable_latency_hiding_scheduler=true 23 | --xla_gpu_enable_triton_gemm=false --xla_gpu_graph_level=0 24 | --xla_gpu_enable_highest_priority_async_stream=true 25 | --xla_gpu_all_reduce_combine_threshold_bytes=67108864 --xla_gpu_all_gather_combine_threshold_bytes=134217728 26 | --xla_gpu_reduce_scatter_combine_threshold_bytes=67108864 --xla_gpu_enable_pipelined_all_gather=true 27 | --xla_gpu_enable_pipelined_reduce_scatter=true --xla_gpu_enable_pipelined_all_reduce=true 28 | --xla_gpu_enable_while_loop_double_buffering=true 29 | --xla_gpu_enable_all_gather_combine_by_dim=false --xla_gpu_enable_reduce_scatter_combine_by_dim=false 30 | --xla_disable_hlo_passes=rematerialization" 31 | 32 | 33 | # 2 nodes 34 | python3 -m MaxText.train MaxText/configs/models/gpu/llama2_7b.yml run_name=$RUN_NAME \ 35 | dcn_data_parallelism=2 ici_fsdp_parallelism=8 base_output_directory=$OUTPUT_PATH profiler=xplane 36 | -------------------------------------------------------------------------------- /MaxText/configs/a3/llama_2_7b/4vm.sh: -------------------------------------------------------------------------------- 1 | echo "Running 4vm.sh" 2 | # Example command to invoke this script via XPK 3 | # python3 xpk/xpk.py workload create --cluster ${CLUSTER_NAME} \ 4 | # --workload ${WORKLOAD_NAME} --docker-image=gcr.io/supercomputer-testing/${LOCAL_IMAGE_NAME} \ 5 | # --device-type ${DEVICE_TYPE} --num-slices 4 \ 6 | # --command "bash MaxText/configs/a3/llama_2_7b/4vm.sh" 7 | 8 | # Stop execution if any command exits with error 9 | set -e 10 | 11 | export OUTPUT_PATH="gs://maxtext-experiments-multipod" 12 | export RUN_NAME="llama-2-4vm-$(date +%Y-%m-%d-%H-%M)" 13 | 14 | # Set environment variables 15 | for ARGUMENT in "$@"; do 16 | IFS='=' read -r KEY VALUE <<< "$ARGUMENT" 17 | export "$KEY"="$VALUE" 18 | done 19 | 20 | export XLA_FLAGS="--xla_dump_to=$OUTPUT_PATH/$RUN_NAME/HLO_dumps/ 21 | --xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_triton_gemm=false 22 | --xla_gpu_graph_level=0 --xla_gpu_enable_highest_priority_async_stream=true 23 | --xla_gpu_all_reduce_combine_threshold_bytes=536870912 --xla_gpu_all_gather_combine_threshold_bytes=134217728 24 | --xla_gpu_reduce_scatter_combine_threshold_bytes=67108864 --xla_gpu_enable_pipelined_all_gather=true 25 | --xla_gpu_enable_pipelined_reduce_scatter=true --xla_gpu_enable_pipelined_all_reduce=true 26 | --xla_gpu_enable_while_loop_double_buffering=true 27 | --xla_gpu_enable_all_gather_combine_by_dim=false --xla_gpu_enable_reduce_scatter_combine_by_dim=false 28 | --xla_disable_hlo_passes=rematerialization" 29 | 30 | # 4 nodes 31 | python3 -m MaxText.train MaxText/configs/models/gpu/llama2_7b.yml run_name=$RUN_NAME \ 32 | dcn_data_parallelism=4 ici_fsdp_parallelism=8 base_output_directory=$OUTPUT_PATH profiler=xplane 33 | -------------------------------------------------------------------------------- /MaxText/configs/a3/llama_2_7b/8vm.sh: -------------------------------------------------------------------------------- 1 | echo "Running 8vm.sh" 2 | # Example command to invoke this script via XPK 3 | # python3 xpk/xpk.py workload create --cluster ${CLUSTER_NAME} \ 4 | # --workload ${WORKLOAD_NAME} --docker-image=gcr.io/supercomputer-testing/${LOCAL_IMAGE_NAME} \ 5 | # --device-type ${DEVICE_TYPE} --num-slices 8 \ 6 | # --command "bash MaxText/configs/a3/llama_2_7b/8vm.sh" 7 | 8 | # Stop execution if any command exits with error 9 | set -e 10 | 11 | export OUTPUT_PATH="gs://maxtext-experiments-multipod" 12 | export RUN_NAME="llama-2-8vm-$(date +%Y-%m-%d-%H-%M)" 13 | 14 | # Set environment variables 15 | for ARGUMENT in "$@"; do 16 | IFS='=' read -r KEY VALUE <<< "$ARGUMENT" 17 | export "$KEY"="$VALUE" 18 | done 19 | 20 | export XLA_FLAGS="--xla_dump_to=$OUTPUT_PATH/$RUN_NAME/HLO_dumps/ 21 | --xla_gpu_enable_latency_hiding_scheduler=true 22 | --xla_gpu_enable_triton_gemm=false --xla_gpu_graph_level=0 23 | --xla_gpu_enable_highest_priority_async_stream=true 24 | --xla_gpu_all_reduce_combine_threshold_bytes=1073741824 --xla_gpu_all_gather_combine_threshold_bytes=134217728 25 | --xla_gpu_reduce_scatter_combine_threshold_bytes=67108864 --xla_gpu_enable_pipelined_all_gather=true 26 | --xla_gpu_enable_pipelined_reduce_scatter=true --xla_gpu_enable_pipelined_all_reduce=true 27 | --xla_gpu_enable_while_loop_double_buffering=true 28 | --xla_gpu_enable_all_gather_combine_by_dim=false --xla_gpu_enable_reduce_scatter_combine_by_dim=false 29 | --xla_disable_hlo_passes=rematerialization" 30 | 31 | # 8 nodes 32 | python3 -m MaxText.train MaxText/configs/models/gpu/llama2_7b.yml run_name=$RUN_NAME \ 33 | dcn_data_parallelism=8 ici_fsdp_parallelism=8 base_output_directory=$OUTPUT_PATH profiler=xplane 34 | -------------------------------------------------------------------------------- /MaxText/configs/a3/llama_2_7b/README.md: -------------------------------------------------------------------------------- 1 | 16 | 17 | # High Performance Model Configs on A3 GPU 18 | Expected performance results for Llama2-7B model running on A3 GPU: 19 | 20 | 21 | ### Llama2-7B 22 | | Hardware | TFLOP/sec/chip | 23 | | ---------------------- | ---------------- | 24 | | 1x A3 (h100-80gb-8) | 492 | 25 | | 2x A3 (h100-80gb-8) | 422 | 26 | | 4x A3 (h100-80gb-8) | 407 | 27 | | 8x A3 (h100-80gb-8) | 409 | 28 | | 16x A3 (h100-80gb-8) | 375 | 29 | -------------------------------------------------------------------------------- /MaxText/configs/dpo.yml: -------------------------------------------------------------------------------- 1 | base_config: "base.yml" 2 | 3 | use_dpo: true 4 | train_data_columns: ['chosen', 'rejected'] 5 | eval_data_columns: ['chosen', 'rejected'] 6 | base_output_directory: 'gs://maxtext-external/logs' 7 | 8 | per_device_batch_size: 2.0 9 | steps: 10 10 | max_target_length: 512 11 | eval_interval: 5 # test eval once, in the middle of 10 training steps 12 | eval_steps: 2 13 | 14 | # TFDS Pipeline ---------------------- 15 | dataset_type: tfds 16 | dataset_path: 'gs://maxtext-dataset/dpo/anthropic_rlhf' 17 | dataset_name: 'tfds:1.0.0' 18 | eval_dataset_name: 'tfds:1.0.0' 19 | eval_split: 'test' 20 | 21 | # HF Pipeline ------------------------- 22 | hf_eval_split: 'test' 23 | 24 | gradient_clipping_threshold: 10.0 25 | learning_rate: 5.0e-7 26 | dpo_label_smoothing: 0.0 27 | dpo_beta: 0.1 28 | 29 | enable_goodput_recording: false 30 | monitor_goodput: false 31 | enable_checkpointing: true 32 | -------------------------------------------------------------------------------- /MaxText/configs/experimental/1024b.sh: -------------------------------------------------------------------------------- 1 | echo "Running 1024b.sh" 2 | # Example command to invoke this script 3 | # bash MaxText/configs/experimental/1024b.sh 4 | 5 | # Stop execution if any command exits with error 6 | set -e 7 | 8 | export OUTPUT_PATH="gs://maxtext-experiments-multipod" 9 | export DATASET_PATH="gs://maxtext-dataset/" 10 | 11 | # Set environment variables 12 | for ARGUMENT in "$@"; do 13 | IFS='=' read -r KEY VALUE <<< "$ARGUMENT" 14 | export "$KEY"="$VALUE" 15 | done 16 | 17 | # Use preflight.sh to set up env based on platform 18 | bash preflight.sh PLATFORM=$PLATFORM 19 | 20 | # Train 21 | export LIBTPU_INIT_ARGS="--xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true" 22 | python3 -m MaxText.train MaxText/configs/base.yml run_name=$RUN_NAME\ 23 | steps=20 per_device_batch_size=2 enable_checkpointing=false\ 24 | remat_policy=full global_parameter_scale=1024\ 25 | ici_fsdp_parallelism=-1 ici_tensor_parallelism=16\ 26 | max_target_length=2048 base_output_directory=$OUTPUT_PATH\ 27 | dataset_path=$DATASET_PATH use_iota_embed=true reuse_example_batch=1\ 28 | dataset_type=synthetic gcs_metrics=true attention='flash' quantization="" 29 | -------------------------------------------------------------------------------- /MaxText/configs/experimental/128b.sh: -------------------------------------------------------------------------------- 1 | echo "Running 128b.sh" 2 | # Example command to invoke this script 3 | # bash MaxText/configs/experimental/128b.sh 4 | 5 | # Stop execution if any command exits with error 6 | set -e 7 | 8 | export OUTPUT_PATH="gs://maxtext-experiments-multipod" 9 | export DATASET_PATH="gs://maxtext-dataset/" 10 | 11 | # Set environment variables 12 | for ARGUMENT in "$@"; do 13 | IFS='=' read -r KEY VALUE <<< "$ARGUMENT" 14 | export "$KEY"="$VALUE" 15 | done 16 | 17 | # Use preflight.sh to set up env based on platform 18 | bash preflight.sh PLATFORM=$PLATFORM 19 | 20 | # Train 21 | export LIBTPU_INIT_ARGS="--xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true" 22 | python3 -m MaxText.train MaxText/configs/base.yml run_name=$RUN_NAME\ 23 | steps=30 per_device_batch_size=2 enable_checkpointing=false\ 24 | remat_policy=minimal global_parameter_scale=128\ 25 | ici_fsdp_parallelism=-1 ici_tensor_parallelism=8\ 26 | max_target_length=2048 base_output_directory=$OUTPUT_PATH\ 27 | dataset_path=$DATASET_PATH use_iota_embed=true reuse_example_batch=1\ 28 | dataset_type=synthetic gcs_metrics=true attention='flash' quantization=""\ 29 | -------------------------------------------------------------------------------- /MaxText/configs/experimental/256b.sh: -------------------------------------------------------------------------------- 1 | echo "Running 256b.sh" 2 | # Example command to invoke this script 3 | # bash MaxText/configs/experimental/256b.sh 4 | 5 | # Stop execution if any command exits with error 6 | set -e 7 | 8 | export OUTPUT_PATH="gs://maxtext-experiments-multipod" 9 | export DATASET_PATH="gs://maxtext-dataset/" 10 | 11 | # Set environment variables 12 | for ARGUMENT in "$@"; do 13 | IFS='=' read -r KEY VALUE <<< "$ARGUMENT" 14 | export "$KEY"="$VALUE" 15 | done 16 | 17 | # Use preflight.sh to set up env based on platform 18 | bash preflight.sh PLATFORM=$PLATFORM 19 | 20 | # Train 21 | export LIBTPU_INIT_ARGS="--xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true" 22 | python3 -m MaxText.train MaxText/configs/base.yml run_name=$RUN_NAME\ 23 | steps=20 per_device_batch_size=2 enable_checkpointing=false\ 24 | remat_policy=minimal global_parameter_scale=256\ 25 | ici_fsdp_parallelism=-1 ici_tensor_parallelism=8\ 26 | max_target_length=2048 base_output_directory=$OUTPUT_PATH\ 27 | dataset_path=$DATASET_PATH use_iota_embed=true reuse_example_batch=1\ 28 | dataset_type=synthetic gcs_metrics=true attention='flash' quantization="" 29 | -------------------------------------------------------------------------------- /MaxText/configs/experimental/32b.sh: -------------------------------------------------------------------------------- 1 | echo "Running 32b.sh" 2 | # Example command to invoke this script 3 | # bash MaxText/configs/experimental/32b.sh 4 | 5 | # Stop execution if any command exits with error 6 | set -e 7 | 8 | export OUTPUT_PATH="gs://maxtext-experiments-multipod" 9 | export DATASET_PATH="gs://maxtext-dataset/" 10 | 11 | # Set environment variables 12 | for ARGUMENT in "$@"; do 13 | IFS='=' read -r KEY VALUE <<< "$ARGUMENT" 14 | export "$KEY"="$VALUE" 15 | done 16 | 17 | # Use preflight.sh to set up env based on platform 18 | bash preflight.sh PLATFORM=$PLATFORM 19 | 20 | # Train 21 | export LIBTPU_INIT_ARGS="--xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true" 22 | python3 -m MaxText.train MaxText/configs/base.yml run_name=$RUN_NAME\ 23 | steps=30 per_device_batch_size=8 enable_checkpointing=false\ 24 | remat_policy=minimal global_parameter_scale=32\ 25 | ici_fsdp_parallelism=-1 ici_tensor_parallelism=4\ 26 | max_target_length=2048 base_output_directory=$OUTPUT_PATH\ 27 | dataset_path=$DATASET_PATH use_iota_embed=true reuse_example_batch=1\ 28 | dataset_type=synthetic gcs_metrics=true attention='flash' quantization=""\ 29 | -------------------------------------------------------------------------------- /MaxText/configs/experimental/512b.sh: -------------------------------------------------------------------------------- 1 | echo "Running 512b.sh" 2 | # Example command to invoke this script 3 | # bash MaxText/configs/experimental/512b.sh 4 | 5 | # Stop execution if any command exits with error 6 | set -e 7 | 8 | export OUTPUT_PATH="gs://maxtext-experiments-multipod" 9 | export DATASET_PATH="gs://maxtext-dataset/" 10 | 11 | # Set environment variables 12 | for ARGUMENT in "$@"; do 13 | IFS='=' read -r KEY VALUE <<< "$ARGUMENT" 14 | export "$KEY"="$VALUE" 15 | done 16 | 17 | # Use preflight.sh to set up env based on platform 18 | bash preflight.sh PLATFORM=$PLATFORM 19 | 20 | # Train 21 | export LIBTPU_INIT_ARGS="--xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true" 22 | python3 -m MaxText.train MaxText/configs/base.yml run_name=$RUN_NAME\ 23 | steps=20 per_device_batch_size=2 enable_checkpointing=false\ 24 | remat_policy=full global_parameter_scale=512\ 25 | ici_fsdp_parallelism=-1 ici_tensor_parallelism=16\ 26 | max_target_length=2048 base_output_directory=$OUTPUT_PATH\ 27 | dataset_path=$DATASET_PATH use_iota_embed=true reuse_example_batch=1\ 28 | dataset_type=synthetic gcs_metrics=true attention='flash' quantization="" 29 | -------------------------------------------------------------------------------- /MaxText/configs/experimental/64b.sh: -------------------------------------------------------------------------------- 1 | echo "Running 64b.sh" 2 | # Example command to invoke this script 3 | # bash MaxText/configs/experimental/64b.sh 4 | 5 | # Stop execution if any command exits with error 6 | set -e 7 | 8 | export OUTPUT_PATH="gs://maxtext-experiments-multipod" 9 | export DATASET_PATH="gs://maxtext-dataset/" 10 | 11 | # Set environment variables 12 | for ARGUMENT in "$@"; do 13 | IFS='=' read -r KEY VALUE <<< "$ARGUMENT" 14 | export "$KEY"="$VALUE" 15 | done 16 | 17 | # Use preflight.sh to set up env based on platform 18 | bash preflight.sh PLATFORM=$PLATFORM 19 | 20 | # Train 21 | export LIBTPU_INIT_ARGS="--xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true" 22 | python3 -m MaxText.train MaxText/configs/base.yml run_name=$RUN_NAME\ 23 | steps=30 per_device_batch_size=4 enable_checkpointing=false\ 24 | remat_policy=minimal global_parameter_scale=64\ 25 | ici_fsdp_parallelism=-1 ici_tensor_parallelism=4\ 26 | max_target_length=2048 base_output_directory=$OUTPUT_PATH\ 27 | dataset_path=$DATASET_PATH use_iota_embed=true reuse_example_batch=1\ 28 | dataset_type=synthetic gcs_metrics=true attention='flash' quantization=""\ 29 | -------------------------------------------------------------------------------- /MaxText/configs/gpu_smoke_test.yml: -------------------------------------------------------------------------------- 1 | base_config: "base.yml" 2 | 3 | hardware: "gpu" 4 | attention: "dot_product" 5 | base_emb_dim: 8 6 | base_num_query_heads: 4 7 | base_num_kv_heads: 4 8 | base_mlp_dim: 32 9 | base_num_decoder_layers: 8 10 | head_dim: 16 11 | per_device_batch_size: 2 12 | max_target_length: 1024 13 | dataset_type: "synthetic" 14 | steps: 10 15 | -------------------------------------------------------------------------------- /MaxText/configs/inference_jetstream.yml: -------------------------------------------------------------------------------- 1 | base_config: "base.yml" 2 | 3 | enable_jax_profiler: False 4 | jax_profiler_port: 9999 5 | 6 | enable_model_warmup: False -------------------------------------------------------------------------------- /MaxText/configs/models/deepseek2-16b.yml: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # model config for DeepSeek V2-Lite - 16B 16 | 17 | base_emb_dim: 2048 18 | base_num_query_heads: 16 19 | base_num_kv_heads: 16 20 | base_mlp_dim: 10944 21 | base_moe_mlp_dim: 1408 22 | base_num_decoder_layers: 27 23 | first_num_dense_layers: 1 24 | mlp_activations: ["silu","linear"] 25 | vocab_size: 102400 26 | enable_dropout: False 27 | logits_via_embedding: False 28 | normalization_layer_epsilon: 1.0e-6 29 | num_experts: 64 30 | num_experts_per_tok: 6 31 | shared_experts: 2 32 | routed_scaling_factor: 1.0 33 | routed_score_func: "softmax" 34 | routed_bias: False 35 | decoder_block: "deepseek" 36 | # MLA 37 | attention_type: "mla" 38 | q_lora_rank: 0 39 | kv_lora_rank: 512 40 | qk_nope_head_dim: 128 41 | qk_rope_head_dim: 64 42 | v_head_dim: 128 43 | # RoPE 44 | rope_type: "yarn" 45 | rope_max_timescale: 10_000 # DeepSeek uses "rope_theta": 10000 46 | max_position_embeddings: 163840 47 | original_max_position_embeddings: 4096 48 | rope_factor: 40 49 | beta_fast: 32 50 | mscale: 0.707 51 | -------------------------------------------------------------------------------- /MaxText/configs/models/deepseek2-236b.yml: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # model config for DeepSeek V2 - 236B 16 | # Please note: DeepSeek V2 - 236B is not fully support at this moment 17 | 18 | base_emb_dim: 5120 19 | base_num_query_heads: 128 20 | base_num_kv_heads: 128 21 | base_mlp_dim: 12288 22 | base_moe_mlp_dim: 1536 23 | base_num_decoder_layers: 60 24 | first_num_dense_layers: 1 25 | mlp_activations: ["silu","linear"] 26 | vocab_size: 102400 27 | enable_dropout: False 28 | logits_via_embedding: False 29 | normalization_layer_epsilon: 1.0e-6 30 | num_experts: 160 31 | num_experts_per_tok: 6 32 | shared_experts: 2 33 | routed_scaling_factor: 16.0 34 | routed_score_func: "softmax" 35 | routed_bias: False 36 | decoder_block: "deepseek" 37 | # MLA 38 | attention_type: "mla" 39 | q_lora_rank: 1536 40 | kv_lora_rank: 512 41 | qk_nope_head_dim: 128 42 | qk_rope_head_dim: 64 43 | v_head_dim: 128 44 | # RoPE 45 | rope_type: "yarn" 46 | rope_max_timescale: 10_000 # DeepSeek uses "rope_theta": 10000 47 | max_position_embeddings: 163840 48 | original_max_position_embeddings: 4096 49 | rope_factor: 40 50 | beta_fast: 32 51 | mscale: 0.707 52 | -------------------------------------------------------------------------------- /MaxText/configs/models/deepseek3-671b.yml: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # model config for DeepSeek V3 - 671B 16 | 17 | # For DeepSeek default device-limited routing, 18 | # please set n_routing_groups=8 and topk_routing_group=4 in your command-line arguments. 19 | 20 | base_emb_dim: 7168 21 | base_num_query_heads: 128 22 | base_num_kv_heads: 128 23 | base_mlp_dim: 18432 24 | base_moe_mlp_dim: 2048 25 | base_num_decoder_layers: 61 26 | first_num_dense_layers: 3 27 | mlp_activations: ["silu","linear"] 28 | vocab_size: 129280 29 | enable_dropout: False 30 | logits_via_embedding: False 31 | normalization_layer_epsilon: 1.0e-6 32 | num_experts: 256 33 | num_experts_per_tok: 8 34 | shared_experts: 1 35 | routed_scaling_factor: 2.5 36 | routed_score_func: "sigmoid" 37 | routed_bias: True 38 | decoder_block: "deepseek" 39 | # MLA 40 | attention_type: "mla" 41 | q_lora_rank: 1536 42 | kv_lora_rank: 512 43 | qk_nope_head_dim: 128 44 | qk_rope_head_dim: 64 45 | v_head_dim: 128 46 | mscale: 1.0 47 | # RoPE 48 | rope_type: "yarn" 49 | rope_max_timescale: 10_000 # DeepSeek uses "rope_theta": 10000 50 | max_position_embeddings: 163840 51 | original_max_position_embeddings: 4096 52 | rope_factor: 40 53 | beta_fast: 32 54 | -------------------------------------------------------------------------------- /MaxText/configs/models/gemma-2b.yml: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # model config for gemma-2b 16 | 17 | base_emb_dim: 2048 18 | base_num_query_heads: 8 19 | base_num_kv_heads: 1 20 | base_mlp_dim: 16384 21 | base_num_decoder_layers: 18 22 | head_dim: 256 23 | mlp_activations: ["gelu","linear"] 24 | vocab_size: 256128 25 | decoder_block: "gemma" 26 | normalization_layer_epsilon: 1.e-06 27 | logits_via_embedding: True -------------------------------------------------------------------------------- /MaxText/configs/models/gemma-7b.yml: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # model config for gemma-7b 16 | 17 | base_emb_dim: 3072 18 | base_num_query_heads: 16 19 | base_num_kv_heads: 16 20 | base_mlp_dim: 24576 21 | base_num_decoder_layers: 28 22 | head_dim: 256 23 | mlp_activations: ["gelu","linear"] 24 | vocab_size: 256128 25 | decoder_block: "gemma" 26 | normalization_layer_epsilon: 1.e-06 27 | logits_via_embedding: True -------------------------------------------------------------------------------- /MaxText/configs/models/gemma2-27b.yml: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # model config for gemma2-27B 16 | 17 | base_emb_dim: 4608 18 | base_num_query_heads: 32 19 | base_num_kv_heads: 16 20 | base_mlp_dim: 36864 21 | base_num_decoder_layers: 23 # half of the real number of layers because we merge [local_attention, global_attention] into one layer 22 | head_dim: 128 23 | mlp_activations: ["gelu","linear"] 24 | vocab_size: 256128 25 | decoder_block: "gemma2" 26 | normalization_layer_epsilon: 1.e-06 27 | logits_via_embedding: True 28 | final_logits_soft_cap: 30.0 29 | attn_logits_soft_cap: 50.0 30 | sliding_window_size: 4096 31 | use_post_attn_norm: True 32 | use_post_ffw_norm: True 33 | -------------------------------------------------------------------------------- /MaxText/configs/models/gemma2-2b.yml: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # model config for gemma-2B 16 | 17 | base_emb_dim: 2304 18 | base_num_query_heads: 8 19 | base_num_kv_heads: 4 20 | base_mlp_dim: 9216 21 | base_num_decoder_layers: 13 # half of the real number of layers because we merge [local_attention, global_attention] into one layer 22 | head_dim: 256 23 | mlp_activations: ["gelu","linear"] 24 | vocab_size: 256128 25 | decoder_block: "gemma2" 26 | normalization_layer_epsilon: 1.e-06 27 | logits_via_embedding: True 28 | final_logits_soft_cap: 30.0 29 | attn_logits_soft_cap: 50.0 30 | sliding_window_size: 4096 31 | use_post_attn_norm: True 32 | use_post_ffw_norm: True 33 | -------------------------------------------------------------------------------- /MaxText/configs/models/gemma2-9b.yml: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # model config for gemma2-9B 16 | 17 | base_emb_dim: 3584 18 | base_num_query_heads: 16 19 | base_num_kv_heads: 8 20 | base_mlp_dim: 14336 21 | base_num_decoder_layers: 21 # half of the real number of layers because we merge [local_attention, global_attention] into one layer 22 | head_dim: 256 23 | mlp_activations: ["gelu","linear"] 24 | vocab_size: 256128 25 | decoder_block: "gemma2" 26 | normalization_layer_epsilon: 1.e-06 27 | logits_via_embedding: True 28 | final_logits_soft_cap: 30.0 29 | attn_logits_soft_cap: 50.0 30 | sliding_window_size: 4096 31 | use_post_attn_norm: True 32 | use_post_ffw_norm: True 33 | -------------------------------------------------------------------------------- /MaxText/configs/models/gemma3-12b.yml: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # model config for gemma3-12b 16 | 17 | base_num_decoder_layers: 48 18 | base_emb_dim: 3840 19 | base_num_query_heads: 16 20 | base_num_kv_heads: 8 21 | base_mlp_dim: 15360 22 | head_dim: 256 23 | mlp_activations: ["gelu","linear"] 24 | vocab_size: 262_144 25 | decoder_block: "gemma3" 26 | normalization_layer_epsilon: 1e-6 27 | logits_via_embedding: True 28 | sliding_window_size: 1024 29 | use_post_attn_norm: true 30 | use_post_ffw_norm: true 31 | local_rope_max_timescale: 10_000 32 | rope_max_timescale: 1_000_000 33 | -------------------------------------------------------------------------------- /MaxText/configs/models/gemma3-27b.yml: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # model config for gemma3-27b 16 | 17 | base_num_decoder_layers: 62 18 | base_emb_dim: 5376 19 | base_num_query_heads: 32 20 | base_num_kv_heads: 16 21 | base_mlp_dim: 21504 22 | head_dim: 128 23 | mlp_activations: ["gelu","linear"] 24 | vocab_size: 262_144 25 | decoder_block: "gemma3" 26 | normalization_layer_epsilon: 1e-6 27 | logits_via_embedding: True 28 | sliding_window_size: 1024 29 | use_post_attn_norm: true 30 | use_post_ffw_norm: true 31 | local_rope_max_timescale: 10_000 32 | rope_max_timescale: 1_000_000 33 | -------------------------------------------------------------------------------- /MaxText/configs/models/gemma3-4b.yml: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # model config for gemma3-4b 16 | 17 | base_num_decoder_layers: 34 18 | base_emb_dim: 2560 19 | base_num_query_heads: 8 20 | base_num_kv_heads: 4 21 | base_mlp_dim: 10240 22 | head_dim: 256 23 | mlp_activations: ["gelu","linear"] 24 | vocab_size: 262_144 25 | decoder_block: "gemma3" 26 | normalization_layer_epsilon: 1e-6 27 | logits_via_embedding: True 28 | sliding_window_size: 1024 29 | use_post_attn_norm: true 30 | use_post_ffw_norm: true 31 | local_rope_max_timescale: 10_000 32 | rope_max_timescale: 1_000_000 33 | -------------------------------------------------------------------------------- /MaxText/configs/models/gpt3-175b.yml: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing software 10 | # distributed under the License is distributed on an "AS IS" BASIS 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # model config for gpt3-175b 16 | 17 | base_emb_dim: 12288 18 | base_num_query_heads: 96 19 | base_num_kv_heads: 96 20 | base_mlp_dim: 49152 21 | base_num_decoder_layers: 96 22 | head_dim: 128 23 | trainable_position_size: 16384 24 | mlp_activations: ["gelu"] 25 | vocab_size: 50304 26 | enable_dropout: False 27 | logits_via_embedding: True 28 | normalize_embedding_logits: False 29 | logits_dot_in_fp32: False 30 | normalization_layer_epsilon: 1.e-05 31 | use_iota_embed: True 32 | fused_qkv: True 33 | opt_type: "adam_pax" 34 | decoder_block: "gpt3" 35 | dataset_path: "gs://mlperf-llm-public2" 36 | dataset_name: "c4/en:3.0.4" 37 | eval_dataset_name: "c4/en:3.0.5" 38 | gradient_clipping_threshold: 1. 39 | adam_b1: 0.9 40 | adam_b2: 0.95 41 | adam_eps: 1.e-8 42 | adam_weight_decay: 0.1 43 | checkpoint_period: 10_000 44 | target_eval_loss: 2.69 45 | eval_per_device_batch_size: 1.0 46 | -------------------------------------------------------------------------------- /MaxText/configs/models/gpt3-22b.yml: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing software 10 | # distributed under the License is distributed on an "AS IS" BASIS 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # model config for gpt3-22b 16 | 17 | base_emb_dim: 6144 18 | base_num_query_heads: 24 19 | base_num_kv_heads: 24 20 | base_mlp_dim: 24576 21 | base_num_decoder_layers: 48 22 | head_dim: 256 23 | max_target_length: 1024 24 | trainable_position_size: 16384 25 | mlp_activations: ["gelu"] 26 | vocab_size: 32768 27 | enable_dropout: False 28 | logits_via_embedding: True 29 | normalize_embedding_logits: False 30 | logits_dot_in_fp32: False 31 | normalization_layer_epsilon: 1.e-05 32 | use_iota_embed: True 33 | fused_qkv: True 34 | opt_type: "adam_pax" 35 | decoder_block: "gpt3" 36 | gradient_clipping_threshold: 1. 37 | adam_b1: 0.9 38 | adam_b2: 0.95 39 | adam_eps: 1.e-8 40 | adam_weight_decay: 0.1 41 | -------------------------------------------------------------------------------- /MaxText/configs/models/gpt3-52k.yml: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing software 10 | # distributed under the License is distributed on an "AS IS" BASIS 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # model config for gpt3-52k, i.e. a fake and small model for testing purpose only 16 | 17 | base_emb_dim: 16 18 | base_num_query_heads: 2 19 | base_num_kv_heads: 2 20 | base_mlp_dim: 64 21 | base_num_decoder_layers: 1 22 | head_dim: 8 23 | trainable_position_size: 2048 24 | mlp_activations: ["gelu"] 25 | vocab_size: 1024 26 | enable_dropout: False 27 | logits_via_embedding: True 28 | normalize_embedding_logits: False 29 | logits_dot_in_fp32: False 30 | normalization_layer_epsilon: 1.e-05 31 | use_iota_embed: True 32 | fused_qkv: True 33 | opt_type: "adam_pax" 34 | decoder_block: "gpt3" 35 | gradient_clipping_threshold: 1. 36 | adam_b1: 0.9 37 | adam_b2: 0.95 38 | adam_eps: 1.e-8 39 | adam_weight_decay: 0.1 40 | attention: "dot_product" # head_dim 8 is too small for splash/flash attention 41 | -------------------------------------------------------------------------------- /MaxText/configs/models/gpt3-6b.yml: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing software 10 | # distributed under the License is distributed on an "AS IS" BASIS 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # model config for gpt3-6b 16 | 17 | base_emb_dim: 3072 18 | base_num_query_heads: 12 19 | base_num_kv_heads: 12 20 | base_mlp_dim: 12288 21 | base_num_decoder_layers: 48 22 | head_dim: 256 23 | max_target_length: 1024 24 | trainable_position_size: 16384 25 | mlp_activations: ["gelu"] 26 | vocab_size: 32768 27 | enable_dropout: False 28 | logits_via_embedding: True 29 | normalize_embedding_logits: False 30 | logits_dot_in_fp32: False 31 | normalization_layer_epsilon: 1.e-05 32 | use_iota_embed: True 33 | fused_qkv: True 34 | opt_type: "adam_pax" 35 | decoder_block: "gpt3" 36 | gradient_clipping_threshold: 1. 37 | adam_b1: 0.9 38 | adam_b2: 0.95 39 | adam_eps: 1.e-8 40 | adam_weight_decay: 0.1 41 | -------------------------------------------------------------------------------- /MaxText/configs/models/gpu/llama2_70b.yml: -------------------------------------------------------------------------------- 1 | base_config: "base.yml" 2 | 3 | run_name: "gpu_train_test" 4 | # Args coming from the NVIDIA spreadsheet http://shortn/_AhULYn1mX4. 5 | hardware: "gpu" 6 | steps: 30 7 | model_name: "llama2-70b" 8 | enable_checkpointing: False 9 | attention: "cudnn_flash_te" 10 | remat_policy: "full" 11 | use_iota_embed: True 12 | scan_layers: True 13 | dataset_type: "synthetic" 14 | async_checkpointing: False 15 | logits_dot_in_fp32: False 16 | 17 | per_device_batch_size: 6 18 | max_target_length: 4096 19 | -------------------------------------------------------------------------------- /MaxText/configs/models/gpu/llama2_7b.yml: -------------------------------------------------------------------------------- 1 | base_config: "base.yml" 2 | 3 | run_name: "gpu_train_test" 4 | hardware: "gpu" 5 | steps: 30 6 | per_device_batch_size: 4 7 | max_target_length: 4096 8 | model_name: "llama2-7b" 9 | enable_checkpointing: False 10 | attention: "cudnn_flash_te" 11 | remat_policy: "minimal_flash" 12 | use_iota_embed: True 13 | scan_layers: False 14 | dataset_type: "synthetic" 15 | async_checkpointing: False -------------------------------------------------------------------------------- /MaxText/configs/models/gpu/llama3.1_405b.yml: -------------------------------------------------------------------------------- 1 | base_config: "base.yml" 2 | run_name: "gpu_train_test" 3 | # Args coming from the NVIDIA spreadsheet http://shortn/_AhULYn1mX4. 4 | hardware: "gpu" 5 | steps: 10 6 | model_name: "llama3.1-405b" 7 | enable_checkpointing: False 8 | #attention: "cudnn_flash_te" 9 | remat_policy: "full" 10 | use_iota_embed: True 11 | scan_layers: True 12 | dataset_type: "synthetic" 13 | async_checkpointing: False 14 | logits_dot_in_fp32: False 15 | per_device_batch_size: 1.0 16 | max_target_length: 4096 17 | -------------------------------------------------------------------------------- /MaxText/configs/models/gpu/llama3_70b.yml: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | base_config: "base.yml" 16 | # The model_name gurantees we will use the correct model from 17 | # configs/models/llama3-70b.yml, see update_model_vars in pyconfig.py for details. 18 | model_name: "llama3-70b" 19 | 20 | run_name: "gpu_train_test" 21 | hardware: "gpu" 22 | steps: 30 23 | per_device_batch_size: 4 24 | max_target_length: 8192 25 | attention: "cudnn_flash_te" 26 | remat_policy: "full" 27 | use_iota_embed: True 28 | dataset_type: "synthetic" 29 | reuse_example_batch: 1 30 | enable_checkpointing: False 31 | -------------------------------------------------------------------------------- /MaxText/configs/models/gpu/llama3_8b.yml: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | base_config: "base.yml" 16 | # The model_name gurantees we will use the correct model from 17 | # configs/models/llama3-8b.yml, see update_model_vars in pyconfig.py for details. 18 | model_name: "llama3-8b" 19 | 20 | run_name: "gpu_train_test" 21 | hardware: "gpu" 22 | steps: 30 23 | per_device_batch_size: 12 24 | max_target_length: 8192 25 | attention: "cudnn_flash_te" 26 | remat_policy: "minimal_flash" 27 | use_iota_embed: True 28 | dataset_type: "synthetic" 29 | reuse_example_batch: 1 30 | enable_checkpointing: False 31 | -------------------------------------------------------------------------------- /MaxText/configs/models/gpu/mixtral_8x1b.yml: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | base_config: "base.yml" 16 | # The model_name gurantees we use the correct model params from 17 | # configs/models/mixtral-8x7b. mixtral-8x1b is mixtral-8x7b but 18 | # with different base_num_decoder_layers. 19 | model_name: "mixtral-8x7b" 20 | base_num_decoder_layers: 5 21 | 22 | run_name: "gpu_train_test" 23 | hardware: "gpu" 24 | steps: 30 25 | 26 | per_device_batch_size: 8 27 | max_target_length: 4096 28 | attention: "cudnn_flash_te" 29 | remat_policy: "full" 30 | use_iota_embed: True 31 | dataset_type: "synthetic" 32 | reuse_example_batch: 1 33 | enable_checkpointing: False 34 | megablox: False 35 | sparse_matmul: False 36 | -------------------------------------------------------------------------------- /MaxText/configs/models/gpu/mixtral_8x2b.yml: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | base_config: "base.yml" 16 | # The model_name gurantees we use the correct model params from 17 | # configs/models/mixtral-8x7b. mixtral-8x2b is mixtral-8x7b but 18 | # with different base_num_decoder_layers. 19 | model_name: "mixtral-8x7b" 20 | base_num_decoder_layers: 10 21 | 22 | run_name: "gpu_train_test" 23 | hardware: "gpu" 24 | steps: 30 25 | 26 | per_device_batch_size: 8 27 | max_target_length: 4096 28 | attention: "cudnn_flash_te" 29 | remat_policy: "full" 30 | use_iota_embed: True 31 | dataset_type: "synthetic" 32 | reuse_example_batch: 1 33 | enable_checkpointing: False 34 | megablox: False 35 | sparse_matmul: False 36 | -------------------------------------------------------------------------------- /MaxText/configs/models/gpu/mixtral_8x7b.yml: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | base_config: "base.yml" 16 | # The model_name gurantees we will use the correct model from 17 | # configs/models/mixtral-8x7b.yml, see update_model_vars in pyconfig.py for details. 18 | model_name: "mixtral-8x7b" 19 | 20 | run_name: "gpu_train_test" 21 | hardware: "gpu" 22 | steps: 30 23 | per_device_batch_size: 12 24 | max_target_length: 4096 25 | attention: "cudnn_flash_te" 26 | remat_policy: "minimal_flash" 27 | use_iota_embed: True 28 | dataset_type: "synthetic" 29 | reuse_example_batch: 1 30 | enable_checkpointing: False 31 | megablox: False 32 | scan_layers: False 33 | tokenizer_path: "/deps/assets/tokenizer.mistral-v1" 34 | profiler: "nsys" 35 | capacity_factor: 1.0 36 | -------------------------------------------------------------------------------- /MaxText/configs/models/llama2-13b.yml: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # model config for llama2-13b 16 | 17 | base_emb_dim: 5120 18 | base_num_query_heads: 40 19 | base_num_kv_heads: 40 20 | base_mlp_dim: 13824 21 | base_num_decoder_layers: 40 22 | head_dim: 128 23 | mlp_activations: ["silu","linear"] 24 | vocab_size: 32000 25 | enable_dropout: False 26 | logits_via_embedding: False 27 | normalization_layer_epsilon: 1.0e-5 28 | decoder_block: "llama2" 29 | logical_axis_rules: [['norm', 'fsdp']] 30 | -------------------------------------------------------------------------------- /MaxText/configs/models/llama2-70b.yml: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # model config for llama2-7b 16 | 17 | base_emb_dim: 8192 18 | base_num_query_heads: 64 19 | base_num_kv_heads: 8 20 | base_mlp_dim: 28672 21 | base_num_decoder_layers: 80 22 | head_dim: 128 23 | mlp_activations: ["silu","linear"] 24 | vocab_size: 32000 25 | logits_via_embedding: False 26 | normalization_layer_epsilon: 1.0e-5 27 | decoder_block: "llama2" 28 | logical_axis_rules: [['norm', 'fsdp']] 29 | -------------------------------------------------------------------------------- /MaxText/configs/models/llama2-7b.yml: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # model config for llama2-7b 16 | 17 | base_emb_dim: 4096 18 | base_num_query_heads: 32 19 | base_num_kv_heads: 32 20 | base_mlp_dim: 11008 21 | base_num_decoder_layers: 32 22 | head_dim: 128 23 | mlp_activations: ["silu","linear"] 24 | vocab_size: 32000 25 | enable_dropout: False 26 | logits_via_embedding: False 27 | normalization_layer_epsilon: 1.0e-5 28 | decoder_block: "llama2" 29 | logical_axis_rules: [['norm', 'fsdp']] 30 | -------------------------------------------------------------------------------- /MaxText/configs/models/llama3-405b.yml: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # model config for llama3-405b 16 | 17 | 18 | base_emb_dim: 16384 19 | base_num_query_heads: 128 20 | base_num_kv_heads: 8 21 | base_num_decoder_layers: 126 22 | base_mlp_dim: 53248 23 | head_dim: 128 24 | mlp_activations: ["silu","linear"] 25 | vocab_size: 128256 26 | enable_dropout: False 27 | logits_via_embedding: False 28 | normalization_layer_epsilon: 1.0e-5 29 | rope_max_timescale: 500_000 30 | decoder_block: "llama2" # Uses the same decoder block as llama2 31 | -------------------------------------------------------------------------------- /MaxText/configs/models/llama3-70b.yml: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # model config for llama3-70b 16 | 17 | base_emb_dim: 8192 18 | base_num_query_heads: 64 19 | base_num_kv_heads: 8 20 | base_num_decoder_layers: 80 21 | base_mlp_dim: 28672 22 | head_dim: 128 23 | mlp_activations: ["silu","linear"] 24 | vocab_size: 128256 25 | enable_dropout: False 26 | logits_via_embedding: False 27 | normalization_layer_epsilon: 1.0e-5 28 | rope_max_timescale: 500_000 29 | decoder_block: "llama2" # Uses the same decoder block as llama2 30 | -------------------------------------------------------------------------------- /MaxText/configs/models/llama3-8b.yml: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # model config for llama3-8b 16 | 17 | base_emb_dim: 4096 18 | base_num_query_heads: 32 19 | base_num_kv_heads: 8 20 | base_num_decoder_layers: 32 21 | base_mlp_dim: 14336 22 | head_dim: 128 23 | mlp_activations: ["silu","linear"] 24 | vocab_size: 128256 25 | enable_dropout: False 26 | logits_via_embedding: False 27 | normalization_layer_epsilon: 1.0e-5 28 | rope_max_timescale: 500_000 29 | decoder_block: "llama2" # Uses the same decoder block as llama2 30 | -------------------------------------------------------------------------------- /MaxText/configs/models/llama3.1-405b.yml: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # model config for llama3.1-405b 16 | 17 | base_emb_dim: 16384 18 | base_num_query_heads: 128 19 | base_num_kv_heads: 8 20 | base_num_decoder_layers: 126 21 | base_mlp_dim: 53248 22 | head_dim: 128 23 | mlp_activations: ["silu","linear"] 24 | vocab_size: 128256 25 | enable_dropout: False 26 | logits_via_embedding: False 27 | normalization_layer_epsilon: 1.0e-5 28 | rope_max_timescale: 500_000 29 | decoder_block: "llama2" # Uses the same decoder block as llama2 30 | -------------------------------------------------------------------------------- /MaxText/configs/models/llama3.1-70b.yml: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # model config for llama3.1-70b 16 | 17 | base_emb_dim: 8192 18 | base_num_query_heads: 64 19 | base_num_kv_heads: 8 20 | base_num_decoder_layers: 80 21 | base_mlp_dim: 28672 22 | head_dim: 128 23 | mlp_activations: ["silu","linear"] 24 | vocab_size: 128256 25 | enable_dropout: False 26 | logits_via_embedding: False 27 | normalization_layer_epsilon: 1.0e-5 28 | rope_max_timescale: 500_000 29 | decoder_block: "llama2" # Uses the same decoder block as llama2 30 | -------------------------------------------------------------------------------- /MaxText/configs/models/llama3.1-8b.yml: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # model config for llama3.1-8b 16 | 17 | base_emb_dim: 4096 18 | base_num_query_heads: 32 19 | base_num_kv_heads: 8 20 | base_num_decoder_layers: 32 21 | base_mlp_dim: 14336 22 | head_dim: 128 23 | mlp_activations: ["silu","linear"] 24 | vocab_size: 128256 25 | enable_dropout: False 26 | logits_via_embedding: False 27 | normalization_layer_epsilon: 1.0e-5 28 | rope_max_timescale: 500_000 29 | decoder_block: "llama2" # Uses the same decoder block as llama2 30 | -------------------------------------------------------------------------------- /MaxText/configs/models/llama3.3-70b.yml: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # model config for llama3.3-70b 16 | 17 | base_emb_dim: 8192 18 | base_num_query_heads: 64 19 | base_num_kv_heads: 8 20 | base_num_decoder_layers: 80 21 | base_mlp_dim: 28672 22 | head_dim: 128 23 | mlp_activations: ["silu","linear"] 24 | vocab_size: 128256 25 | enable_dropout: False 26 | logits_via_embedding: False 27 | normalization_layer_epsilon: 1.0e-5 28 | rope_max_timescale: 500_000 29 | decoder_block: "llama2" # Uses the same decoder block as llama2 30 | -------------------------------------------------------------------------------- /MaxText/configs/models/llama4-17b-128e.yml: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | decoder_block: "llama4" 17 | mlp_activations: ["silu","linear"] 18 | enable_dropout: False 19 | tokenizer_type: "huggingface" 20 | 21 | base_emb_dim: 5120 22 | base_num_decoder_layers: 48 23 | base_num_query_heads: 40 24 | base_num_kv_heads: 8 25 | base_mlp_dim: 16384 26 | base_moe_mlp_dim: 8192 27 | vocab_size: 202048 28 | normalization_layer_epsilon: 1e-05 29 | rope_max_timescale: 500000 30 | rope_type: "llama3.1" 31 | rope_use_scale: False 32 | num_experts: 128 33 | shared_experts: 1 34 | num_experts_per_tok: 1 35 | use_qk_norm: False 36 | nope_layer_interval: 4 # Every fourth layer should NOT use RoPE 37 | interleave_moe_layer_step: 2 # Every 2nd layer is MoE layer, and 1st layer is dense layer 38 | inhomogeneous_layer_cycle_interval: 4 # Every four layers the pattern of nope and moe repeats (least common multiple of nope interval and moe interval) 39 | 40 | temperature_tuning: True 41 | # Chunk attention is used on all RoPE layers 42 | # otherwise, on NoPE layers, use global attention 43 | chunk_attn_window_size: 8192 44 | image_size_for_vit: 336 45 | -------------------------------------------------------------------------------- /MaxText/configs/models/llama4-17b-16e.yml: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | decoder_block: "llama4" 17 | mlp_activations: ["silu","linear"] 18 | enable_dropout: False 19 | logits_via_embedding: False 20 | tokenizer_type: "huggingface" 21 | 22 | base_emb_dim: 5120 23 | base_num_decoder_layers: 48 24 | base_num_query_heads: 40 25 | base_num_kv_heads: 8 26 | base_mlp_dim: 16384 27 | base_moe_mlp_dim: 8192 28 | vocab_size: 202048 29 | normalization_layer_epsilon: 1e-05 30 | rope_max_timescale: 500000 31 | rope_type: "llama3.1" 32 | num_experts: 16 33 | shared_experts: 1 34 | num_experts_per_tok: 1 35 | use_qk_norm: True # Llama4 models apply an L2Norm to the Query and Keys after RoPE 36 | nope_layer_interval: 4 # Every fourth layer should NOT use RoPE 37 | interleave_moe_layer_step: 1 # Every layer is MoE layer 38 | inhomogeneous_layer_cycle_interval: 4 # Every four layers the pattern of nope and moe repeats (least common multiple of nope interval and moe interval) 39 | 40 | temperature_tuning: True 41 | # Chunk attention is used on all RoPE layers 42 | # otherwise, on NoPE layers, use global attention 43 | chunk_attn_window_size: 8192 44 | image_size_for_vit: 336 45 | -------------------------------------------------------------------------------- /MaxText/configs/models/mistral-7b.yml: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # model config for mistral-7b 16 | 17 | base_emb_dim: 4096 18 | base_num_query_heads: 32 19 | base_num_kv_heads: 8 20 | base_mlp_dim: 14336 21 | base_num_decoder_layers: 32 22 | head_dim: 128 23 | mlp_activations: ["silu","linear"] 24 | vocab_size: 32000 25 | enable_dropout: False 26 | logits_via_embedding: False 27 | normalization_layer_epsilon: 1.0e-5 28 | rope_max_timescale: 1_000_000 29 | decoder_block: "mistral" 30 | -------------------------------------------------------------------------------- /MaxText/configs/models/mixtral-8x22b.yml: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # model config for mixtral-8x22b 16 | # tokenizer_path is assets/tokenizer.mistral-v3 17 | 18 | base_emb_dim: 6144 19 | base_num_query_heads: 48 20 | base_num_kv_heads: 8 21 | base_mlp_dim: 16384 22 | base_num_decoder_layers: 56 23 | head_dim: 128 24 | mlp_activations: ["silu","linear"] 25 | vocab_size: 32768 26 | enable_dropout: False 27 | logits_via_embedding: False 28 | normalization_layer_epsilon: 1.0e-5 29 | num_experts: 8 30 | num_experts_per_tok: 2 31 | rope_max_timescale: 1_000_000 32 | decoder_block: "mixtral" 33 | -------------------------------------------------------------------------------- /MaxText/configs/models/mixtral-8x7b.yml: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # model config for mixtral-8x7b 16 | # tokenizer_path is assets/tokenizer.mistral-v1 17 | 18 | base_emb_dim: 4096 19 | base_num_query_heads: 32 20 | base_num_kv_heads: 8 21 | base_mlp_dim: 14336 22 | base_num_decoder_layers: 32 23 | head_dim: 128 24 | mlp_activations: ["silu","linear"] 25 | vocab_size: 32000 26 | enable_dropout: False 27 | logits_via_embedding: False 28 | normalization_layer_epsilon: 1.0e-5 29 | num_experts: 8 30 | num_experts_per_tok: 2 31 | rope_max_timescale: 1_000_000 32 | decoder_block: "mixtral" 33 | -------------------------------------------------------------------------------- /MaxText/configs/quantization/README.md: -------------------------------------------------------------------------------- 1 | # Mixed precision quantization configs (currently supported for inference only). 2 | 3 | This directory contains sample json files representing mixed precision quantization configs. 4 | 5 | A mixed precision config json is contains the following: 6 | Keys represent a regex for the layer to which the config is applied. 7 | Values represent the quantization config for the corresponding layer. 8 | 9 | The quantization config for any layer is defined using the following variables. 10 | w_bits: Number of bits used for weights, default None (implying no quantization) 11 | a_bits: Number of bits used for activations, default None (implying no quantization) 12 | w_scale: Clipping scale for weights, default 1.0 13 | a_scale: Clipping scale for activations, default 1.0 14 | tile_size: tile size for subchannel, default -1 (implying no subchannel) 15 | 16 | For example, the config below implies 4-bit weight_only quantization for layers wi_0 and w0. 17 | { 18 | ".*/wi_0": {"w_bits": 4}, 19 | ".*/wo": {"w_bits": 4} 20 | } 21 | 22 | A special key '__default__' can be used to override the default values. 23 | For example the following config defines 8-bit weight only quantization for all layers. 24 | { 25 | "__default__": {"w_bits": 8} 26 | } 27 | 28 | # To configure mixed precision quantization, define the following. 29 | 1. A json file (e.g. example.json) in this directory with desired config 30 | 2. Set the following parameters defined in base.yml 31 | quantization="intmp" 32 | quant_cfg_path="/example.json" 33 | -------------------------------------------------------------------------------- /MaxText/configs/quantization/dense_llm_subchannel.json: -------------------------------------------------------------------------------- 1 | { 2 | "__default__": {"w_bits": 8, "a_bits": 8}, 3 | ".*/query": {"w_bits": 4, "tile_size": 128}, 4 | ".*/key": {"w_bits": 4, "tile_size": 256}, 5 | ".*/value": {"w_bits": 4}, 6 | ".*/out": {"w_bits": 4}, 7 | ".*/wi_0": {"w_bits": 4}, 8 | ".*/wo": {"w_bits": 4} 9 | } 10 | -------------------------------------------------------------------------------- /MaxText/configs/quantization/dense_llm_weight_only_scale.json: -------------------------------------------------------------------------------- 1 | { 2 | "__default__": {"w_bits": 8}, 3 | ".*/query": {"w_bits": 4, "w_scale": 0.8}, 4 | ".*/key": {"w_bits": 4, "w_scale": 0.9}, 5 | ".*/value": {"w_bits": 4}, 6 | ".*/out": {"w_bits": 4}, 7 | ".*/wi_0": {"w_bits": 4}, 8 | ".*/wo": {"w_bits": 4} 9 | } -------------------------------------------------------------------------------- /MaxText/configs/quantization/int4_weight_only.json: -------------------------------------------------------------------------------- 1 | { 2 | "__default__": {"w_bits": 4} 3 | } -------------------------------------------------------------------------------- /MaxText/configs/quantization/int8_weight_only.json: -------------------------------------------------------------------------------- 1 | { 2 | "__default__": {"w_bits": 8} 3 | } -------------------------------------------------------------------------------- /MaxText/configs/sft.yml: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | base_config: "base.yml" 16 | 17 | use_sft: True 18 | # sft_train_on_completion_only=False trains on both prompt and completion tokens; trains only on completion tokens otherwise 19 | sft_train_on_completion_only: True 20 | packing: True 21 | learning_rate: 2.e-5 22 | 23 | # -------------- HF pipeline -------------- 24 | dataset_type: hf 25 | hf_path: 'HuggingFaceH4/ultrachat_200k' 26 | train_split: 'train_sft' 27 | hf_eval_split: 'test_sft' 28 | train_data_columns: ['messages'] 29 | eval_data_columns: ['messages'] 30 | -------------------------------------------------------------------------------- /MaxText/configs/tpu_smoke_test.yml: -------------------------------------------------------------------------------- 1 | base_config: "base.yml" 2 | 3 | hardware: "tpu" 4 | async_checkpointing: false 5 | attention: autoselected 6 | base_emb_dim: 32 7 | base_num_query_heads: 8 8 | base_num_kv_heads: 8 9 | base_mlp_dim: 32 10 | base_num_decoder_layers: 8 11 | dataset_type: "synthetic" 12 | head_dim: 128 13 | per_device_batch_size: 1 14 | steps: 10 15 | -------------------------------------------------------------------------------- /MaxText/configs/trillium/gemma2_9b.sh: -------------------------------------------------------------------------------- 1 | # Gemma2-9b model. 2 | # This config will work out of the box for any number of trillium-256 slices. 3 | # 4 | # Command Flags: 5 | # OUTPUT_PATH (Required, unless base_output_directory is already set in base.yml) 6 | # DATASET_PATH (Required, unless dataset_path is already set in base.yml) 7 | # RUN_NAME (Required, unless run_name is already set in base.yml or running with XPK/GKE) 8 | # 9 | # Example to invoke this script: 10 | # bash MaxText/configs/trillium/gemma2_9b.sh RUN_NAME="" OUTPUT_PATH="gs://" DATASET_PATH="gs://" 11 | # 12 | 13 | 14 | # Stop execution if any command exits with error 15 | set -e 16 | 17 | export EXECUTABLE="train" # or train_compile 18 | export RUN_PREFLIGHT="true" 19 | 20 | # Set environment variables 21 | for ARGUMENT in "$@"; do 22 | IFS='=' read -r KEY VALUE <<< "$ARGUMENT" 23 | export "$KEY"="$VALUE" 24 | done 25 | 26 | # The setup accommodates two cases: 27 | # 1) Passing the 'RUN_NAME' variable at runtime 28 | # 2) Propagating the 'M_RUN_NAME' variable within an Airflow sweeping workflow 29 | if [ -n "$RUN_NAME" ]; 30 | then 31 | export M_RUN_NAME=$RUN_NAME 32 | fi 33 | 34 | # Set up network optimizations 35 | if [ "$RUN_PREFLIGHT" = "true" ]; then 36 | bash preflight.sh 37 | fi 38 | 39 | # Train 40 | export LIBTPU_INIT_ARGS="--xla_tpu_scoped_vmem_limit_kib=114688 --xla_tpu_use_minor_sharding_for_major_trivial_input=true --xla_tpu_relayout_group_size_threshold_for_reduce_scatter=1 --xla_tpu_assign_all_reduce_scatter_layout --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true" 41 | 42 | python3 -m MaxText.$EXECUTABLE MaxText/configs/base.yml model_name=gemma2-9b\ 43 | steps=15 per_device_batch_size=3 enable_checkpointing=false\ 44 | remat_policy=full ici_fsdp_transpose_parallelism=256 ici_fsdp_parallelism=-1\ 45 | max_target_length=8192 base_output_directory=$OUTPUT_PATH\ 46 | reuse_example_batch=1 dataset_type=synthetic gcs_metrics=true\ 47 | attention='flash' sa_block_q=2048 sa_block_q_dkv=2048 sa_block_q_dq=2048 48 | -------------------------------------------------------------------------------- /MaxText/configs/trillium/llama2_7b_4096.sh: -------------------------------------------------------------------------------- 1 | # Llama2-7b model. 2 | # This config will work out of the box for any number of trillium-256 slices. 3 | # 4 | # Command Flags: 5 | # OUTPUT_PATH (Required, unless base_output_directory is already set in base.yml) 6 | # DATASET_PATH (Required, unless dataset_path is already set in base.yml) 7 | # RUN_NAME (Required, unless run_name is already set in base.yml or running with XPK/GKE) 8 | # 9 | # Example to invoke this script: 10 | # bash MaxText/configs/trillium/llama2_7b_4096.sh RUN_NAME="" OUTPUT_PATH="gs://" DATASET_PATH="gs://" 11 | # 12 | 13 | 14 | # Stop execution if any command exits with error 15 | set -e 16 | 17 | export EXECUTABLE="train" # or train_compile 18 | export RUN_PREFLIGHT="true" 19 | 20 | # Set environment variables 21 | for ARGUMENT in "$@"; do 22 | IFS='=' read -r KEY VALUE <<< "$ARGUMENT" 23 | export "$KEY"="$VALUE" 24 | done 25 | 26 | # The setup accommodates two cases: 27 | # 1) Passing the 'RUN_NAME' variable at runtime 28 | # 2) Propagating the 'M_RUN_NAME' variable within an Airflow sweeping workflow 29 | if [ -n "$RUN_NAME" ]; 30 | then 31 | export M_RUN_NAME=$RUN_NAME 32 | fi 33 | 34 | # Set up network optimizations 35 | if [ "$RUN_PREFLIGHT" = "true" ]; then 36 | bash preflight.sh 37 | fi 38 | 39 | # Train 40 | export LIBTPU_INIT_ARGS="--xla_tpu_scoped_vmem_limit_kib=98304 --xla_enable_async_all_gather=true --xla_tpu_overlap_compute_collective_tc=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true" 41 | 42 | python3 -m MaxText.$EXECUTABLE MaxText/configs/base.yml model_name=llama2-7b\ 43 | steps=15 per_device_batch_size=12 enable_checkpointing=false\ 44 | remat_policy=full ici_fsdp_parallelism=-1\ 45 | max_target_length=4096 base_output_directory=$OUTPUT_PATH\ 46 | reuse_example_batch=1 dataset_type=synthetic gcs_metrics=true\ 47 | attention='flash' sa_block_q=1024 sa_block_q_dkv=2048 sa_block_q_dq=2048 48 | -------------------------------------------------------------------------------- /MaxText/configs/trillium/mixtral_8x7b.sh: -------------------------------------------------------------------------------- 1 | # Mixtral-8x7b model. 2 | # This config will work out of the box for any number of trillium-256 slices. 3 | # 4 | # Command Flags: 5 | # OUTPUT_PATH (Required, unless base_output_directory is already set in base.yml) 6 | # DATASET_PATH (Required, unless dataset_path is already set in base.yml) 7 | # RUN_NAME (Required, unless run_name is already set in base.yml or running with XPK/GKE) 8 | # 9 | # Example to invoke this script: 10 | # bash MaxText/configs/trillium/mixtral_8x7b.sh RUN_NAME="" OUTPUT_PATH="gs://" DATASET_PATH="gs://" 11 | # 12 | 13 | 14 | # Stop execution if any command exits with error 15 | set -e 16 | 17 | export EXECUTABLE="train" # or train_compile 18 | export RUN_PREFLIGHT="true" 19 | 20 | # Set environment variables 21 | for ARGUMENT in "$@"; do 22 | IFS='=' read -r KEY VALUE <<< "$ARGUMENT" 23 | export "$KEY"="$VALUE" 24 | done 25 | 26 | # The setup accommodates two cases: 27 | # 1) Passing the 'RUN_NAME' variable at runtime 28 | # 2) Propagating the 'M_RUN_NAME' variable within an Airflow sweeping workflow 29 | if [ -n "$RUN_NAME" ]; 30 | then 31 | export M_RUN_NAME=$RUN_NAME 32 | fi 33 | 34 | # Set up network optimizations 35 | if [ "$RUN_PREFLIGHT" = "true" ]; then 36 | bash preflight.sh 37 | fi 38 | 39 | # Train 40 | export LIBTPU_INIT_ARGS="--xla_tpu_scoped_vmem_limit_kib=81920 --xla_enable_async_all_gather=true --xla_tpu_overlap_compute_collective_tc=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true" 41 | 42 | python3 -m MaxText.$EXECUTABLE MaxText/configs/base.yml model_name=mixtral-8x7b\ 43 | steps=15 per_device_batch_size=32 enable_checkpointing=false\ 44 | remat_policy=full ici_fsdp_parallelism=-1\ 45 | max_target_length=1024 base_output_directory=$OUTPUT_PATH\ 46 | reuse_example_batch=1 dataset_type=synthetic gcs_metrics=true\ 47 | attention='flash' sa_block_q=1024 sa_block_q_dkv=2048 sa_block_q_dq=2048 48 | -------------------------------------------------------------------------------- /MaxText/configs/v4/README.md: -------------------------------------------------------------------------------- 1 | 16 | 17 | # High Performance Model Configs on TPU v4 18 | Expected performance results for 22B and 52B parameter models running on TPU v4: 19 | 20 | 21 | ### 22B model 22 | | Hardware | TFLOP/sec/chip | MFU | 23 | | ----------- | ---------------- | ----- | 24 | | 1x v4-128 | 156 | 56.7% | 25 | | 2x v4-128 | 152 | 55.2% | 26 | | 4x v4-128 | 149 | 54.3% | 27 | | 8x v4-128 | 146 | 53.2% | 28 | 29 | ### 52B model 30 | | Hardware | TFLOP/sec/chip | MFU | 31 | | ----------- | ---------------- | ----- | 32 | | 1x v4-384 | 154 | 56.0% | 33 | | 2x v4-384 | 162 | 58.9% | # this is quirkily higher than single slice because of choices made by the compiler, not for a fundamental reason. -------------------------------------------------------------------------------- /MaxText/configs/v5e/README.md: -------------------------------------------------------------------------------- 1 | 16 | 17 | # High Performance Model Training Configs on TPU v5e 18 | Expected performance results for 16B, 32B, 64B, and 128B parameter models running on TPU v5e: 19 | 20 | 21 | | Hardware | 16B TFLOP/sec/chip | 16B MFU | 32B TFLOP/sec/chip | 32B MFU | 64B TFLOP/sec/chip | 64B MFU | 128B TFLOP/sec/chip | 128B MFU | 22 | | ----------- | -----------------: | ------- | -----------------: | ------- | -----------------: | ------- | ------------------: | -------- | 23 | | 1x v5e-256 | 120 | 61.10% | 132 | 66.86% | 118 | 59.90% | 110 | 56.06% | 24 | | 2x v5e-256 | 117 | 59.37% | 128 | 64.81% | 112 | 56.66% | 110 | 55.82% | 25 | | 4x v5e-256 | 117 | 59.14% | 126 | 64.10% | 110 | 55.85% | 108 | 54.93% | 26 | | 8x v5e-256 | 115 | 58.27% | 125 | 63.67% | 108 | 54.96% | 104 | 52.93% | 27 | | 16x v5e-256 | 111 | 56.56% | 123 | 62.26% | 105 | 53.29% | 100 | 50.86% | 28 | | 32x v5e-256 | 108 | 54.65% | 119 | 60.40% | 99 | 50.18% | 91 | 46.25% | 29 | -------------------------------------------------------------------------------- /MaxText/configs/v5e/llama2_13b.sh: -------------------------------------------------------------------------------- 1 | # Llama2 13B model. 2 | # This config will work out of the box for any number of v5e-256 slices. 3 | # 4 | # Command Flags: 5 | # OUTPUT_PATH (Required, unless base_output_directory is already set in base.yml) 6 | # DATASET_PATH (Required, unless dataset_path is already set in base.yml) 7 | # RUN_NAME (Required, unless run_name is already set in base.yml or running with XPK/GKE) 8 | # 9 | # Example to invoke this script: 10 | # bash MaxText/configs/v5e/llama2_13b.sh RUN_NAME="" OUTPUT_PATH="gs://" DATASET_PATH="gs://" 11 | # 12 | # Example to AOT compile: 13 | # bash MaxText/configs/v5e/llama2_13b.sh EXECUTABLE=train_compile M_COMPILE_TOPOLOGY=v5e-256 M_COMPILE_TOPOLOGY_NUM_SLICES=2 14 | 15 | 16 | # Stop execution if any command exits with error 17 | set -e 18 | 19 | export EXECUTABLE="train" # or train_compile 20 | export RUN_PREFLIGHT="true" 21 | 22 | # Set environment variables 23 | for ARGUMENT in "$@"; do 24 | IFS='=' read -r KEY VALUE <<< "$ARGUMENT" 25 | export "$KEY"="$VALUE" 26 | done 27 | 28 | # The setup accommodates two cases: 29 | # 1) Passing the 'RUN_NAME' variable at runtime 30 | # 2) Propagating the 'M_RUN_NAME' variable within an Airflow sweeping workflow 31 | if [ -n "$RUN_NAME" ]; 32 | then 33 | export M_RUN_NAME=$RUN_NAME 34 | fi 35 | 36 | # Set up network optimizations 37 | if [ "$RUN_PREFLIGHT" = "true" ]; then 38 | bash preflight.sh 39 | fi 40 | 41 | # Train 42 | export LIBTPU_INIT_ARGS="--xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true" 43 | 44 | python3 -m MaxText.$EXECUTABLE MaxText/configs/base.yml model_name=llama2-13b\ 45 | base_output_directory=$OUTPUT_PATH dataset_path=${DATASET_PATH}\ 46 | tokenizer_path=assets/tokenizer.llama2 per_device_batch_size=8 remat_policy=qkv_proj_offloaded\ 47 | steps=15 enable_checkpointing=false use_iota_embed=true 48 | -------------------------------------------------------------------------------- /MaxText/configs/v5e/llama2_70b.sh: -------------------------------------------------------------------------------- 1 | # Llama2 70B model. 2 | # This config will work out of the box for any number of v5e-256 slices. 3 | # 4 | # Command Flags: 5 | # OUTPUT_PATH (Required, unless base_output_directory is already set in base.yml) 6 | # DATASET_PATH (Required, unless dataset_path is already set in base.yml) 7 | # RUN_NAME (Required, unless run_name is already set in base.yml or running with XPK/GKE) 8 | # 9 | # Example to invoke this script: 10 | # bash MaxText/configs/v5e/llama2_70b.sh RUN_NAME="" OUTPUT_PATH="gs://" DATASET_PATH="gs://" 11 | # 12 | # Example to AOT compile: 13 | # bash MaxText/configs/v5e/llama2_70b.sh EXECUTABLE=train_compile M_COMPILE_TOPOLOGY=v5e-256 M_COMPILE_TOPOLOGY_NUM_SLICES=2 14 | 15 | 16 | # Stop execution if any command exits with error 17 | set -e 18 | 19 | export EXECUTABLE="train" # or train_compile 20 | export RUN_PREFLIGHT="true" 21 | 22 | # Set environment variables 23 | for ARGUMENT in "$@"; do 24 | IFS='=' read -r KEY VALUE <<< "$ARGUMENT" 25 | export "$KEY"="$VALUE" 26 | done 27 | 28 | # The setup accommodates two cases: 29 | # 1) Passing the 'RUN_NAME' variable at runtime 30 | # 2) Propagating the 'M_RUN_NAME' variable within an Airflow sweeping workflow 31 | if [ -n "$RUN_NAME" ]; 32 | then 33 | export M_RUN_NAME=$RUN_NAME 34 | fi 35 | 36 | # Set up network optimizations 37 | if [ "$RUN_PREFLIGHT" = "true" ]; then 38 | bash preflight.sh 39 | fi 40 | 41 | # Train 42 | export LIBTPU_INIT_ARGS="--xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true" 43 | 44 | python3 -m MaxText.$EXECUTABLE MaxText/configs/base.yml model_name=llama2-70b\ 45 | base_output_directory=$OUTPUT_PATH dataset_path=${DATASET_PATH}\ 46 | tokenizer_path=assets/tokenizer.llama2 per_device_batch_size=2 remat_policy=qkv_proj_offloaded\ 47 | steps=15 enable_checkpointing=false use_iota_embed=true 48 | -------------------------------------------------------------------------------- /MaxText/configs/v5e/llama2_70b_v5e-16.yml: -------------------------------------------------------------------------------- 1 | base_config: "inference_jetstream.yml" 2 | 3 | # tensor = 8, autoregressive=2 4 | # per_device_batch_size=6 5 | # weight bf16, kv cache bf16 6 | 7 | model_name: "llama2-70b" 8 | sharding_strategy: "experimental" 9 | attention: 'dot_product' 10 | allow_split_physical_axes: True 11 | # Used to replicate the quantization scale to avoid the inefficient XLA fusion. 12 | replicate_quant_scale: True 13 | 14 | logical_axis_rules: [ 15 | ['embed', []], 16 | ['vocab', ['tensor', 'autoregressive']], 17 | ['activation_batch', []], 18 | ['activation_length', []], 19 | ['activation_embed', []], 20 | ['activation_vocab', ['tensor']], 21 | ['heads', ['tensor', 'autoregressive']], 22 | ['kv', []], 23 | # TODO: fix the wrong XLA ops for the following sharding. 24 | # ['q_heads', ['tensor', 'autoregressive']], 25 | # ['kv_head_dim', ['autoregressive']], 26 | ['q_heads', ['tensor']], 27 | ['kv_heads', ['tensor']], 28 | ['kv_head_dim', []], 29 | ['activation_prefill_kv_batch', []], 30 | ['activation_kv_batch', ['autoregressive']], 31 | ['activation_kv_heads', ['tensor']], 32 | ['activation_kv_head_dim', []], 33 | ['activation_heads', ['tensor']], 34 | ['activation_kv', ['tensor', 'autoregressive']], 35 | ['norm', []], 36 | ['mlp', ['tensor', 'autoregressive']], 37 | ['activation_mlp', ['tensor', 'autoregressive']], 38 | ['cache_batch_prefill', []], 39 | ['cache_batch', ['autoregressive']], 40 | ['cache_sequence', []], 41 | ['cache_heads', ['tensor']], 42 | ['cache_kv', []], 43 | ] 44 | -------------------------------------------------------------------------------- /MaxText/configs/v5e/llama2_7b.sh: -------------------------------------------------------------------------------- 1 | # Llama2 7B model. 2 | # This config will work out of the box for any number of v5e-256 slices. 3 | # 4 | # Command Flags: 5 | # OUTPUT_PATH (Required, unless base_output_directory is already set in base.yml) 6 | # DATASET_PATH (Required, unless dataset_path is already set in base.yml) 7 | # RUN_NAME (Required, unless run_name is already set in base.yml or running with XPK/GKE) 8 | # 9 | # Example to invoke this script: 10 | # bash MaxText/configs/v5e/llama2_7b.sh RUN_NAME="" OUTPUT_PATH="gs://" DATASET_PATH="gs://" 11 | # 12 | # Example to AOT compile: 13 | # bash MaxText/configs/v5e/llama2_7b.sh EXECUTABLE=train_compile M_COMPILE_TOPOLOGY=v5e-256 M_COMPILE_TOPOLOGY_NUM_SLICES=2 14 | 15 | 16 | # Stop execution if any command exits with error 17 | set -e 18 | 19 | export EXECUTABLE="train" # or train_compile 20 | export RUN_PREFLIGHT="true" 21 | 22 | # Set environment variables 23 | for ARGUMENT in "$@"; do 24 | IFS='=' read -r KEY VALUE <<< "$ARGUMENT" 25 | export "$KEY"="$VALUE" 26 | done 27 | 28 | # The setup accommodates two cases: 29 | # 1) Passing the 'RUN_NAME' variable at runtime 30 | # 2) Propagating the 'M_RUN_NAME' variable within an Airflow sweeping workflow 31 | if [ -n "$RUN_NAME" ]; 32 | then 33 | export M_RUN_NAME=$RUN_NAME 34 | fi 35 | 36 | # Set up network optimizations 37 | if [ "$RUN_PREFLIGHT" = "true" ]; then 38 | bash preflight.sh 39 | fi 40 | 41 | # Train 42 | export LIBTPU_INIT_ARGS="--xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true" 43 | 44 | python3 -m MaxText.$EXECUTABLE MaxText/configs/base.yml model_name=llama2-7b\ 45 | base_output_directory=$OUTPUT_PATH dataset_path=${DATASET_PATH}\ 46 | tokenizer_path=assets/tokenizer.llama2 per_device_batch_size=4 remat_policy=save_qkv_proj\ 47 | steps=15 enable_checkpointing=false use_iota_embed=true -------------------------------------------------------------------------------- /MaxText/configs/v5e/llama3_405b_v5e-64.yml: -------------------------------------------------------------------------------- 1 | base_config: "inference_jetstream.yml" 2 | 3 | # v5e-64 4 | # tensor = 8, autoregressive=8 5 | # per_device_batch_size=1 6 | # weight bf16, kv cache bf16 7 | 8 | model_name: "llama3.1-405b" 9 | sharding_strategy: "experimental" 10 | attention: 'dot_product' 11 | allow_split_physical_axes: True 12 | tokenizer_path: "assets/tokenizer_llama3.tiktoken" 13 | # Used to replicate the quantization scale to avoid the inefficient XLA fusion. 14 | replicate_quant_scale: True 15 | 16 | logical_axis_rules: [ 17 | ['embed', []], 18 | ['vocab', ['tensor', 'autoregressive']], 19 | ['activation_batch', []], 20 | ['activation_length', []], 21 | ['activation_embed', []], 22 | ['activation_vocab', ['tensor', 'autoregressive']], 23 | ['heads', ['tensor', 'autoregressive']], 24 | ['kv', []], 25 | ['kv_heads', ['tensor']], 26 | ['q_heads', ['tensor']], 27 | ['kv_head_dim', []], 28 | ['activation_prefill_kv_batch', []], 29 | ['activation_kv_batch', ['autoregressive']], 30 | ['activation_kv_heads', ['tensor']], 31 | ['activation_kv_head_dim', []], 32 | ['activation_heads', ['tensor']], 33 | ['activation_kv', ['tensor', 'autoregressive']], 34 | ['norm', []], 35 | ['mlp', ['tensor', 'autoregressive']], 36 | ['activation_mlp', ['tensor', 'autoregressive']], 37 | ['cache_batch_prefill', []], 38 | ['cache_batch', ['autoregressive']], 39 | ['cache_sequence', []], 40 | ['cache_heads', ['tensor']], 41 | ['cache_kv', []], 42 | ] 43 | -------------------------------------------------------------------------------- /MaxText/configs/v5e/llama3_70b_v5e-16.yml: -------------------------------------------------------------------------------- 1 | base_config: "inference_jetstream.yml" 2 | 3 | # tensor = 8, autoregressive=2 4 | # per_device_batch_size=6 5 | # weight bf16, kv cache bf16 6 | 7 | model_name: "llama3-70b" 8 | tokenizer_path: "assets/tokenizer_llama3.tiktoken" 9 | sharding_strategy: "experimental" 10 | attention: 'dot_product' 11 | allow_split_physical_axes: True 12 | # Used to replicate the quantization scale to avoid the inefficient XLA fusion. 13 | replicate_quant_scale: True 14 | 15 | logical_axis_rules: [ 16 | ['embed', []], 17 | ['vocab', ['tensor', 'autoregressive']], 18 | ['activation_batch', []], 19 | ['activation_length', []], 20 | ['activation_embed', []], 21 | ['activation_vocab', ['tensor']], 22 | ['heads', ['tensor', 'autoregressive']], 23 | ['kv', []], 24 | # TODO: fix the wrong XLA ops for the following sharding. 25 | # ['q_heads', ['tensor', 'autoregressive']], 26 | # ['kv_head_dim', ['autoregressive']], 27 | ['q_heads', ['tensor']], 28 | ['kv_heads', ['tensor']], 29 | ['kv_head_dim', []], 30 | ['activation_prefill_kv_batch', []], 31 | ['activation_kv_batch', ['autoregressive']], 32 | ['activation_kv_heads', ['tensor']], 33 | ['activation_kv_head_dim', []], 34 | ['activation_heads', ['tensor']], 35 | ['activation_kv', ['tensor', 'autoregressive']], 36 | ['norm', []], 37 | ['mlp', ['tensor', 'autoregressive']], 38 | ['activation_mlp', ['tensor', 'autoregressive']], 39 | ['cache_batch_prefill', []], 40 | ['cache_batch', ['autoregressive']], 41 | ['cache_sequence', []], 42 | ['cache_heads', ['tensor']], 43 | ['cache_kv', []], 44 | ] 45 | -------------------------------------------------------------------------------- /MaxText/configs/v5p/gpt3_175b/v5p_1024.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # GPT-3 175B Model. 3 | # Train GPT-3 175B on v5p-1024 slice. 4 | 5 | # Example to invoke this script: 6 | # bash MaxText/configs/v5p/gpt3_175b/v5p_1024.sh YOUR_RUN gs://YOUR_BUCKET" 7 | 8 | set -euox pipefail 9 | 10 | # Read arguments or use defaults from environment variables 11 | RUNNAME=${1:-${RUNNAME:-some-run}} 12 | BASE_OUTPUT_DIRECTORY=${2:-${BASE_OUTPUT_DIRECTORY:-gs://some-bucket}} 13 | 14 | chmod +x MaxText/configs/v5p/gpt3_175b/gpt3_175b_base.sh 15 | ./MaxText/configs/v5p/gpt3_175b/gpt3_175b_base.sh 4 "full" 1 64 8 "${RUNNAME}" "${BASE_OUTPUT_DIRECTORY}" 16 | -------------------------------------------------------------------------------- /MaxText/configs/v5p/gpt3_175b/v5p_12288.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # GPT-3 175B Model. 3 | # Train GPT-3 175B on v5p-12288 slice, with a custom topology of 8x16x48. 4 | 5 | # Example to invoke this script: 6 | # bash MaxText/configs/v5p/gpt3_175b/v5p_12288.sh YOUR_RUN gs://YOUR_BUCKET" 7 | 8 | set -euox pipefail 9 | 10 | # Read arguments or use defaults from environment variables 11 | RUNNAME=${1:-${RUNNAME:-some-run}} 12 | BASE_OUTPUT_DIRECTORY=${2:-${BASE_OUTPUT_DIRECTORY:-gs://some-bucket}} 13 | 14 | chmod +x MaxText/configs/v5p/gpt3_175b/gpt3_175b_base.sh 15 | ./MaxText/configs/v5p/gpt3_175b/gpt3_175b_base.sh 1 "minimal" 16 48 8 "${RUNNAME}" "${BASE_OUTPUT_DIRECTORY}" -------------------------------------------------------------------------------- /MaxText/configs/v5p/gpt3_175b/v5p_2048.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # GPT-3 175B Model. 3 | # Train GPT-3 175B on v5p-2048 slice. 4 | 5 | # Example to invoke this script: 6 | # bash MaxText/configs/v5p/gpt3_175b/v5p_2048.sh YOUR_RUN gs://YOUR_BUCKET" 7 | 8 | set -euox pipefail 9 | 10 | # Read arguments or use defaults from environment variables 11 | RUNNAME=${1:-${RUNNAME:-some-run}} 12 | BASE_OUTPUT_DIRECTORY=${2:-${BASE_OUTPUT_DIRECTORY:-gs://some-bucket}} 13 | 14 | chmod +x MaxText/configs/v5p/gpt3_175b/gpt3_175b_base.sh 15 | ./MaxText/configs/v5p/gpt3_175b/gpt3_175b_base.sh 2 "save_dot_except_mlpwi" 8 16 8 "${RUNNAME}" "${BASE_OUTPUT_DIRECTORY}" -------------------------------------------------------------------------------- /MaxText/configs/v5p/gpt3_175b/v5p_3072.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # GPT-3 175B Model. 3 | # Train GPT-3 175B on v5p-3072 slice. 4 | 5 | # Example to invoke this script: 6 | # bash MaxText/configs/v5p/gpt3_175b/v5p_3072.sh YOUR_RUN gs://YOUR_BUCKET" 7 | 8 | set -euox pipefail 9 | 10 | # Read arguments or use defaults from environment variables 11 | RUNNAME=${1:-${RUNNAME:-some-run}} 12 | BASE_OUTPUT_DIRECTORY=${2:-${BASE_OUTPUT_DIRECTORY:-gs://some-bucket}} 13 | 14 | chmod +x MaxText/configs/v5p/gpt3_175b/gpt3_175b_base.sh 15 | ./MaxText/configs/v5p/gpt3_175b/gpt3_175b_base.sh 2 "save_dot_except_mlpwi" 12 16 8 "${RUNNAME}" "${BASE_OUTPUT_DIRECTORY}" -------------------------------------------------------------------------------- /MaxText/configs/v5p/gpt3_175b/v5p_4096.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # GPT-3 175B Model. 3 | # Train GPT-3 175B on v5p-4096 slice, with a custom topology of 4x8x64. 4 | 5 | # Example to invoke this script: 6 | # bash MaxText/configs/v5p/gpt3_175b/v5p_4096.sh YOUR_RUN gs://YOUR_BUCKET" 7 | 8 | set -euox pipefail 9 | 10 | # Read arguments or use defaults from environment variables 11 | RUNNAME=${1:-${RUNNAME:-some-run}} 12 | BASE_OUTPUT_DIRECTORY=${2:-${BASE_OUTPUT_DIRECTORY:-gs://some-bucket}} 13 | 14 | chmod +x MaxText/configs/v5p/gpt3_175b/gpt3_175b_base.sh 15 | ./MaxText/configs/v5p/gpt3_175b/gpt3_175b_base.sh 1 "minimal" 4 64 8 "${RUNNAME}" "${BASE_OUTPUT_DIRECTORY}" -------------------------------------------------------------------------------- /MaxText/configs/v5p/gpt3_175b/v5p_8192.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # GPT-3 175B Model. 3 | # Train GPT-3 175B on v5p-8192 slice, with a custom topology of 8x16x32. 4 | 5 | # Example to invoke this script: 6 | # bash MaxText/configs/v5p/gpt3_175b/v5p_8192.sh YOUR_RUN gs://YOUR_BUCKET" 7 | 8 | set -euox pipefail 9 | 10 | # Read arguments or use defaults from environment variables 11 | RUNNAME=${1:-${RUNNAME:-some-run}} 12 | BASE_OUTPUT_DIRECTORY=${2:-${BASE_OUTPUT_DIRECTORY:-gs://some-bucket}} 13 | 14 | chmod +x MaxText/configs/v5p/gpt3_175b/gpt3_175b_base.sh 15 | ./MaxText/configs/v5p/gpt3_175b/gpt3_175b_base.sh 1 "minimal" 16 32 8 "${RUNNAME}" "${BASE_OUTPUT_DIRECTORY}" -------------------------------------------------------------------------------- /MaxText/experimental/rl/grpo.yml: -------------------------------------------------------------------------------- 1 | base_config: "base.yml" 2 | 3 | use_grpo: True 4 | train_data_columns: 'prompt' 5 | 6 | learning_rate: 1.e-6 7 | 8 | dataset_type: hf # we currently only support Huggingface input pipeline with GRPO. 9 | 10 | #TRL 11 | max_prefill_predict_length: 512 12 | max_target_length: 1024 13 | 14 | adam_b2: 0.999 15 | 16 | # Group Relative Policy Optimization (GRPO) 17 | num_generations: 4 18 | grpo_beta: 0.04 19 | 20 | decode_sampling_strategy: "weighted" 21 | decode_sampling_temperature: 0.9 22 | async_checkpointing: false 23 | -------------------------------------------------------------------------------- /MaxText/experimental/rl/grpo_trainer_test.yml: -------------------------------------------------------------------------------- 1 | base_config: "grpo.yml" 2 | 3 | model_name: "default" 4 | vocab_size: 128256 5 | max_target_length: 32 6 | per_device_batch_size: 1 7 | max_prefill_predict_length: 16 8 | dataset_type: "synthetic" 9 | dtype: "float32" 10 | matmul_precision: "high" 11 | logits_dot_in_fp32: True 12 | prompt: "Hello world this is a test" 13 | init_weights_seed: 42 14 | enable_dropout: False 15 | -------------------------------------------------------------------------------- /MaxText/globals.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2025 Google LLC 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import os.path 18 | 19 | PKG_DIR = os.path.dirname(os.path.abspath(__file__)) 20 | 21 | __all__ = ["PKG_DIR"] 22 | -------------------------------------------------------------------------------- /MaxText/inference/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2025 Google LLC 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | -------------------------------------------------------------------------------- /MaxText/inference/configs/multi_host/disaggregation/llama3_405b_v6e-16-16.yml: -------------------------------------------------------------------------------- 1 | base_config: "inference_jetstream.yml" 2 | 3 | model_name: "llama3.1-405b" 4 | sharding_strategy: "experimental" 5 | attention: 'dot_product' 6 | allow_split_physical_axes: True 7 | tokenizer_path: "assets/tokenizer_llama3.tiktoken" 8 | # Used to replicate the quantization scale to avoid the inefficient XLA fusion. 9 | replicate_quant_scale: True 10 | 11 | inference_server: "ExperimentalMaxtextDisaggregatedServer" 12 | 13 | logical_axis_rules: [ 14 | ['embed', []], 15 | ['vocab', ['tensor', 'autoregressive']], 16 | ['activation_batch', []], 17 | ['activation_length', []], 18 | ['activation_embed', []], 19 | ['activation_vocab', ['tensor', 'autoregressive']], 20 | ['heads', ['tensor', 'autoregressive']], 21 | ['kv', []], 22 | ['kv_heads', ['tensor']], 23 | ['q_heads', ['tensor']], 24 | ['kv_head_dim', []], 25 | ['activation_prefill_kv_batch', []], 26 | ['activation_kv_batch', ['autoregressive']], 27 | ['activation_kv_heads', ['tensor']], 28 | ['activation_kv_head_dim', []], 29 | ['activation_heads', ['tensor']], 30 | ['activation_kv', ['tensor', 'autoregressive']], 31 | ['norm', []], 32 | ['mlp', ['tensor', 'autoregressive']], 33 | ['activation_mlp', ['tensor', 'autoregressive']], 34 | ['cache_batch_prefill', []], 35 | ['cache_batch', []], 36 | ['cache_sequence', []], 37 | ['cache_heads', ['tensor']], 38 | ['cache_kv', []], 39 | ] 40 | -------------------------------------------------------------------------------- /MaxText/inference/configs/multi_host/interleaved/llama2_70b_v5e-16.yml: -------------------------------------------------------------------------------- 1 | base_config: "inference_jetstream.yml" 2 | 3 | # tensor = 8, autoregressive=2 4 | # per_device_batch_size=6 5 | # weight bf16, kv cache bf16 6 | 7 | model_name: "llama2-70b" 8 | sharding_strategy: "experimental" 9 | attention: 'dot_product' 10 | allow_split_physical_axes: True 11 | # Used to replicate the quantization scale to avoid the inefficient XLA fusion. 12 | replicate_quant_scale: True 13 | 14 | logical_axis_rules: [ 15 | ['embed', []], 16 | ['vocab', ['tensor', 'autoregressive']], 17 | ['activation_batch', []], 18 | ['activation_length', []], 19 | ['activation_embed', []], 20 | ['activation_vocab', ['tensor']], 21 | ['heads', ['tensor', 'autoregressive']], 22 | ['kv', []], 23 | # TODO: fix the wrong XLA ops for the following sharding. 24 | # ['q_heads', ['tensor', 'autoregressive']], 25 | # ['kv_head_dim', ['autoregressive']], 26 | ['q_heads', ['tensor']], 27 | ['kv_heads', ['tensor']], 28 | ['kv_head_dim', []], 29 | ['activation_prefill_kv_batch', []], 30 | ['activation_kv_batch', ['autoregressive']], 31 | ['activation_kv_heads', ['tensor']], 32 | ['activation_kv_head_dim', []], 33 | ['activation_heads', ['tensor']], 34 | ['activation_kv', ['tensor', 'autoregressive']], 35 | ['norm', []], 36 | ['mlp', ['tensor', 'autoregressive']], 37 | ['activation_mlp', ['tensor', 'autoregressive']], 38 | ['cache_batch_prefill', []], 39 | ['cache_batch', ['autoregressive']], 40 | ['cache_sequence', []], 41 | ['cache_heads', ['tensor']], 42 | ['cache_kv', []], 43 | ] 44 | -------------------------------------------------------------------------------- /MaxText/inference/configs/multi_host/interleaved/llama3_405b_v5e-64.yml: -------------------------------------------------------------------------------- 1 | base_config: "inference_jetstream.yml" 2 | 3 | # v5e-64 4 | # tensor = 8, autoregressive=8 5 | # per_device_batch_size=1 6 | # weight bf16, kv cache bf16 7 | 8 | model_name: "llama3.1-405b" 9 | sharding_strategy: "experimental" 10 | attention: 'dot_product' 11 | allow_split_physical_axes: True 12 | tokenizer_path: "assets/tokenizer_llama3.tiktoken" 13 | # Used to replicate the quantization scale to avoid the inefficient XLA fusion. 14 | replicate_quant_scale: True 15 | 16 | logical_axis_rules: [ 17 | ['embed', []], 18 | ['vocab', ['tensor', 'autoregressive']], 19 | ['activation_batch', []], 20 | ['activation_length', []], 21 | ['activation_embed', []], 22 | ['activation_vocab', ['tensor', 'autoregressive']], 23 | ['heads', ['tensor', 'autoregressive']], 24 | ['kv', []], 25 | ['kv_heads', ['tensor']], 26 | ['q_heads', ['tensor']], 27 | ['kv_head_dim', []], 28 | ['activation_prefill_kv_batch', []], 29 | ['activation_kv_batch', ['autoregressive']], 30 | ['activation_kv_heads', ['tensor']], 31 | ['activation_kv_head_dim', []], 32 | ['activation_heads', ['tensor']], 33 | ['activation_kv', ['tensor', 'autoregressive']], 34 | ['norm', []], 35 | ['mlp', ['tensor', 'autoregressive']], 36 | ['activation_mlp', ['tensor', 'autoregressive']], 37 | ['cache_batch_prefill', []], 38 | ['cache_batch', ['autoregressive']], 39 | ['cache_sequence', []], 40 | ['cache_heads', ['tensor']], 41 | ['cache_kv', []], 42 | ] 43 | -------------------------------------------------------------------------------- /MaxText/inference/configs/multi_host/interleaved/llama3_70b_v5e-16.yml: -------------------------------------------------------------------------------- 1 | base_config: "inference_jetstream.yml" 2 | 3 | # tensor = 8, autoregressive=2 4 | # per_device_batch_size=6 5 | # weight bf16, kv cache bf16 6 | 7 | model_name: "llama3-70b" 8 | tokenizer_path: "assets/tokenizer_llama3.tiktoken" 9 | sharding_strategy: "experimental" 10 | attention: 'dot_product' 11 | allow_split_physical_axes: True 12 | # Used to replicate the quantization scale to avoid the inefficient XLA fusion. 13 | replicate_quant_scale: True 14 | 15 | logical_axis_rules: [ 16 | ['embed', []], 17 | ['vocab', ['tensor', 'autoregressive']], 18 | ['activation_batch', []], 19 | ['activation_length', []], 20 | ['activation_embed', []], 21 | ['activation_vocab', ['tensor']], 22 | ['heads', ['tensor', 'autoregressive']], 23 | ['kv', []], 24 | # TODO: fix the wrong XLA ops for the following sharding. 25 | # ['q_heads', ['tensor', 'autoregressive']], 26 | # ['kv_head_dim', ['autoregressive']], 27 | ['q_heads', ['tensor']], 28 | ['kv_heads', ['tensor']], 29 | ['kv_head_dim', []], 30 | ['activation_prefill_kv_batch', []], 31 | ['activation_kv_batch', ['autoregressive']], 32 | ['activation_kv_heads', ['tensor']], 33 | ['activation_kv_head_dim', []], 34 | ['activation_heads', ['tensor']], 35 | ['activation_kv', ['tensor', 'autoregressive']], 36 | ['norm', []], 37 | ['mlp', ['tensor', 'autoregressive']], 38 | ['activation_mlp', ['tensor', 'autoregressive']], 39 | ['cache_batch_prefill', []], 40 | ['cache_batch', ['autoregressive']], 41 | ['cache_sequence', []], 42 | ['cache_heads', ['tensor']], 43 | ['cache_kv', []], 44 | ] 45 | -------------------------------------------------------------------------------- /MaxText/inference/gpu/README.md: -------------------------------------------------------------------------------- 1 | ## Benchmarking Scripts 2 | 3 | This directory contains scripts used for baseline benchmarking with fixed parameters. -------------------------------------------------------------------------------- /MaxText/inference/jetstream_pathways/README.md: -------------------------------------------------------------------------------- 1 | ## Build and upload MaxText + JetStream + Pathways Server image 2 | 3 | These instructions are to build the MaxText + JetStream + Pathways Server image, which calls an entrypoint script that invokes the [JetStream](https://github.com/AI-Hypercomputer/JetStream) inference server with the MaxText framework. 4 | 5 | ``` 6 | docker build -t jetstream-pathways . 7 | docker tag jetstream-pathways us-docker.pkg.dev/${PROJECT_ID}/jetstream/jetstream-pathways:latest 8 | docker push us-docker.pkg.dev/${PROJECT_ID}/jetstream/jetstream-pathways:latest 9 | ``` 10 | 11 | If you would like to change the version of MaxText or JetStream the image is built off of, change the `MAXTEXT_VERSION` / `JETSTREAM_VERSION` environment variable: 12 | ``` 13 | ENV MAXTEXT_VERSION= 14 | ENV JETSTREAM_VERSION= 15 | ``` -------------------------------------------------------------------------------- /MaxText/inference/jetstream_pathways/jetstream_pathways_entrypoint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Copyright 2024 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | cd /maxtext 18 | python3 -m MaxText.maxengine_server $@ 19 | -------------------------------------------------------------------------------- /MaxText/inference/maxengine_server/Dockerfile: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Ubuntu:22.04 16 | # Use Ubuntu 22.04 from Docker Hub. 17 | # https://hub.docker.com/_/ubuntu/tags?page=1&name=22.04 18 | FROM ubuntu:22.04 19 | 20 | ENV DEBIAN_FRONTEND=noninteractive 21 | ENV MAXTEXT_VERSION=main 22 | ENV JETSTREAM_VERSION=main 23 | 24 | RUN apt -y update && apt install -y --no-install-recommends \ 25 | ca-certificates \ 26 | git \ 27 | python3.10 \ 28 | python3-pip 29 | 30 | RUN update-alternatives --install \ 31 | /usr/bin/python3 python3 /usr/bin/python3.10 1 32 | 33 | RUN git clone https://github.com/AI-Hypercomputer/maxtext.git && \ 34 | git clone https://github.com/AI-Hypercomputer/JetStream.git 35 | 36 | RUN cd maxtext/ && \ 37 | git checkout ${MAXTEXT_VERSION} && \ 38 | bash setup.sh 39 | 40 | RUN cd /JetStream && \ 41 | git checkout ${JETSTREAM_VERSION} && \ 42 | python3 -m pip install -e . 43 | 44 | COPY maxengine_server_entrypoint.sh /usr/bin/ 45 | 46 | RUN chmod +x /usr/bin/maxengine_server_entrypoint.sh 47 | 48 | ENTRYPOINT ["/usr/bin/maxengine_server_entrypoint.sh"] -------------------------------------------------------------------------------- /MaxText/inference/maxengine_server/README.md: -------------------------------------------------------------------------------- 1 | ## Build and upload Maxengine Server image 2 | 3 | These instructions are to build the Maxengine Server image, which calls an entrypoint script that invokes the [JetStream](https://github.com/AI-Hypercomputer/JetStream) inference server with the MaxText framework. 4 | 5 | ``` 6 | docker build -t maxengine-server . 7 | docker tag maxengine-server us-docker.pkg.dev/${PROJECT_ID}/jetstream/maxengine-server:latest 8 | docker push us-docker.pkg.dev/${PROJECT_ID}/jetstream/maxengine-server:latest 9 | ``` 10 | 11 | If you would like to change the version of MaxText or JetStream the image is built off of, change the `MAXTEXT_VERSION` / `JETSTREAM_VERSION` environment variable: 12 | ``` 13 | ENV MAXTEXT_VERSION= 14 | ENV JETSTREAM_VERSION= 15 | ``` -------------------------------------------------------------------------------- /MaxText/inference/maxengine_server/maxengine_server_entrypoint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Copyright 2024 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | cd /maxtext 18 | python3 -m MaxText.maxengine_server \ 19 | MaxText/configs/base.yml $@ 20 | -------------------------------------------------------------------------------- /MaxText/inference_mlperf/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2025 Google LLC 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | -------------------------------------------------------------------------------- /MaxText/inference_mlperf/matmul/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2025 Google LLC 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | -------------------------------------------------------------------------------- /MaxText/inference_mlperf/matmul/matmul_dtypes.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """matrix multiplication data types""" 15 | 16 | 17 | import jax 18 | 19 | from MaxText.inference_mlperf.matmul import timing_util 20 | 21 | _PROFILE = False 22 | MATMUL_SIZES = [(250, 2048)] 23 | 24 | _INT4 = jax.numpy.int4 25 | _INT8 = jax.numpy.int8 26 | _DEFAULT = jax.numpy.bfloat16 27 | 28 | 29 | def f(X, Y): 30 | return jax.lax.batch_matmul(X, Y) 31 | 32 | 33 | f_jit = jax.jit(f) 34 | 35 | num_matmuls, matrix_size = MATMUL_SIZES[0] 36 | 37 | for dtypeA, dtypeB in [ 38 | (_INT4, _INT4), 39 | (_INT4, _INT8), 40 | (_INT8, _INT4), 41 | (_INT8, _INT8), 42 | (_INT8, _DEFAULT), 43 | (_DEFAULT, _DEFAULT), 44 | ]: 45 | A = jax.numpy.ones((num_matmuls, matrix_size, matrix_size), dtype=dtypeA) 46 | B = jax.numpy.ones((num_matmuls, matrix_size, matrix_size), dtype=dtypeB) 47 | 48 | print(f"A, B shape is {f(A, B).shape}. A dtype is {A.dtype}, B dtype is {B.dtype} and prod type is {f(A, B).dtype}") 49 | timing_util.simple_timeit(f_jit, A, B, task="matmul_" + str(matrix_size), enable_profile=_PROFILE) 50 | -------------------------------------------------------------------------------- /MaxText/inference_mlperf/matmul/timing_util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ Timing utility functions """ 15 | 16 | import datetime 17 | import os.path 18 | import tempfile 19 | 20 | import jax 21 | 22 | 23 | def simple_timeit(f, *args, tries=10, task=None, enable_profile=False): 24 | """Simple utility to time a function for multiple runs""" 25 | assert task is not None 26 | 27 | trace_name = f"{task}" # + '_' ]+ ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(10)) 28 | temp_dir = tempfile.gettempdir() 29 | trace_dir = os.path.join(temp_dir, trace_name) 30 | print(trace_dir) 31 | 32 | outcomes_ms = [] 33 | jax.block_until_ready(f(*args)) # warm it up! 34 | if enable_profile: 35 | jax.profiler.start_trace(trace_dir) 36 | for _ in range(tries): 37 | s = datetime.datetime.now() 38 | jax.block_until_ready(f(*args)) 39 | e = datetime.datetime.now() 40 | outcomes_ms.append(1000 * (e - s).total_seconds()) 41 | if enable_profile: 42 | jax.profiler.stop_trace() 43 | average_time_ms = sum(outcomes_ms) / len(outcomes_ms) 44 | print(f"Average time ms for mm for {task} is {round(average_time_ms, 3)}") 45 | return average_time_ms / 1000 46 | -------------------------------------------------------------------------------- /MaxText/inference_mlperf/requirements.txt: -------------------------------------------------------------------------------- 1 | transformers==4.31.0 2 | nltk==3.8.1 3 | evaluate==0.4.0 4 | absl-py==1.4.0 5 | rouge-score==0.1.2 6 | sentencepiece==0.1.99 7 | accelerate==0.21.0 8 | omegaconf 9 | -------------------------------------------------------------------------------- /MaxText/inference_mlperf/trillium/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2025 Google LLC 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | -------------------------------------------------------------------------------- /MaxText/inference_mlperf/user.conf: -------------------------------------------------------------------------------- 1 | # The format of this config file is 'key = value'. 2 | # The key has the format 'model.scenario.key'. Value is mostly int64_t. 3 | # Model maybe '*' as wildcard. In that case the value applies to all models. 4 | # All times are in milli seconds 5 | 6 | # Set performance_sample_count for each model. 7 | llama2-70b.*.performance_sample_count_override = 24576 8 | mixtral-8x7b.*.performance_sample_count_override = 15000 9 | *.Offline.min_duration = 600000 10 | 11 | 12 | # In Offline scenario, we always have one query. But LoadGen maps this to 13 | # min_sample_count internally in Offline scenario. If the dataset size is larger 14 | # than 24576 we limit the min_query_count to 24576 and otherwise we use 15 | # the dataset size as the limit 16 | llama2-70b.Offline.min_query_count = 24576 17 | mixtral-8x7b.Offline.min_query_count = 15000 18 | 19 | # These fields should be defined and overridden by user.conf. 20 | *.Offline.target_qps = 5.0 21 | -------------------------------------------------------------------------------- /MaxText/inference_mlperf/user100.conf: -------------------------------------------------------------------------------- 1 | # The format of this config file is 'key = value'. 2 | # The key has the format 'model.scenario.key'. Value is mostly int64_t. 3 | # Model maybe '*' as wildcard. In that case the value applies to all models. 4 | # All times are in milli seconds 5 | 6 | # Set performance_sample_count for each model. 7 | #llama2-70b.*.performance_sample_count_override = 24576 8 | llama2-70b.*.performance_sample_count_override = 100 9 | mixtral-8x7b.*.performance_sample_count_override = 100 10 | #*.Offline.min_duration = 600000 11 | *.Offline.min_duration = 60 12 | 13 | 14 | # In Offline scenario, we always have one query. But LoadGen maps this to 15 | # min_sample_count internally in Offline scenario. If the dataset size is larger 16 | # than 24576 we limit the min_query_count to 24576 and otherwise we use 17 | # the dataset size as the limit 18 | #llama2-70b.Offline.min_query_count = 24576 19 | llama2-70b.Offline.min_query_count = 100 20 | mixtral-8x7b.Offline.min_query_count = 100 21 | 22 | 23 | # These fields should be defined and overridden by user.conf. 24 | *.Offline.target_qps = 5.0 25 | -------------------------------------------------------------------------------- /MaxText/inference_mlperf/user5000.conf: -------------------------------------------------------------------------------- 1 | # The format of this config file is 'key = value'. 2 | # The key has the format 'model.scenario.key'. Value is mostly int64_t. 3 | # Model maybe '*' as wildcard. In that case the value applies to all models. 4 | # All times are in milli seconds 5 | 6 | # Set performance_sample_count for each model. 7 | #llama2-70b.*.performance_sample_count_override = 24576 8 | llama2-70b.*.performance_sample_count_override = 5000 9 | mixtral-8x7b.*.performance_sample_count_override = 5000 10 | #*.Offline.min_duration = 600000 11 | *.Offline.min_duration = 600 12 | 13 | 14 | # In Offline scenario, we always have one query. But LoadGen maps this to 15 | # min_sample_count internally in Offline scenario. If the dataset size is larger 16 | # than 24576 we limit the min_query_count to 24576 and otherwise we use 17 | # the dataset size as the limit 18 | #llama2-70b.Offline.min_query_count = 24576 19 | llama2-70b.Offline.min_query_count = 5000 20 | mixtral-8x7b.Offline.min_query_count = 5000 21 | 22 | 23 | # These fields should be defined and overridden by user.conf. 24 | *.Offline.target_qps = 5.0 25 | -------------------------------------------------------------------------------- /MaxText/input_pipeline/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2025 Google LLC 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | -------------------------------------------------------------------------------- /MaxText/kernels/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2025 Google LLC 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | -------------------------------------------------------------------------------- /MaxText/kernels/megablox/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Megablox kernel""" 15 | 16 | from MaxText.kernels.megablox.ops import gmm 17 | -------------------------------------------------------------------------------- /MaxText/kernels/megablox/common.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Common utilities for GMM kernels.""" 16 | 17 | import re 18 | 19 | import jax 20 | import jax.numpy as jnp 21 | 22 | 23 | def is_tpu() -> bool: 24 | return "TPU" in jax.devices()[0].device_kind 25 | 26 | 27 | def tpu_kind() -> str: 28 | """Query identification string for the currently attached TPU.""" 29 | return jax.devices()[0].device_kind 30 | 31 | 32 | _TPU_KIND_PATTERN = re.compile(r"TPU v(\d+)") 33 | 34 | 35 | def tpu_generation() -> int: 36 | """Generation number of the currently attached TPU.""" 37 | if version := _TPU_KIND_PATTERN.match(tpu_kind()): 38 | return int(version[1]) 39 | raise NotImplementedError("only TPU devices are supported") 40 | 41 | 42 | def supports_bfloat16_matmul() -> bool: 43 | """Does the currently attached CPU support bfloat16 inputs?""" 44 | return not is_tpu() or tpu_generation() >= 4 45 | 46 | 47 | def assert_is_supported_dtype(dtype: jnp.dtype) -> None: 48 | if dtype not in (jnp.bfloat16, jnp.float32): 49 | raise ValueError(f"Expected bfloat16 or float32 array but got {dtype}.") 50 | 51 | 52 | def select_input_dtype(lhs: jnp.ndarray, rhs: jnp.ndarray) -> jnp.dtype: 53 | """A type to which both input should be adapted to before dot product.""" 54 | # bf16xbf16 matmul is only supported since TPUv4 generation. In case of mixed 55 | # input precision, we need to convert bf16 argument to fp32 beforehand. 56 | if supports_bfloat16_matmul() and lhs.dtype == jnp.bfloat16 and rhs.dtype == jnp.bfloat16: 57 | return jnp.bfloat16 58 | else: 59 | return jnp.float32 60 | -------------------------------------------------------------------------------- /MaxText/layers/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2025 Google LLC 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | -------------------------------------------------------------------------------- /MaxText/layers/initializers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Initializers.""" 16 | 17 | from typing import Callable, Tuple, Union 18 | 19 | import jax 20 | 21 | from flax import linen as nn 22 | 23 | from MaxText.common_types import Array, DType, Shape, PRNGKey 24 | 25 | Initializer = Callable[[PRNGKey, Shape, DType], Array] 26 | InitializerAxis = Union[int, Tuple[int, ...]] 27 | NdInitializer = Callable[[PRNGKey, Shape, DType, InitializerAxis, InitializerAxis], Array] 28 | 29 | default_embed_init = nn.initializers.variance_scaling(1.0, "fan_in", "normal", out_axis=0) 30 | 31 | default_bias_init = jax.nn.initializers.constant(0.0) 32 | 33 | 34 | def nd_dense_init(scale, mode, distribution): 35 | """Initializer with in_axis, out_axis set at call time.""" 36 | 37 | def init_fn(key, shape, dtype, in_axis, out_axis): 38 | fn = jax.nn.initializers.variance_scaling(scale, mode, distribution, in_axis, out_axis) 39 | return fn(key, shape, dtype) 40 | 41 | return init_fn 42 | -------------------------------------------------------------------------------- /MaxText/layers/normalizations.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Normalization Layers.""" 16 | 17 | from typing import Any, Tuple, Optional 18 | 19 | from flax import linen as nn 20 | from jax import lax 21 | import jax 22 | import jax.numpy as jnp 23 | from MaxText import max_logging 24 | from MaxText.layers.initializers import Initializer 25 | 26 | 27 | class RMSNorm(nn.Module): 28 | """RMS normalization.""" 29 | 30 | epsilon: float = 1e-6 31 | dtype: Any = jnp.float32 32 | weight_dtype: Any = jnp.float32 33 | kernel_axes: Tuple[Optional[str], ...] = () 34 | scale_init: Initializer = nn.initializers.ones 35 | parameter_memory_host_offload: bool = False 36 | 37 | @nn.compact 38 | def __call__(self, x: jnp.ndarray) -> jnp.ndarray: 39 | """Applies layer normalization on the input.""" 40 | x = jnp.asarray(x, jnp.float32) 41 | features = x.shape[-1] 42 | mean2 = jnp.mean(lax.square(x), axis=-1, keepdims=True) 43 | y = jnp.asarray(x * lax.rsqrt(mean2 + self.epsilon), self.dtype) 44 | scale = self.param( 45 | "scale", 46 | nn.with_logical_partitioning(self.scale_init, self.kernel_axes), 47 | (features,), 48 | self.weight_dtype, 49 | ) 50 | # Move scale to device if parameter offloading is enabled 51 | if self.parameter_memory_host_offload: 52 | max_logging.log("normalizations.py: Moving scale parameter to device") 53 | scale = jax.device_put(scale, jax._src.sharding_impls.TransferToMemoryKind("device")) 54 | 55 | scale = jnp.asarray(scale, self.dtype) 56 | return y * scale 57 | -------------------------------------------------------------------------------- /MaxText/load_and_quantize_checkpoint.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """CLI utility for loading and quantizing a checkpoint.""" 16 | 17 | import os 18 | from typing import Sequence 19 | 20 | from absl import app 21 | 22 | import jax 23 | 24 | from MaxText import max_utils 25 | from MaxText import maxengine 26 | from MaxText import pyconfig 27 | 28 | 29 | def main(argv: Sequence[str]) -> None: 30 | jax.config.update("jax_default_prng_impl", "unsafe_rbg") 31 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" 32 | 33 | config = pyconfig.initialize(argv) 34 | validate_config(config) 35 | max_utils.print_system_information() 36 | 37 | engine = maxengine.MaxEngine(config) 38 | rng = jax.random.PRNGKey(1234) 39 | rng, rng_load_params = jax.random.split(rng) 40 | 41 | # load_params will load a checkpoint and quantize if the following parameters are set: 42 | # quantization=$valid_quantization_type \ 43 | # save_quantized_params_path=$gsbucket_path \ 44 | # checkpoint_is_quantized=false (default) 45 | engine.load_params(rng_load_params) 46 | 47 | 48 | def validate_config(config): 49 | assert config.load_full_state_path == "", "Operation on full states not supported! Convert to parameter checkpoint first." 50 | 51 | 52 | if __name__ == "__main__": 53 | app.run(main) 54 | -------------------------------------------------------------------------------- /MaxText/max_logging.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2023 Google LLC 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | """Stub for logging utilities. Right now just meant to avoid raw prints.""" 18 | 19 | 20 | def log(user_str): 21 | print(user_str, flush=True) 22 | -------------------------------------------------------------------------------- /MaxText/scratch_code/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2025 Google LLC 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | -------------------------------------------------------------------------------- /MaxText/scratch_code/gemma_7b.sh: -------------------------------------------------------------------------------- 1 | export M_LOAD_PARAMETERS_PATH=gs://runner-maxtext-logs/reroll5/checkpoints/10/items/ 2 | export M_PER_DEVICE_BATCH_SIZE=24 3 | export M_MAX_PREFILL_PREDICT_LENGTH=1024 4 | export M_MAX_TARGET_LENGTH=2048 5 | 6 | #python3 -m MaxText.decode MaxText/configs/base.yml tokenizer_path=assets/tokenizer.gemma run_name=runner_2024-03-06-04-17 steps=10 weight_dtype=bfloat16 async_checkpointing=false model_name=gemma-7b ici_fsdp_parallelism=1 ici_autoregressive_parallelism=-1 scan_layers=false 7 | 8 | python3 -m MaxText.maxengine_server MaxText/configs/base.yml tokenizer_path=assets/tokenizer.gemma run_name=runner_2024-03-06-04-17 steps=10 weight_dtype=bfloat16 async_checkpointing=false model_name=gemma-7b ici_fsdp_parallelism=1 ici_autoregressive_parallelism=-1 scan_layers=false 9 | -------------------------------------------------------------------------------- /MaxText/scratch_code/run_inference_microbenchmark.sh: -------------------------------------------------------------------------------- 1 | # llama2-7b 2 | python3 -m MaxText.inference_microbenchmark \ 3 | MaxText/configs/base.yml \ 4 | async_checkpointing=false \ 5 | attention=autoselected \ 6 | dataset_path=gs://maxtext-dataset \ 7 | ici_fsdp_parallelism=1 \ 8 | ici_autoregressive_parallelism=-1 \ 9 | max_prefill_predict_length=1024 \ 10 | max_target_length=2048 \ 11 | per_device_batch_size=16 \ 12 | quantization=int8 \ 13 | quantize_kvcache=True \ 14 | steps=10 \ 15 | scan_layers=false \ 16 | model_name=llama2-7b \ 17 | weight_dtype=bfloat16 \ 18 | tokenizer_path=assets/tokenizer.llama2 19 | -------------------------------------------------------------------------------- /MaxText/scratch_code/setup_transformer.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python3 -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu 3 | python3 -m pip install tokenizers -U 4 | python3 -m pip install transformers -U 5 | -------------------------------------------------------------------------------- /MaxText/test_assets/.gitignore: -------------------------------------------------------------------------------- 1 | golden_data_gemma-2b.jsonl 2 | golden_data_gemma2-27b.jsonl 3 | golden_data_gemma2-2b.jsonl 4 | golden_data_gemma2-9b.jsonl 5 | golden_data_gemma3-12b.jsonl 6 | golden_data_gemma3-27b.jsonl 7 | golden_data_gemma3-4b.jsonl 8 | golden_data_llama2-70b.jsonl 9 | golden_data_llama2-7b.jsonl 10 | golden_data_llama3-70b.jsonl 11 | golden_data_llama3-8b.jsonl 12 | golden_data_llama3.1-405b.jsonl 13 | golden_data_llama3.1-70b.jsonl 14 | golden_data_llama3.1-8b.jsonl 15 | golden_data_llama3.3-70b.jsonl 16 | golden_data_llama4-17b-16e.jsonl 17 | golden_data_mistral-7b.jsonl 18 | golden_data_mixtral-8x22b.jsonl 19 | golden_data_mixtral-8x7b.jsonl 20 | -------------------------------------------------------------------------------- /MaxText/test_assets/golden_data_grpo_default.jsonl: -------------------------------------------------------------------------------- 1 | {"maxtext_loss": 0.0, "input_ids": [128000, 9906, 1917, 420, 374, 264, 1296, 128000, 9906, 1917, 420, 374, 264, 1296, 0, 0, 0, 0], "generated_completions": [128000, 9906, 1917, 420, 374, 264, 1296, 62387, 64248, 94859, 15603, 112205, 13091, 32909, 90304, 116037, 114513, 35686, 26560, 91645, 85220, 105433, 75171, 0, 0, 0, 0], "maxtext_per_token_logps_no_ckpt_loading": [-0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -12.265657424926758, -12.379983901977539, -13.414685249328613, -15.10253620147705, -13.143253326416016, -11.190017700195312, -12.384369850158691, -0.0, -0.0, -0.0, -0.0], "avg_kl": 0.0, "avg_advantage": 0.0} 2 | -------------------------------------------------------------------------------- /MaxText/test_assets/golden_data_sft_default.jsonl: -------------------------------------------------------------------------------- 1 | {"data": {"messages": [{"role": "user", "content": "Hello, what is your name?"}, {"role": "assistant", "content": "I am a chatbot. How can I help?"}]}, "tokens": [1, 518, 25580, 29962, 15043, 29892, 825, 338, 596, 1024, 29973, 518, 29914, 25580, 29962, 306, 626, 263, 13563, 7451, 29889, 1128, 508, 306, 1371, 29973, 29871, 2], "attention_mask": [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "token_log_probs": [-10.455509185791016, -11.176178932189941, -11.196348190307617, -10.699331283569336, -10.40202808380127, -10.493494987487793, -10.981508255004883, -9.65949821472168, -10.184316635131836, -11.546117782592773, -11.033979415893555, -11.696186065673828, -10.98974609375, -10.65627670288086, -9.982662200927734, -11.240318298339844, -12.635238647460938, -9.757575988769531, -12.000450134277344, -11.398622512817383, -10.542476654052734, -10.546899795532227, -11.729068756103516, -10.480279922485352, -11.757697105407715, -10.342456817626953, -9.775711059570312]} 2 | -------------------------------------------------------------------------------- /MaxText/test_assets/test_image.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI-Hypercomputer/maxtext/2c82db1872ba4e80ea240e123229bc87ea69591c/MaxText/test_assets/test_image.jpg -------------------------------------------------------------------------------- /MaxText/tests/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2025 Google LLC 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | -------------------------------------------------------------------------------- /MaxText/tests/aot_hlo_identical_script.sh: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # https://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | 13 | # Bash script to run AOT and real runs (on a v4-8) 14 | # This needs to run via a bash script so the AOT and real runs each 15 | # initialize jax/XLA separtely (e.g. separate dump directories) 16 | # and we do not contaminate the XLA flags of the second run with the first. 17 | 18 | compile_dump_dir=$1 19 | train_dump_dir=$2 20 | custom_args=$3 21 | 22 | shared_args="configs/base.yml base_output_directory=gs://runner-maxtext-logs run_name=compile_equivalent_test dataset_path=gs://maxtext-dataset dataset_type=synthetic steps=5 enable_checkpointing=False $custom_args" 23 | aot_args="compile_topology=v4-8 compile_topology_num_slices=1" 24 | 25 | export XLA_FLAGS=--xla_dump_to=${compile_dump_dir} 26 | python3 -m MaxText.train_compile $shared_args $aot_args 27 | 28 | export XLA_FLAGS=--xla_dump_to=${train_dump_dir} 29 | python3 -m MaxText.train $shared_args 30 | 31 | -------------------------------------------------------------------------------- /MaxText/tests/hf_checkpoint_conversion_test.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2025 Google LLC 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | """ Tests for kernels """ 18 | 19 | import numpy as np 20 | from MaxText.max_utils import permute_to_match_maxtext_rope, unpermute_from_match_maxtext_rope 21 | import unittest 22 | 23 | 24 | class HFCheckpointConversionTest(unittest.TestCase): 25 | 26 | def test_huggingface_to_maxtext_back_to_huggingface_flow(self): 27 | base_num_query_heads = 16 28 | head_dim = 32 29 | wq = np.arange(base_num_query_heads * head_dim * base_num_query_heads * head_dim, dtype=np.float16).reshape( 30 | base_num_query_heads * head_dim, base_num_query_heads * head_dim 31 | ) 32 | wq1 = wq.transpose() 33 | wq2 = np.reshape(wq1, [base_num_query_heads * head_dim, base_num_query_heads, head_dim]) 34 | 35 | wq3 = permute_to_match_maxtext_rope(wq2) 36 | stack_shape = (1,) 37 | x = np.zeros(stack_shape + wq3.shape, dtype=np.float16) 38 | x[0, ...] = wq3 39 | x = np.transpose(x, axes=(1, 0, 2, 3)) 40 | 41 | x = x[:, 0, :, :] 42 | wq4 = unpermute_from_match_maxtext_rope(x, "llama3.1") 43 | wq5 = wq4.reshape(base_num_query_heads * head_dim, base_num_query_heads * head_dim) 44 | wq6 = wq5.transpose() 45 | 46 | if not np.array_equal(wq, wq6): 47 | print("Test failed: wq does not match wq6") 48 | 49 | if not np.array_equal(wq1, wq5): 50 | print("Test failed: wq1 does not match wq5") 51 | 52 | if not np.array_equal(wq2, wq4): 53 | print("Test failed: wq2 does not match wq4") 54 | 55 | 56 | if __name__ == "__main__": 57 | unittest.main() 58 | -------------------------------------------------------------------------------- /MaxText/tests/inference/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2025 Google LLC 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | -------------------------------------------------------------------------------- /MaxText/tests/inference/test_llama2_7b_bf16.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Define the arguments in an array 4 | args=( 5 | "-m" 6 | "MaxText.decode" 7 | "MaxText/configs/base.yml" 8 | "tokenizer_path=assets/tokenizer.llama2" 9 | "model_name=llama2-7b" 10 | "load_parameters_path=gs://runner-maxtext-logs/direct_generate_param_only_checkpoint_2024-06-11-04-13/checkpoints/0/items/" 11 | "checkpoint_is_quantized=false" 12 | "weight_dtype=bfloat16" 13 | "max_prefill_predict_length=16" 14 | "max_target_length=32" 15 | "ici_fsdp_parallelism=1" 16 | "ici_autoregressive_parallelism=1" 17 | "ici_tensor_parallelism=-1" 18 | "scan_layers=false" 19 | "per_device_batch_size=1" 20 | "attention=paged" 21 | "pagedattn_num_pages=64" 22 | "pagedattn_tokens_per_page=8" 23 | "pagedattn_pages_per_compute_block=4" 24 | ) 25 | 26 | # Execute the Python script with the arguments 27 | python3 "${args[@]}" 28 | 29 | -------------------------------------------------------------------------------- /MaxText/tests/inference/test_llama2_7b_int8.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Define the arguments in an array 4 | args=( 5 | "-m" 6 | "MaxText.decode" 7 | "MaxText/configs/base.yml" 8 | "tokenizer_path=assets/tokenizer.llama2" 9 | "model_name=llama2-7b" 10 | "load_parameters_path=gs://msingh-bkt/checkpoints/quant_llama2-7b-chat/20241120034012/int8_" 11 | "checkpoint_is_quantized=true" 12 | "quantization=int8" 13 | "weight_dtype=bfloat16" 14 | "max_prefill_predict_length=16" 15 | "max_target_length=32" 16 | "ici_fsdp_parallelism=1" 17 | "ici_autoregressive_parallelism=1" 18 | "ici_tensor_parallelism=-1" 19 | "scan_layers=false" 20 | "per_device_batch_size=1" 21 | "attention=paged" 22 | "pagedattn_num_pages=64" 23 | "pagedattn_tokens_per_page=8" 24 | "pagedattn_pages_per_compute_block=4" 25 | ) 26 | 27 | # Execute the Python script with the arguments 28 | python3 "${args[@]}" 29 | -------------------------------------------------------------------------------- /MaxText/tests/integration_tests/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2025 Google LLC 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | -------------------------------------------------------------------------------- /MaxText/tests/integration_tests/inference_microbenchmark_smoke_test.py: -------------------------------------------------------------------------------- 1 | """Copyright 2024 Google LLC 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | """ 15 | 16 | """ Smoke test for inference microbenchmark""" 17 | import jax 18 | import os.path 19 | import pytest 20 | import unittest 21 | from absl.testing import absltest 22 | 23 | from MaxText import pyconfig 24 | from MaxText.globals import PKG_DIR 25 | from MaxText.inference_microbenchmark import run_benchmarks 26 | 27 | 28 | class Inference_Microbenchmark(unittest.TestCase): 29 | """integration test for inference microbenchmark""" 30 | 31 | @pytest.mark.integration_test 32 | @pytest.mark.tpu_only 33 | def test(self): 34 | jax.config.update("jax_default_prng_impl", "unsafe_rbg") 35 | config = pyconfig.initialize( 36 | [ 37 | None, 38 | os.path.join(PKG_DIR, "configs", "tpu_smoke_test.yml"), 39 | rf"tokenizer_path={os.path.join(os.path.dirname(PKG_DIR), 'assets', 'tokenizer.llama2')}", 40 | "ici_autoregressive_parallelism=-1", 41 | "ici_fsdp_parallelism=1", 42 | "max_prefill_predict_length=1024", 43 | "max_target_length=2048", 44 | "scan_layers=false", 45 | "weight_dtype=bfloat16", 46 | "attention=dot_product", 47 | "skip_jax_distributed_system=True", 48 | ] 49 | ) 50 | run_benchmarks(config) 51 | 52 | 53 | if __name__ == "__main__": 54 | absltest.main() 55 | -------------------------------------------------------------------------------- /MaxText/tests/integration_tests/shmap_collective_matmul_test.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2024 Google LLC 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | """Integration test for pedagogical_examples/shmap_collective_matmul.py""" 18 | 19 | import os.path 20 | import sys 21 | 22 | import pytest 23 | 24 | from MaxText.globals import PKG_DIR 25 | 26 | sys.path.append(os.path.join(os.path.dirname(PKG_DIR), "pedagogical_examples")) 27 | 28 | # Uncomment the import when b/415022795 is fixed 29 | # from pedagogical_examples.shmap_collective_matmul import main 30 | 31 | 32 | @pytest.mark.skip(reason="Enable when b/415022795 is fixed") 33 | @pytest.mark.integration_test 34 | @pytest.mark.tpu_only 35 | def test_shmap_collective_matmul_example(): 36 | """Validate Pedagogical Example, Shmap_collective_matmul.""" 37 | # Uncomment main() assertion when b/415022795 is fixed 38 | # assert main() is True 39 | -------------------------------------------------------------------------------- /MaxText/tests/train_gpu_smoke_test.py: -------------------------------------------------------------------------------- 1 | """Copyright 2024 Google LLC 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | """ 15 | 16 | """ Smoke test """ 17 | import os 18 | import unittest 19 | 20 | from absl.testing import absltest 21 | 22 | from MaxText.train import main as train_main 23 | from MaxText.globals import PKG_DIR 24 | 25 | 26 | class Train(unittest.TestCase): 27 | """Smoke test for GPUs.""" 28 | 29 | def test_tiny_config(self): 30 | test_tmpdir = os.environ.get("TEST_TMPDIR") # pylint: disable=unused-variable 31 | train_main( 32 | [ 33 | None, 34 | os.path.join(PKG_DIR, "configs", "gpu_smoke_test.yml"), 35 | # pylint: disable=f-string-without-interpolation 36 | f"base_output_directory=gs://runner-maxtext-logs", 37 | "run_name=runner_test", 38 | r"dataset_path=gs://maxtext-dataset", 39 | "enable_checkpointing=False", 40 | rf"tokenizer_path={os.path.join(os.path.dirname(PKG_DIR), 'assets', 'tokenizer.llama2')}", 41 | "enable_goodput_recording=False", 42 | "enable_checkpoint_cloud_logger=False", 43 | "monitor_goodput=False", 44 | ] 45 | ) 46 | 47 | 48 | if __name__ == "__main__": 49 | absltest.main() 50 | -------------------------------------------------------------------------------- /MaxText/tests/train_smoke_test.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2023 Google LLC 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | """ Smoke test """ 18 | import os 19 | import unittest 20 | 21 | from absl.testing import absltest 22 | 23 | from MaxText.train import main as train_main 24 | from MaxText.globals import PKG_DIR 25 | 26 | 27 | class Train(unittest.TestCase): 28 | """Smoke test G3 only""" 29 | 30 | def test_tiny_config(self): 31 | test_tmpdir = os.environ.get("TEST_TMPDIR") # pylint: disable=unused-variable 32 | train_main( 33 | [ 34 | None, 35 | os.path.join(PKG_DIR, "configs", "base.yml"), 36 | # pylint: disable=f-string-without-interpolation 37 | f"base_output_directory=gs://runner-maxtext-logs", 38 | "run_name=runner_test", 39 | r"dataset_path=gs://maxtext-dataset", 40 | "base_emb_dim=8", 41 | "base_num_query_heads=4", 42 | "base_num_kv_heads=4", 43 | "base_mlp_dim=32", 44 | "base_num_decoder_layers=8", 45 | "head_dim=128", 46 | "per_device_batch_size=2", 47 | "max_target_length=1024", 48 | "dataset_type=synthetic", 49 | "steps=10", 50 | "enable_checkpointing=False", 51 | rf"tokenizer_path={os.path.join(os.path.dirname(PKG_DIR), 'assets', 'tokenizer.llama2')}", 52 | "enable_goodput_recording=False", 53 | "enable_checkpoint_cloud_logger=False", 54 | "monitor_goodput=False", 55 | ] 56 | ) 57 | 58 | 59 | if __name__ == "__main__": 60 | absltest.main() 61 | -------------------------------------------------------------------------------- /MaxText/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2025 Google LLC 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | -------------------------------------------------------------------------------- /assets/tokenizer: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI-Hypercomputer/maxtext/2c82db1872ba4e80ea240e123229bc87ea69591c/assets/tokenizer -------------------------------------------------------------------------------- /assets/tokenizer.gemma: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI-Hypercomputer/maxtext/2c82db1872ba4e80ea240e123229bc87ea69591c/assets/tokenizer.gemma -------------------------------------------------------------------------------- /assets/tokenizer.gemma3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI-Hypercomputer/maxtext/2c82db1872ba4e80ea240e123229bc87ea69591c/assets/tokenizer.gemma3 -------------------------------------------------------------------------------- /assets/tokenizer.llama2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI-Hypercomputer/maxtext/2c82db1872ba4e80ea240e123229bc87ea69591c/assets/tokenizer.llama2 -------------------------------------------------------------------------------- /assets/tokenizer.mistral-v1: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI-Hypercomputer/maxtext/2c82db1872ba4e80ea240e123229bc87ea69591c/assets/tokenizer.mistral-v1 -------------------------------------------------------------------------------- /assets/tokenizer.mistral-v3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI-Hypercomputer/maxtext/2c82db1872ba4e80ea240e123229bc87ea69591c/assets/tokenizer.mistral-v3 -------------------------------------------------------------------------------- /benchmarks/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2025 Google LLC 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | -------------------------------------------------------------------------------- /benchmarks/disruption_management/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2025 Google LLC 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | https://www.apache.org/licenses/LICENSE-2.0 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | """ 13 | -------------------------------------------------------------------------------- /benchmarks/mmlu/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2025 Google LLC 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | -------------------------------------------------------------------------------- /benchmarks/recipes/__init__.py: -------------------------------------------------------------------------------- 1 | """Copyright 2025 Google LLC 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | """ 15 | -------------------------------------------------------------------------------- /benchmarks/xpk_configs.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2024 Google LLC 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import dataclasses 18 | 19 | 20 | # This is needed to prevent circular imports. 21 | @dataclasses.dataclass 22 | class XpkClusterConfig: 23 | """Holds details related to a XPK cluster to run workloads on.""" 24 | 25 | cluster_name: str 26 | project: str 27 | zone: str 28 | device_type: str 29 | -------------------------------------------------------------------------------- /code_style.sh: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Clean up Python codes using Pylint & Pyink 16 | # Googlers: please run `sudo apt install pipx; pipx install pylint --force; pipx install pyink==23.10.0` in advance 17 | 18 | set -e # Exit immediately if any command fails 19 | 20 | FOLDERS_TO_FORMAT=("MaxText" "pedagogical_examples") 21 | LINE_LENGTH=$(grep -E "^max-line-length=" pylintrc | cut -d '=' -f 2) 22 | 23 | # Check for --check flag 24 | CHECK_ONLY_PYINK_FLAGS="" 25 | if [[ "$1" == "--check" ]]; then 26 | CHECK_ONLY_PYINK_FLAGS="--check --diff --color" 27 | fi 28 | 29 | for folder in "${FOLDERS_TO_FORMAT[@]}" 30 | do 31 | pyink "$folder" ${CHECK_ONLY_PYINK_FLAGS} --pyink-indentation=2 --line-length=${LINE_LENGTH} 32 | done 33 | 34 | for folder in "${FOLDERS_TO_FORMAT[@]}" 35 | do 36 | # pylint doesn't change files, only reports errors. 37 | pylint --disable R0401,R0917,W0201,W0613 "./$folder" 38 | done 39 | 40 | echo "Successfully clean up all codes." 41 | -------------------------------------------------------------------------------- /download_dataset.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Copyright 2023 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | # This script downloads c4/en/3.0.1 to your gcs bucket directory 18 | # Usage bash download_dataset.sh <> <> 19 | # Usage example: bash download_dataset.sh cloud-tpu-multipod-dev gs://maxtext-dataset 20 | 21 | function remove_trailing_slash { 22 | if [[ $1 =~ /$ ]]; then # Check if the path ends with a slash 23 | echo "${1::-1}" # Remove the last character 24 | else 25 | echo "$1" # Output the path as-is 26 | fi 27 | } 28 | 29 | gsutil -u $1 -m cp 'gs://allennlp-tensorflow-datasets/c4/en/3.0.1/*' $(remove_trailing_slash $2)/c4/en/3.0.1 30 | -------------------------------------------------------------------------------- /end_to_end/test_jdi.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | echo "Running the jax.distributed.initialize test"; 4 | python3 -c "import jax; jax.distributed.initialize()"; 5 | if [[ "$?" -eq "0" ]]; then 6 | echo "Test exit status 0, success!" 7 | else 8 | echo "Non-zero exit status, test failed!" 9 | fi 10 | -------------------------------------------------------------------------------- /end_to_end/test_mtc_phase_2_save_path.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -ex 3 | 4 | RUN_NAME=${1}_$(date +%Y-%m-%d-%H) 5 | OUTPUT_PATH=${2} 6 | DATASET_PATH=${3} 7 | 8 | export TPU_PREMAPPED_BUFFER_SIZE=20000014336 9 | export TPU_PREMAPPED_BUFFER_TRANSFER_THRESHOLD_BYTES=20000014336 10 | 11 | # Train and save checkpoint 12 | python3 -m MaxText.train MaxText/configs/base.yml remat_policy=full base_output_directory=$OUTPUT_PATH dataset_path=$DATASET_PATH \ 13 | steps=100 checkpoint_period=200 run_name=$RUN_NAME enable_emergency_checkpoint=true local_checkpoint_directory=/local local_checkpoint_period=20 use_replicator_service=True replicator_backup_interval_minutes=5 metrics_file='saved_metrics.txt' 14 | -------------------------------------------------------------------------------- /end_to_end/test_multi_tier_checkpointing.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -ex 3 | 4 | RUN_NAME=${1}_$(date +%Y-%m-%d-%H) 5 | OUTPUT_PATH=${2} 6 | DATASET_PATH=${3} 7 | 8 | export TPU_PREMAPPED_BUFFER_SIZE=20000014336 9 | export TPU_PREMAPPED_BUFFER_TRANSFER_THRESHOLD_BYTES=20000014336 10 | 11 | # Train and save checkpoint 12 | python3 -m MaxText.train MaxText/configs/base.yml remat_policy=full base_output_directory=$OUTPUT_PATH dataset_path=$DATASET_PATH \ 13 | steps=100 enable_emergency_checkpoint=true checkpoint_period=200 local_checkpoint_directory=/local local_checkpoint_period=20 run_name=$RUN_NAME metrics_file='saved_metrics.txt' 14 | 15 | # Retrieve checkpoint 16 | python3 -m MaxText.train MaxText/configs/base.yml remat_policy=full base_output_directory=$OUTPUT_PATH dataset_path=$DATASET_PATH \ 17 | steps=110 enable_emergency_checkpoint=true checkpoint_period=200 local_checkpoint_directory=/local local_checkpoint_period=20 run_name=$RUN_NAME metrics_file='restored_metrics.txt' 18 | 19 | 20 | python3 end_to_end/tpu/eval_assert.py checkpoint_save_restore metrics.txt learning/loss 21 | 22 | # Clean up ramdisk 23 | rm -rf /local/* 24 | -------------------------------------------------------------------------------- /end_to_end/tpu/gemma/Run_Gemma.md: -------------------------------------------------------------------------------- 1 | 16 | 17 | # Gemma 18 | [Gemma](https://ai.google.dev/gemma) is a family of lightweight, state-of-the art open models built from research and technology that we used to create the Gemini models. 19 | 20 | Following the instructions at [kaggle](https://www.kaggle.com/models/google/gemma/frameworks/maxText) will let you download Gemma model weights. You will have to consent to license for Gemma using your kaggle account's [API credentials](https://github.com/Kaggle/kaggle-api?tab=readme-ov-file#api-credentials). 21 | 22 | After downloading the weights run [convert_gemma_chkpt.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/MaxText/convert_gemma_chkpt.py), which converts the checkpoint to be compatible with MaxText and uploads them to a GCS bucket. You can run decode and finetuning using instructions mentioned in the test scripts at [end_to_end/tpu/gemma](https://github.com/AI-Hypercomputer/maxtext/tree/main/end_to_end/tpu/gemma). 23 | 24 | ## MaxText supports pretraining and finetuning with high performance 25 | 26 | Model Flop utilization for training on v5e and v5p TPUs. 27 | 28 | | Model | v5e-256 (bf16) | v5p-128 (bf16) | v5e-256 (int8) | v5p-128 (int8) | 29 | | -------- | -------------- | -------------- | -------------- | -------------- | 30 | | Gemma-2b | 58% | 55% | 64% | 68% | 31 | | Gemma-7b | 58% | 60% | 70% | 70% | 32 | -------------------------------------------------------------------------------- /end_to_end/tpu/llama3.1/405b/3_test_llama3.1_405b.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # This file tests the quantization of the Llama3.1-405b checkpoint, and assumes an unscanned checkpoint already exists. 4 | 5 | set -ex 6 | 7 | # We install torch CPU because the checkpoint conversion script does not need a TPU/GPU 8 | python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu 9 | 10 | # This is defined in 2_test_llama3.1_405b.sh 11 | export MODEL_VARIATION='llama3.1-405b' 12 | export UNSCANNED_CHECKPOINT=gs://maxtext-llama/llama3.1_405b_bf16/unscanned/0/items 13 | 14 | # Non-Googlers please remember to point `SAVE_QUANT_PARAMS_PATH` to the GCS bucket where you want to save your quantized checkpoint 15 | export SAVE_QUANT_PARAMS_PATH=gs://maxtext-llama/llama3.1_405b_int8 16 | 17 | export QUANTIZE_TYPE="int8" 18 | 19 | JAX_PLATFORMS=cpu python3 -m MaxText.load_and_quantize_checkpoint \ 20 | MaxText/configs/base.yml \ 21 | tokenizer_path=assets/tokenizer_llama3.tiktoken \ 22 | tokenizer_type=tiktoken \ 23 | load_parameters_path=${UNSCANNED_CHECKPOINT} \ 24 | max_prefill_predict_length=1024 \ 25 | max_target_length=2048 \ 26 | model_name=${MODEL_VARIATION} \ 27 | ici_fsdp_parallelism=1 \ 28 | ici_autoregressive_parallelism=1 \ 29 | ici_tensor_parallelism=-1 \ 30 | scan_layers=false \ 31 | weight_dtype=bfloat16 \ 32 | per_device_batch_size=1 \ 33 | attention=dot_product \ 34 | quantization=${QUANTIZE_TYPE} \ 35 | save_quantized_params_path=${SAVE_QUANT_PARAMS_PATH} \ 36 | async_checkpointing=false 37 | -------------------------------------------------------------------------------- /end_to_end/tpu/llama_finetuning_test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # This script is designed for internal use within Google. External users can adapt it by: 4 | # - Updating GCS paths (gs://) to your accessible locations. 5 | # - Using the checkpoint generated from train.py or available one in open source (https://llama.meta.com/llama-downloads/). 6 | 7 | set -e 8 | idx=$(date +%Y-%m-%d-%H-%M) 9 | 10 | base_ckpt_path=gs://maxtext-llama/test/2024-01-15-06-49/decode-ckpt-maxtext/0/items 11 | BASE_OUTPUT_DIRECTORY=gs://runner-maxtext-logs 12 | DATASET_PATH=gs://maxtext-dataset 13 | 14 | export LOSS_THRESHOLD=2.5 15 | 16 | python3 -m MaxText.train MaxText/configs/base.yml run_name=runner_direct_${idx} base_output_directory=${BASE_OUTPUT_DIRECTORY} load_parameters_path=${base_ckpt_path} model_name='llama2-7b' dataset_path=${DATASET_PATH} async_checkpointing=false model_name='llama2-7b' ici_tensor_parallelism=4 steps=10 per_device_batch_size=.25 metrics_file='metrics.txt' 17 | 18 | # Assert training loss is smaller than input LOSS_THRESHOLD 19 | python3 end_to_end/tpu/eval_assert.py final_loss metrics.txt $LOSS_THRESHOLD -------------------------------------------------------------------------------- /end_to_end/tpu/mixtral/Run_Mixtral.md: -------------------------------------------------------------------------------- 1 | 16 | 17 | # Mixtral 18 | 19 | [Mixtral](https://mistral.ai/news/mixtral-of-experts/) is a state-of-the-art AI model developed by Mistral AI, utilizing a sparse mixture-of-experts (MoE) architecture. 20 | 21 | 22 | To get started, follow the instructions at [mistral-inference](https://github.com/mistralai/mistral-inference) to download the model. Once downloaded, run [llama_or_mistral_ckpt.py](../../../MaxText/llama_or_mistral_ckpt.py) to convert the checkpoint for MaxText compatibility. You can then proceed with decoding, pretraining, and finetuning. You could find Mixtral 8x7B example in the [end_to_end/tpu/mixtral/8x7b](../mixtral/8x7b) test scripts. 23 | 24 | 25 | Additionally, Mixtral integrates with [MegaBlocks](https://arxiv.org/abs/2211.15841), an efficient dropless MoE strategy, which can be activated by setting both sparse_matmul and megablox flags to True (default). 26 | 27 | 28 | ## MaxText supports pretraining and finetuning with high performance 29 | 30 | Model Flop utilization for training on v5p TPUs. 31 | 32 | | Model size | Accelerator type | TFLOP/chip/sec | Model flops utilization (MFU) | 33 | | ------------ | -------------- | -------------- | -------------- | 34 | | Mixtral 8X7B | v5p-128 | 251.94 | 54.89% | 35 | 36 | 37 | -------------------------------------------------------------------------------- /end_to_end/tpu/test_checkpoint_resharding.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -ex 3 | 4 | RUN_NAME=${1}_$(date +%Y-%m-%d-%H) 5 | OUTPUT_PATH=${2} 6 | DATASET_PATH=${3} 7 | 8 | # Train and save checkpoint - sharded with DCN Data Parallelism + ICI FSDP Parallelism 9 | python3 -m MaxText.train MaxText/configs/base.yml run_name=$RUN_NAME steps=101\ 10 | metrics_file='saved_metrics.txt' checkpoint_period=20 base_output_directory=$OUTPUT_PATH dataset_path=$DATASET_PATH\ 11 | dcn_data_parallelism=2 dcn_fsdp_parallelism=1 ici_fsdp_parallelism=4 ici_tensor_parallelism=1 collect_stack_trace=False 12 | 13 | # Retrieve checkpoint - sharded with DCN Data Parallelism + ICI FSDP + Tensor Parallelism 14 | python3 -m MaxText.train MaxText/configs/base.yml run_name=$RUN_NAME steps=102\ 15 | metrics_file='restored_metrics.txt' base_output_directory=$OUTPUT_PATH dataset_path=$DATASET_PATH\ 16 | dcn_data_parallelism=2 dcn_fsdp_parallelism=1 ici_fsdp_parallelism=2 ici_tensor_parallelism=2 collect_stack_trace=False 17 | 18 | python3 end_to_end/tpu/eval_assert.py checkpoint_save_restore metrics.txt learning/loss 19 | -------------------------------------------------------------------------------- /end_to_end/tpu/test_determinism.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -ex 3 | 4 | RUN_NAME=${1}_$(date +%Y-%m-%d-%H) 5 | OUTPUT_PATH=${2} 6 | DATASET_PATH=${3} 7 | DATASET_TYPE=${4} 8 | 9 | if [ "$DATASET_TYPE" == "grain" ] 10 | then 11 | EVAL_METRICS=grain_checkpoint_save_restore 12 | echo "Using grain dataset type" 13 | echo "Mounting $DATASET_PATH to /tmp/gcsfuse/" 14 | bash setup_gcsfuse.sh DATASET_GCS_BUCKET=$DATASET_PATH MOUNT_PATH=/tmp/gcsfuse/ 15 | DATASET_PATH=/tmp/gcsfuse/ 16 | CMD_DATA=" dataset_type=grain grain_train_files=/tmp/gcsfuse/array-record/c4/en/3.0.1/c4-train.array_record*" 17 | fi 18 | 19 | #Train 20 | CMD1="python3 -m MaxText.train MaxText/configs/base.yml run_name=${RUN_NAME}_1 steps=5 metrics_file=run_1_metrics.txt\ 21 | enable_checkpointing=False enable_data_shuffling=True enable_dropout=False base_output_directory=$OUTPUT_PATH dataset_path=$DATASET_PATH" 22 | CMD1+=$CMD_DATA 23 | 24 | 25 | CMD2="python3 -m MaxText.train MaxText/configs/base.yml run_name=${RUN_NAME}_2 steps=5 metrics_file=run_2_metrics.txt\ 26 | enable_checkpointing=False enable_data_shuffling=True enable_dropout=False base_output_directory=$OUTPUT_PATH dataset_path=$DATASET_PATH" 27 | CMD2+=$CMD_DATA 28 | 29 | $CMD1 30 | $CMD2 31 | python3 end_to_end/tpu/eval_assert.py determinism metrics.txt learning/loss 32 | -------------------------------------------------------------------------------- /end_to_end/tpu/test_dpo.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -xe 4 | 5 | RUN_NAME=dpo_$(date +%Y-%m-%d-%H-%M-%S) 6 | 7 | # get latest converted Gemma2 2B checkpoint from internal GCS bucket 8 | export GEMMA_2B_CKPT_PATH=$(gcloud storage ls gs://maxtext-gemma/gemma2/2b | sort -r | head -1) 9 | LOGS="gs://maxtext-external/logs" 10 | 11 | # tfds pipeline 12 | python3 -m MaxText.train MaxText/configs/dpo.yml tokenizer_path=assets/tokenizer.gemma \ 13 | run_name="$RUN_NAME-tfds" model_name=gemma2-2b base_output_directory=${LOGS} \ 14 | load_parameters_path=${GEMMA_2B_CKPT_PATH}/0/items \ 15 | per_device_batch_size=0.5 allow_split_physical_axes=True \ 16 | ici_data_parallelism=2 ici_tensor_parallelism=2 ici_fsdp_parallelism=1 17 | 18 | # grain pipeline 19 | mkdir -p /tmp/anthropic_rlhf || true 20 | gcloud storage cp -r gs://maxtext-dataset/dpo/anthropic_rlhf/array_record /tmp/anthropic_rlhf 21 | python3 -m MaxText.train MaxText/configs/dpo.yml tokenizer_path=assets/tokenizer.gemma \ 22 | run_name="$RUN_NAME-grain" model_name=gemma2-2b base_output_directory=${LOGS} \ 23 | load_parameters_path=${GEMMA_2B_CKPT_PATH}/0/items \ 24 | dataset_type=grain grain_worker_count=16 \ 25 | grain_train_files='/tmp/anthropic_rlhf/array_record/anthropic_rlhf_tfds-train.array_record*' \ 26 | grain_eval_files='/tmp/anthropic_rlhf/array_record/anthropic_rlhf_tfds-test.array_record*' \ 27 | per_device_batch_size=0.5 allow_split_physical_axes=True \ 28 | ici_data_parallelism=2 ici_tensor_parallelism=2 ici_fsdp_parallelism=1 29 | 30 | # hf pipeline 31 | python3 -m MaxText.train MaxText/configs/dpo.yml tokenizer_path='google/gemma-2-2b-it' \ 32 | run_name="$RUN_NAME-grain" model_name=gemma2-2b base_output_directory=${LOGS} \ 33 | load_parameters_path=${GEMMA_2B_CKPT_PATH}/0/items \ 34 | dataset_type=hf hf_access_token=$HF_TOKEN hf_path='Anthropic/hh-rlhf' \ 35 | per_device_batch_size=0.5 allow_split_physical_axes=True ici_tensor_parallelism=2 \ 36 | ici_data_parallelism=2 ici_tensor_parallelism=2 ici_fsdp_parallelism=1 37 | -------------------------------------------------------------------------------- /end_to_end/tpu/test_gpt3.sh: -------------------------------------------------------------------------------- 1 | set -euox pipefail 2 | 3 | TIMESTAMP=$(date +%Y%m%d-%H%M) 4 | export PAXML_CKPT_PATH=gs://maxtext-gpt3/ckpt_test/paxml/checkpoints/checkpoint_00000000/state 5 | export OUTPUT_PATH=gs://maxtext-gpt3/tests 6 | export RUN_NAME=test_${TIMESTAMP} 7 | 8 | # convert gpt3-52k model 9 | python3 -m MaxText.convert_gpt3_ckpt_from_paxml --paxml-ckpt-path=${PAXML_CKPT_PATH} --maxtext-model-name=gpt3-52k --run-name=${RUN_NAME} --base-output-directory=${OUTPUT_PATH} 10 | 11 | # Run gpt3-52k with the converted ckpt 12 | python3 -m MaxText.train MaxText/configs/base.yml run_name=${RUN_NAME} model_name=gpt3-52k\ 13 | steps=10 per_device_batch_size=6 enable_checkpointing=true async_checkpointing=false\ 14 | remat_policy=full max_target_length=2048 base_output_directory=${OUTPUT_PATH}\ 15 | dataset_type=synthetic 16 | -------------------------------------------------------------------------------- /end_to_end/tpu/test_tflops.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -ex 3 | 4 | USER=${1} 5 | TFLOP_THRESHOLD=${2} 6 | OUTPUT_PATH=${3} 7 | DATASET_PATH=${4} 8 | 9 | 10 | if [ -z ${5} ] 11 | then 12 | RUN_NAME=${USER}_$(date +%Y-%m-%d-%H-%M-%S) 13 | else 14 | RUN_NAME=${5}_$(date +%Y-%m-%d-%H) 15 | fi 16 | 17 | #Train 18 | python3 -m MaxText.train MaxText/configs/base.yml run_name=$RUN_NAME\ 19 | steps=150 reuse_example_batch=1 remat_policy='full' profiler=xplane enable_checkpointing=False metrics_file='metrics.txt'\ 20 | base_output_directory=$OUTPUT_PATH dataset_path=$DATASET_PATH log_period=150 21 | 22 | python3 end_to_end/tpu/eval_assert.py metrics_average metrics.txt $TFLOP_THRESHOLD perf/per_device_tflops_per_sec 23 | -------------------------------------------------------------------------------- /end_to_end/tpu/test_tflops_16b_params.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | echo "Running test_tflops_16b_params.sh" 3 | 4 | # Command Flags: 5 | # OUTPUT_PATH (Required, unless base_output_directory is already set in base.yml) 6 | # DATASET_PATH (Required, unless dataset_path is already set in base.yml) 7 | # RUN_NAME (Required, unless run_name is already set in base.yml or running with XPK/GKE) 8 | # PLATFORM (Optional, can be "gke" or "gce", default is "gce") 9 | # TFLOP_THRESHOLD (Optional, default is 0 ) 10 | # 11 | # Example to invoke this script: 12 | # bash end_to_end/tpu/test_tflops_16b_params.sh RUN_NAME=""" OUTPUT_PATH="gs://" DATASET_PATH="gs://" PLATFORM="gke" TFLOP_THRESHOLD=0 13 | 14 | # Stop execution if any command exits with error 15 | set -ex 16 | 17 | export TFLOP_THRESHOLD=0 18 | export PLATFORM="gce" 19 | 20 | # Set environment variables 21 | for ARGUMENT in "$@"; do 22 | IFS='=' read -r KEY VALUE <<< "$ARGUMENT" 23 | export "$KEY"="$VALUE" 24 | done 25 | 26 | # Set up network optimizations 27 | bash preflight.sh PLATFORM=$PLATFORM 28 | 29 | # Train 30 | export LIBTPU_INIT_ARGS="--xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true" 31 | python3 -m MaxText.train MaxText/configs/base.yml run_name=$RUN_NAME\ 32 | steps=150 per_device_batch_size=6 enable_checkpointing=false remat_policy=full\ 33 | max_target_length=2048 metrics_file='metrics.txt' base_output_directory=$OUTPUT_PATH\ 34 | dataset_path=$DATASET_PATH log_period=150 global_parameter_scale=16 35 | 36 | # Assert TFLOP/s 37 | python3 end_to_end/tpu/eval_assert.py metrics_average metrics.txt $TFLOP_THRESHOLD perf/per_device_tflops_per_sec 38 | -------------------------------------------------------------------------------- /end_to_end/tpu/test_tflops_32b_params.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | echo "Running test_tflops_32b_params.sh" 3 | 4 | # Command Flags: 5 | # OUTPUT_PATH (Required, unless base_output_directory is already set in base.yml) 6 | # DATASET_PATH (Required, unless dataset_path is already set in base.yml) 7 | # RUN_NAME (Required, unless run_name is already set in base.yml or running with XPK/GKE) 8 | # PLATFORM (Optional, can be "gke" or "gce", default is "gce") 9 | # TFLOP_THRESHOLD (Optional, default is 0 ) 10 | # 11 | # Example to invoke this script: 12 | # bash end_to_end/tpu/test_tflops_32b_params.sh RUN_NAME=""" OUTPUT_PATH="gs://" DATASET_PATH="gs://" PLATFORM="gke" TFLOP_THRESHOLD=0 13 | 14 | # Stop execution if any command exits with error 15 | set -ex 16 | 17 | export TFLOP_THRESHOLD=0 18 | export PLATFORM="gce" 19 | 20 | # Set environment variables 21 | for ARGUMENT in "$@"; do 22 | IFS='=' read -r KEY VALUE <<< "$ARGUMENT" 23 | export "$KEY"="$VALUE" 24 | done 25 | 26 | # Set up network optimizations 27 | bash preflight.sh PLATFORM=$PLATFORM 28 | 29 | # Train 30 | export LIBTPU_INIT_ARGS="--xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true" 31 | python3 -m MaxText.train MaxText/configs/base.yml run_name=$RUN_NAME\ 32 | steps=150 per_device_batch_size=4 enable_checkpointing=false remat_policy=full\ 33 | max_target_length=2048 metrics_file='metrics.txt' base_output_directory=$OUTPUT_PATH\ 34 | dataset_path=$DATASET_PATH log_period=150 global_parameter_scale=32 35 | 36 | # Assert TFLOP/s 37 | python3 end_to_end/tpu/eval_assert.py metrics_average metrics.txt $TFLOP_THRESHOLD perf/per_device_tflops_per_sec 38 | -------------------------------------------------------------------------------- /end_to_end/tpu/test_tflops_64b_params.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | echo "Running test_tflops_64b_params.sh" 3 | 4 | # Command Flags: 5 | # OUTPUT_PATH (Required, unless base_output_directory is already set in base.yml) 6 | # DATASET_PATH (Required, unless dataset_path is already set in base.yml) 7 | # RUN_NAME (Required, unless run_name is already set in base.yml or running with XPK/GKE) 8 | # PLATFORM (Optional, can be "gke" or "gce", default is "gce") 9 | # TFLOP_THRESHOLD (Optional, default is 0 ) 10 | # 11 | # Example to invoke this script: 12 | # bash end_to_end/tpu/test_tflops_64b_params.sh RUN_NAME=""" OUTPUT_PATH="gs://" DATASET_PATH="gs://" PLATFORM="gke" TFLOP_THRESHOLD=0 13 | 14 | # Stop execution if any command exits with error 15 | set -ex 16 | 17 | export TFLOP_THRESHOLD=0 18 | export PLATFORM="gce" 19 | 20 | # Set environment variables 21 | for ARGUMENT in "$@"; do 22 | IFS='=' read -r KEY VALUE <<< "$ARGUMENT" 23 | export "$KEY"="$VALUE" 24 | done 25 | 26 | # Set up network optimizations 27 | bash preflight.sh PLATFORM=$PLATFORM 28 | 29 | # Train 30 | export LIBTPU_INIT_ARGS="--xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true" 31 | python3 -m MaxText.train MaxText/configs/base.yml run_name=$RUN_NAME\ 32 | steps=150 per_device_batch_size=2 enable_checkpointing=false remat_policy=full\ 33 | max_target_length=2048 metrics_file='metrics.txt' base_output_directory=$OUTPUT_PATH\ 34 | dataset_path=$DATASET_PATH log_period=150 global_parameter_scale=64 35 | 36 | # Assert TFLOP/s 37 | python3 end_to_end/tpu/eval_assert.py metrics_average metrics.txt $TFLOP_THRESHOLD perf/per_device_tflops_per_sec 38 | -------------------------------------------------------------------------------- /end_to_end/tpu/test_vocab_creation.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -ex 3 | 4 | RUN_NAME=${1}_$(date +%Y-%m-%d-%H) 5 | OUTPUT_PATH=${2} 6 | DATASET_PATH=${3} 7 | VOCAB_PATH=$OUTPUT_PATH/vocab_test_creation_$RUN_NAME 8 | 9 | 10 | #Train 11 | python3 -m MaxText.train MaxText/configs/base.yml run_name=$RUN_NAME steps=5 enable_checkpointing=False\ 12 | base_output_directory=$OUTPUT_PATH dataset_path=$DATASET_PATH tokenizer_path=$VOCAB_PATH 13 | 14 | python3 end_to_end/tpu/eval_assert.py vocab_creation $VOCAB_PATH 15 | -------------------------------------------------------------------------------- /getting_started/GCP_Workload_Observability.md: -------------------------------------------------------------------------------- 1 | # Enable GCP Workload Observabiltiy 2 | This guide provides an overview on how to enable GCP workload observability for your MaxText workload. 3 | 4 | ## Overview 5 | Google offers a monitoring and alerting feature that is well suited for critical MaxText workloads sensitive to infrastructure changes. 6 | Once enabled, metrics will be automatically sent to [Cloud Monarch](https://research.google/pubs/monarch-googles-planet-scale-in-memory-time-series-database/) for monitoring. 7 | If a metric hits its pre-defined threshold, the Google Cloud on-call team will be alerted to see if any action is needed. 8 | 9 | The feature currently supports heartbeat and performance (training step time in seconds) metrics. In the near future, support for the goodput metric will also be added. 10 | Users should work with their Customer Engineer (CE) and the Google team to define appropriate thresholds for the performance metrics. 11 | 12 | This guide layouts how to enable the feature for your MaxText workload. 13 | 14 | ## Enabling GCP Workload Observabiltiy 15 | User can control which metric they want to report via config: 16 | 17 | ### Heartbeat metric 18 | - This metric will be a boolean flag. 19 | - To turn on this metric, set `report_heartbeat_metric_for_gcp_monitoring` to `True` 20 | - To control the frequency of heartbeat reporting (default is every 5 seconds), set `heartbeat_reporting_interval_in_seconds` to your desired value. 21 | 22 | ### Performance metric 23 | - This metric will be a double, capturing the training step time in seconds. 24 | - To turn on this metric, set `report_performance_metric_for_gcp_monitoring` to `True` 25 | 26 | For an example, please refer to [base.yml](../MaxText/configs/base.yml). -------------------------------------------------------------------------------- /getting_started/Run_Llama2.md: -------------------------------------------------------------------------------- 1 | 16 | 17 | ## About Llama2 18 | 19 | MaxText supports [Llama2](https://llama.meta.com/llama2) pretraining, finetuning and decoding for its 7B and 70B flavors. To get started on decoding and finetuning of Llama2, you will first need to download weights along with its tokenizer from [Meta](https://llama.meta.com/llama-downloads). 20 | 21 | The file [test_llama2_7b.sh](https://github.com/google/maxtext/blob/main/end_to_end/tpu/llama2/7b/test_llama2_7b.sh) provides details on how to convert the PyTorch weights in orbax checkpoint format, and thereafter use it for running decoding and finetuning. [test_llama2_7b.sh](https://github.com/google/maxtext/blob/main/end_to_end/tpu/llama2/7b/test_llama2_7b.sh) also shows how to run pretraining and also how to run decoding on the finetuned model checkpoint. 22 | 23 | ### MaxText supports pretraining and finetuning with high performance. 24 | 25 | Model Flop utilization for training on v5e and v5p and v4 TPUs with MaxText. 26 | 27 | 28 | | Model | v4-128 (bf16) | v5p-128 (bf16) | v5e-256 (bf16) | 29 | | ---------- | -------------- | -------------- | -------------- | 30 | | Llama2-70b | 57% | 65% | 57% | 31 | -------------------------------------------------------------------------------- /maxtext_custom_wheels.Dockerfile: -------------------------------------------------------------------------------- 1 | ARG BASEIMAGE=maxtext_base_image 2 | FROM $BASEIMAGE 3 | 4 | # Requires wheels be in /deps. This means any custom wheels should be placed 5 | # in the maxtext directory. 6 | RUN python3 -m pip install --force-reinstall /deps/*.whl 7 | -------------------------------------------------------------------------------- /maxtext_db_dependencies.Dockerfile: -------------------------------------------------------------------------------- 1 | # syntax=docker/dockerfile:experimental 2 | # Copy benchmark-db 3 | FROM gcr.io/tpu-prod-env-one-vm/benchmark-db:2025-02-14 4 | 5 | # Install system dependencies 6 | RUN apt-get update && apt-get install -y curl gnupg 7 | 8 | # Add the Google Cloud SDK package repository 9 | RUN echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] https://packages.cloud.google.com/apt cloud-sdk main" | tee -a /etc/apt/sources.list.d/google-cloud-sdk.list 10 | RUN curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | apt-key --keyring /usr/share/keyrings/cloud.google.gpg add - 11 | 12 | # Install the Google Cloud SDK 13 | RUN apt-get update && apt-get install -y google-cloud-sdk 14 | 15 | # Set the default Python version to 3.10 16 | RUN update-alternatives --install /usr/bin/python3 python3 /usr/local/bin/python3.10 1 17 | 18 | # Set environment variables for Google Cloud SDK and Python 3.10 19 | ENV PATH="/usr/local/google-cloud-sdk/bin:/usr/local/bin/python3.10:${PATH}" 20 | 21 | # Set environment variables via build arguments 22 | ARG MODE 23 | ENV ENV_MODE=$MODE 24 | 25 | ARG JAX_VERSION 26 | ENV ENV_JAX_VERSION=$JAX_VERSION 27 | 28 | ARG LIBTPU_GCS_PATH 29 | ENV ENV_LIBTPU_GCS_PATH=$LIBTPU_GCS_PATH 30 | 31 | ARG DEVICE 32 | ENV ENV_DEVICE=$DEVICE 33 | 34 | RUN mkdir -p /deps 35 | 36 | # Set the working directory in the container 37 | WORKDIR /deps 38 | 39 | # Copy setup files and dependency files separately for better caching 40 | COPY setup.sh ./ 41 | COPY constraints_gpu.txt requirements.txt requirements_with_jax_ai_image.txt ./ 42 | 43 | # Install dependencies - these steps are cached unless the copied files change 44 | RUN echo "Running command: bash setup.sh MODE=$ENV_MODE JAX_VERSION=$ENV_JAX_VERSION LIBTPU_GCS_PATH=${ENV_LIBTPU_GCS_PATH} DEVICE=${ENV_DEVICE}" 45 | RUN --mount=type=cache,target=/root/.cache/pip bash setup.sh MODE=${ENV_MODE} JAX_VERSION=${ENV_JAX_VERSION} LIBTPU_GCS_PATH=${ENV_LIBTPU_GCS_PATH} DEVICE=${ENV_DEVICE} 46 | 47 | # Now copy the remaining code (source files that may change frequently) 48 | COPY . . 49 | -------------------------------------------------------------------------------- /maxtext_dependencies.Dockerfile: -------------------------------------------------------------------------------- 1 | # syntax=docker/dockerfile:experimental 2 | # Use Python 3.10 as the base image 3 | FROM python:3.10-slim-bullseye 4 | 5 | # Install system dependencies 6 | RUN apt-get update && apt-get install -y curl gnupg 7 | 8 | # Add the Google Cloud SDK package repository 9 | RUN echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] https://packages.cloud.google.com/apt cloud-sdk main" | tee -a /etc/apt/sources.list.d/google-cloud-sdk.list 10 | RUN curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | apt-key --keyring /usr/share/keyrings/cloud.google.gpg add - 11 | 12 | # Install the Google Cloud SDK 13 | RUN apt-get update && apt-get install -y google-cloud-sdk 14 | 15 | # Set the default Python version to 3.10 16 | RUN update-alternatives --install /usr/bin/python3 python3 /usr/local/bin/python3.10 1 17 | 18 | # Set environment variables for Google Cloud SDK and Python 3.10 19 | ENV PATH="/usr/local/google-cloud-sdk/bin:/usr/local/bin/python3.10:${PATH}" 20 | 21 | # Set environment variables via build arguments 22 | ARG MODE 23 | ENV ENV_MODE=$MODE 24 | 25 | ARG JAX_VERSION 26 | ENV ENV_JAX_VERSION=$JAX_VERSION 27 | 28 | ARG LIBTPU_GCS_PATH 29 | ENV ENV_LIBTPU_GCS_PATH=$LIBTPU_GCS_PATH 30 | 31 | ARG DEVICE 32 | ENV ENV_DEVICE=$DEVICE 33 | 34 | RUN mkdir -p /deps 35 | 36 | # Set the working directory in the container 37 | WORKDIR /deps 38 | 39 | # Copy setup files and dependency files separately for better caching 40 | COPY setup.sh ./ 41 | COPY constraints_gpu.txt requirements.txt requirements_with_jax_ai_image.txt ./ 42 | 43 | # Install dependencies - these steps are cached unless the copied files change 44 | RUN echo "Running command: bash setup.sh MODE=$ENV_MODE JAX_VERSION=$ENV_JAX_VERSION LIBTPU_GCS_PATH=${ENV_LIBTPU_GCS_PATH} DEVICE=${ENV_DEVICE}" 45 | RUN --mount=type=cache,target=/root/.cache/pip bash setup.sh MODE=${ENV_MODE} JAX_VERSION=${ENV_JAX_VERSION} LIBTPU_GCS_PATH=${ENV_LIBTPU_GCS_PATH} DEVICE=${ENV_DEVICE} 46 | 47 | # Now copy the remaining code (source files that may change frequently) 48 | COPY . . 49 | -------------------------------------------------------------------------------- /maxtext_gpu_dependencies.Dockerfile: -------------------------------------------------------------------------------- 1 | # syntax=docker/dockerfile:experimental 2 | ARG BASEIMAGE=ghcr.io/nvidia/jax:base 3 | FROM $BASEIMAGE 4 | 5 | # Stopgaps measure to circumvent gpg key setup issue. 6 | RUN echo "deb [trusted=yes] https://developer.download.nvidia.com/devtools/repos/ubuntu2204/amd64/ /" > /etc/apt/sources.list.d/devtools-ubuntu2204-amd64.list 7 | 8 | # Install dependencies for adjusting network rto 9 | RUN apt-get update && apt-get install -y iproute2 ethtool lsof 10 | 11 | # Install DNS util dependencies 12 | RUN apt-get install -y dnsutils 13 | 14 | # Add the Google Cloud SDK package repository 15 | RUN echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] https://packages.cloud.google.com/apt cloud-sdk main" | tee -a /etc/apt/sources.list.d/google-cloud-sdk.list 16 | RUN curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | apt-key --keyring /usr/share/keyrings/cloud.google.gpg add - 17 | 18 | # Install the Google Cloud SDK 19 | RUN apt-get update && apt-get install -y google-cloud-sdk 20 | 21 | # Set environment variables for Google Cloud SDK 22 | ENV PATH="/usr/local/google-cloud-sdk/bin:${PATH}" 23 | 24 | # Upgrade libcusprase to work with Jax 25 | RUN apt-get update && apt-get install -y libcusparse-12-6 26 | 27 | ARG MODE 28 | ENV ENV_MODE=$MODE 29 | 30 | ARG JAX_VERSION 31 | ENV ENV_JAX_VERSION=$JAX_VERSION 32 | 33 | ARG DEVICE 34 | ENV ENV_DEVICE=$DEVICE 35 | 36 | RUN mkdir -p /deps 37 | 38 | # Set the working directory in the container 39 | WORKDIR /deps 40 | 41 | # Copy setup files and dependency files separately for better caching 42 | COPY setup.sh ./ 43 | COPY constraints_gpu.txt requirements.txt requirements_with_jax_ai_image.txt ./ 44 | 45 | # Install dependencies - these steps are cached unless the copied files change 46 | RUN echo "Running command: bash setup.sh MODE=$ENV_MODE JAX_VERSION=$ENV_JAX_VERSION DEVICE=${ENV_DEVICE}" 47 | RUN --mount=type=cache,target=/root/.cache/pip bash setup.sh MODE=${ENV_MODE} JAX_VERSION=${ENV_JAX_VERSION} DEVICE=${ENV_DEVICE} 48 | 49 | # Now copy the remaining code (source files that may change frequently) 50 | COPY . . 51 | -------------------------------------------------------------------------------- /maxtext_jax_ai_image.Dockerfile: -------------------------------------------------------------------------------- 1 | ARG JAX_AI_IMAGE_BASEIMAGE 2 | 3 | # JAX AI Base Image 4 | FROM $JAX_AI_IMAGE_BASEIMAGE 5 | ARG JAX_AI_IMAGE_BASEIMAGE 6 | 7 | ARG COMMIT_HASH 8 | 9 | ENV COMMIT_HASH=$COMMIT_HASH 10 | 11 | RUN mkdir -p /deps 12 | 13 | # Set the working directory in the container 14 | WORKDIR /deps 15 | 16 | # Copy setup files and dependency files separately for better caching 17 | COPY setup.sh ./ 18 | COPY requirements.txt requirements_with_jax_ai_image.txt ./ 19 | 20 | 21 | # For JAX AI tpu training images 0.4.37 AND 0.4.35 22 | # Orbax checkpoint installs the latest version of JAX, 23 | # but the libtpu version in the base image is older. 24 | # This version mismatch can cause compatibility issues 25 | # and break MaxText. 26 | # Upgrade libtpu version if using either of the old stable images 27 | 28 | ARG DEVICE 29 | ENV DEVICE=$DEVICE 30 | 31 | RUN if [ "$DEVICE" = "tpu" ] && ([ "$JAX_AI_IMAGE_BASEIMAGE" = "us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.4.37-rev1" ] || [ "$JAX_AI_IMAGE_BASEIMAGE" = "us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.4.35-rev1" ]); then \ 32 | python3 -m pip install --no-cache-dir --upgrade jax[tpu]; fi 33 | 34 | # Install Maxtext requirements with Jax AI Image 35 | RUN apt-get update && apt-get install --yes && apt-get install --yes dnsutils 36 | # TODO(bvandermoon, parambole): Remove this when it's added to JAX AI Image 37 | RUN pip install google-cloud-monitoring 38 | RUN python3 -m pip install -r /deps/requirements_with_jax_ai_image.txt 39 | 40 | # Now copy the remaining code (source files that may change frequently) 41 | COPY . . 42 | RUN ls . 43 | 44 | # Run the script available in JAX AI base image to generate the manifest file 45 | RUN bash /jax-stable-stack/generate_manifest.sh PREFIX=maxtext COMMIT_HASH=$COMMIT_HASH 46 | -------------------------------------------------------------------------------- /maxtext_libtpu_path.Dockerfile: -------------------------------------------------------------------------------- 1 | ARG BASEIMAGE=maxtext_base_image 2 | FROM $BASEIMAGE 3 | 4 | #FROM maxtext_base_image 5 | # Set the TPU_LIBRARY_PATH 6 | ENV TPU_LIBRARY_PATH='/root/custom_libtpu/libtpu.so' 7 | 8 | WORKDIR /deps -------------------------------------------------------------------------------- /maxtext_runner.Dockerfile: -------------------------------------------------------------------------------- 1 | # syntax=docker.io/docker/dockerfile:1.7-labs 2 | 3 | ARG BASEIMAGE=maxtext_base_image 4 | FROM $BASEIMAGE 5 | 6 | #FROM maxtext_base_image 7 | 8 | # Set the working directory in the container 9 | WORKDIR /deps 10 | 11 | # Copy assets separately 12 | COPY assets assets/ 13 | COPY MaxText/test_assets/ MaxText/test_assets/ 14 | 15 | # Copy all files except assets from local workspace into docker container 16 | COPY --exclude=assets --exclude=MaxText/test_assets . . 17 | -------------------------------------------------------------------------------- /pedagogical_examples/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2025 Google LLC 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | -------------------------------------------------------------------------------- /pedagogical_examples/non_spmd.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | 3 | """ 4 | Copyright 2023 Google LLC 5 | 6 | Licensed under the Apache License, Version 2.0 (the "License"); 7 | you may not use this file except in compliance with the License. 8 | You may obtain a copy of the License at 9 | 10 | https://www.apache.org/licenses/LICENSE-2.0 11 | 12 | Unless required by applicable law or agreed to in writing, software 13 | distributed under the License is distributed on an "AS IS" BASIS, 14 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | See the License for the specific language governing permissions and 16 | limitations under the License. 17 | """ 18 | 19 | """ 20 | This programs demonstrates embarrassingly parallelizable non-SPMD computations in Jax, in this case by having each 21 | process_index run its own computation. 22 | The same approach can be extended for non-embarrassingly parallelizable computations. 23 | The simplest way to do that would be by running embarrassingly parallelizable computations on arbitrary submeshes, 24 | then using a `host_local_array_to_global_array` to reshard into a new global array. 25 | An important limitation of this approach is that we cannot overlap communication and computation between the different 26 | kernel calls. 27 | """ 28 | 29 | import numpy as np 30 | 31 | import jax 32 | from jax.sharding import PartitionSpec 33 | from jax.sharding import Mesh 34 | 35 | 36 | # Notice this is jax.local_devices(), not jax.devices(). Hence each process (on TPUVMs, each VM) will run separate programs 37 | # on its mesh. 38 | mesh = Mesh(np.array(jax.local_devices()), ["data"]) 39 | sharding = jax.sharding.NamedSharding(mesh, PartitionSpec(None)) 40 | idx = jax.process_index() 41 | 42 | 43 | # Example step depends on idx which is different on each program 44 | def example_step(): 45 | return idx * jax.numpy.ones((idx + 1)) 46 | 47 | 48 | jit_func = jax.jit( 49 | example_step, 50 | out_shardings=sharding, 51 | ) 52 | 53 | # pylint: disable=not-callable 54 | print(f"{idx=} -> {jit_func()=}") 55 | -------------------------------------------------------------------------------- /preflight.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | echo "Running preflight.sh" 3 | # Command Flags: 4 | # 5 | # Example to invoke this script: 6 | # bash preflight.sh 7 | 8 | # Warning: 9 | # For any dependencies, please add them into `setup.sh` or `maxtext_dependencies.Dockerfile`. 10 | # You should not install any dependencies in this file. 11 | 12 | # Stop execution if any command exits with error 13 | set -e 14 | 15 | # Set environment variables 16 | for ARGUMENT in "$@"; do 17 | IFS='=' read -r KEY VALUE <<< "$ARGUMENT" 18 | export "$KEY"="$VALUE" 19 | done 20 | 21 | # Check if sudo is available 22 | if command -v sudo >/dev/null 2>&1; then 23 | # sudo is available, use it 24 | echo "running rto_setup.sh with sudo" 25 | 26 | # apply network settings. 27 | sudo bash rto_setup.sh 28 | else 29 | # sudo is not available, run the script without sudo 30 | echo "running rto_setup.sh without sudo" 31 | 32 | # apply network settings. 33 | bash rto_setup.sh 34 | fi -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | [build-system] 16 | requires = ["setuptools>=61.0"] 17 | build-backend = "setuptools.build_meta" 18 | 19 | [project] 20 | name = "MaxText" 21 | version = "0.1.0" 22 | description = "MaxText is a high performance, highly scalable, open-source LLM written in pure Python/Jax and targeting Google Cloud TPUs and GPUs for training and inference" 23 | readme = "README.md" 24 | requires-python = ">=3.10" 25 | license = {file = "LICENSE"} 26 | dynamic = ["dependencies"] 27 | 28 | [tool.setuptools.dynamic] 29 | dependencies = {file = ["requirements.txt"]} 30 | 31 | [tool.setuptools.packages.find] 32 | where = ["MaxText"] 33 | include = ["MaxText*"] 34 | exclude = ["MaxText.tests.*"] 35 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | # pytest.ini 2 | [pytest] 3 | testpaths = 4 | tests 5 | python_files = *_test.py *_tests.py 6 | addopts = 7 | -rf --import-mode=importlib --strict-markers 8 | --ignore=MaxText/tests/profiler_test.py 9 | --ignore=MaxText/tests/train_smoke_test.py 10 | --ignore=MaxText/tests/train_int8_smoke_test.py 11 | --ignore=MaxText/tests/train_gpu_smoke_test.py 12 | --ignore=MaxText/tests/train_using_ragged_dot_smoke_test.py 13 | markers = 14 | tpu_only: marks tests to be run on TPUs only 15 | gpu_only: marks tests to be run on GPUs only 16 | cpu_only: marks tests to be run on CPUs only 17 | integration_test: tests exercising larger portions of the system, 18 | including interactions with other systems like GCS, 19 | e.g., end_to_end tests 20 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | jax>=0.4.30 2 | jaxlib>=0.4.30 3 | orbax-checkpoint>=0.5.12 4 | absl-py 5 | array-record 6 | aqtp 7 | cloud-accelerator-diagnostics 8 | cloud-tpu-diagnostics 9 | datasets 10 | gcsfs 11 | google-cloud-aiplatform==1.61.0 12 | google-cloud-storage 13 | google-cloud-monitoring 14 | google-api-core 15 | google-api-python-client 16 | grain[parquet]>=0.2.6 17 | huggingface_hub 18 | flax>=v0.10.6 19 | jaxtyping 20 | ml-collections 21 | ml-goodput-measurement==0.0.10 22 | numpy 23 | optax 24 | protobuf==3.20.3 25 | pylint 26 | pytest 27 | pyink 28 | pre-commit 29 | pytype 30 | pillow>=11.1.0 31 | sentencepiece==0.2.0 32 | tensorflow-text>=2.13.0 33 | tensorflow>=2.13.0 34 | tensorflow-datasets 35 | tensorboardx>=2.6.2.2 36 | tensorboard-plugin-profile 37 | tiktoken 38 | transformers 39 | mlperf-logging@git+https://github.com/mlperf/logging.git 40 | google-jetstream@git+https://github.com/AI-Hypercomputer/JetStream.git 41 | jsonlines 42 | pathwaysutils==0.1.1 43 | omegaconf 44 | -------------------------------------------------------------------------------- /requirements_with_jax_ai_image.txt: -------------------------------------------------------------------------------- 1 | # Requirements for Building the MaxText Docker Image 2 | # These requirements are additional to the dependencies present in the JAX AI base image. 3 | datasets 4 | grain[parquet]>=0.2.6 5 | orbax-checkpoint>=0.10.3 6 | pylint 7 | pytest 8 | pyink 9 | pre-commit 10 | protobuf==3.20.3 11 | pytype 12 | pillow>=11.1.0 13 | sentencepiece==0.1.97 14 | tensorflow-text>=2.13.0 15 | tensorflow-datasets 16 | tiktoken 17 | transformers 18 | mlperf-logging@git+https://github.com/mlperf/logging.git 19 | google-jetstream@git+https://github.com/AI-Hypercomputer/JetStream.git 20 | jsonlines 21 | pathwaysutils==0.1.1 22 | google-api-python-client 23 | omegaconf 24 | jaxtyping 25 | -------------------------------------------------------------------------------- /rto_setup.sh: -------------------------------------------------------------------------------- 1 | echo "Running rto_setup.sh" 2 | 3 | # Stop execution if any command exits with error 4 | set -e 5 | 6 | echo "Adjust Network settings and apply non cache copy" 7 | 8 | # Disable slow start after idle 9 | sysctl net.ipv4.tcp_slow_start_after_idle=0 10 | 11 | # Disable metrics cache 12 | sysctl net.ipv4.tcp_no_metrics_save=1 13 | 14 | # Address rto_min issue with two default routing entries: screen/7RGgkiXkGXSeYF2 15 | route=$(ip route show | sed -n 1p) 16 | second_route=$(ip route show | sed -n 2p) 17 | if [[ "${second_route}" =~ ^default.* ]]; then 18 | modified_route=${route//" lock"/} 19 | ip route delete ${modified_route} 20 | fi 21 | route=$(ip route show | sed -n 1p) 22 | echo "route rto before change: $route" 23 | if [[ "${route}" =~ .*lock.*5ms.* ]]; then 24 | echo "${route}" 25 | else 26 | # shellcheck disable=SC2086 27 | ip route change $route rto_min 5ms 28 | fi 29 | route=$(ip route show | sed -n 1p) 30 | echo "route rto after change: $route" 31 | 32 | # Disable Cubic Hystart Ack-Train 33 | echo 2 > /sys/module/tcp_cubic/parameters/hystart_detect 34 | 35 | # Improve handling SYN burst 36 | echo 4096 > /proc/sys/net/core/somaxconn 37 | echo 4096 > /proc/sys/net/ipv4/tcp_max_syn_backlog 38 | 39 | # Disable MTU Discovery 40 | echo 0 > /proc/sys/net/ipv4/tcp_mtu_probing 41 | 42 | # Increase TCP Zerocopy control memory 43 | sysctl -w net.core.optmem_max=131072 44 | 45 | # Printing output of `ip route show` 46 | echo -e "\nPrinting output of 'ip route show':" 47 | ip route show 48 | 49 | first_line_res=$(ip route show | head -n 1) 50 | dev_name=$(echo "$first_line_res" | awk -F'[[:space:]]' '{ print $5 }') 51 | echo "dev_name=${dev_name}" 52 | ethtool -K "${dev_name}" tx-nocache-copy on 53 | 54 | echo "rto_setup.sh finished" -------------------------------------------------------------------------------- /setup_with_retries.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | max_attempts=5 4 | current_attempt=1 5 | 6 | # Loop until setup succeeds or reaches the maximum number of attempts 7 | while [ $current_attempt -le $max_attempts ]; do 8 | echo "Attempt to run setup.sh $current_attempt:" 9 | bash setup.sh "$@" 10 | 11 | # Check the exit status of run_setup 12 | if [ $? -eq 0 ]; then 13 | echo "Success for running setup on attempt $current_attempt!" 14 | exit 0 15 | else 16 | echo "Failed to run setup on attempt $current_attempt." 17 | ((current_attempt++)) 18 | sleep 5 # Short delay before next attempt 19 | fi 20 | done 21 | 22 | echo "All attempts to run setup failed. Exiting." 23 | exit 1 -------------------------------------------------------------------------------- /unit_test_and_lint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Copyright 2023 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | python3 -m pylint $(git ls-files '*.py') 18 | 19 | python3 -m pytest --pyargs MaxText.tests 20 | --------------------------------------------------------------------------------