├── .axlearn └── axlearn.default.config ├── .circleci └── config.yml ├── .github ├── scripts │ └── monitor_memory.sh └── workflows │ ├── build.yml │ └── pre-commit.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .pylintrc ├── ACKNOWLEDGEMENTS.md ├── CHANGELOG.md ├── CODEOWNERS ├── CONTRIBUTING.md ├── Dockerfile ├── LICENSE ├── README.md ├── axlearn ├── __init__.py ├── audio │ ├── __init__.py │ ├── decoder_asr.py │ ├── decoder_asr_test.py │ ├── encoder_asr.py │ ├── encoder_asr_test.py │ ├── evaler_asr.py │ ├── evaler_asr_test.py │ ├── frontend.py │ ├── frontend_benchmark.py │ ├── frontend_test.py │ ├── frontend_utils.py │ ├── frontend_utils_test.py │ ├── input_asr.py │ ├── input_asr_test.py │ ├── model_asr.py │ ├── model_asr_test.py │ ├── spectrum_augmenter.py │ ├── spectrum_augmenter_test.py │ ├── subsamplers.py │ ├── subsamplers_test.py │ └── test_utils.py ├── cli │ ├── __init__.py │ ├── gcp.py │ ├── testdata │ │ └── dummy.py │ ├── utils.py │ └── utils_test.py ├── cloud │ ├── __init__.py │ ├── common │ │ ├── __init__.py │ │ ├── bastion.py │ │ ├── bastion_test.py │ │ ├── bundler.py │ │ ├── bundler_test.py │ │ ├── cleaner.py │ │ ├── cleaner_test.py │ │ ├── config.py │ │ ├── config_test.py │ │ ├── docker.py │ │ ├── event_queue.py │ │ ├── event_queue_test.py │ │ ├── git_summary.py │ │ ├── git_summary_test.py │ │ ├── job.py │ │ ├── job_test.py │ │ ├── quota.py │ │ ├── quota_test.py │ │ ├── scheduler.py │ │ ├── scheduler_test.py │ │ ├── testdata │ │ │ └── counter.py │ │ ├── types.py │ │ ├── uploader.py │ │ ├── uploader_test.py │ │ ├── utils.py │ │ ├── utils_test.py │ │ ├── validator.py │ │ ├── writer.py │ │ └── writer_test.py │ └── gcp │ │ ├── __init__.py │ │ ├── bundler.py │ │ ├── bundler_test.py │ │ ├── cloud_build.py │ │ ├── cloud_build_test.py │ │ ├── config.py │ │ ├── config_test.py │ │ ├── event_queue.py │ │ ├── event_queue_test.py │ │ ├── examples │ │ ├── dataflow_inference_custom.py │ │ └── dataflow_inference_hf.py │ │ ├── job.py │ │ ├── job_flink.py │ │ ├── job_flink_test.py │ │ ├── job_pathways.py │ │ ├── job_pathways_test.py │ │ ├── job_test.py │ │ ├── jobs │ │ ├── __init__.py │ │ ├── bastion_vm.py │ │ ├── bastion_vm_test.py │ │ ├── cpu_runner.py │ │ ├── cpu_runner_test.py │ │ ├── dataflow.py │ │ ├── dataflow_test.py │ │ ├── gke_runner.py │ │ ├── launch.py │ │ ├── launch_test.py │ │ ├── launch_utils.py │ │ ├── launch_utils_test.py │ │ ├── logs.py │ │ └── logs_test.py │ │ ├── jobset_utils.py │ │ ├── jobset_utils_test.py │ │ ├── measurement.py │ │ ├── measurement_test.py │ │ ├── monitoring │ │ ├── __init__.py │ │ ├── testdata │ │ │ ├── sample_metrics.txt │ │ │ └── sample_metrics_idle.txt │ │ ├── tpu_client.py │ │ ├── tpu_client_test.py │ │ ├── tpu_device_monitor.py │ │ └── tpu_device_monitor_test.py │ │ ├── nccl │ │ ├── a3_mega │ │ │ ├── guest_config.txtpb │ │ │ └── tuner_config.txtpb │ │ ├── a3_ultra │ │ │ ├── guest_config.txtpb │ │ │ └── tuner_config.txtpb │ │ └── a4_high │ │ │ ├── guest_config.txtpb │ │ │ └── tuner_config.txtpb │ │ ├── node_pool.py │ │ ├── node_pool_provisioner.py │ │ ├── node_pool_provisioner_test.py │ │ ├── node_pool_test.py │ │ ├── pathways_utils.py │ │ ├── pathways_utils_test.py │ │ ├── runners │ │ ├── __init__.py │ │ ├── base.py │ │ ├── gke.py │ │ ├── gke_test.py │ │ ├── utils.py │ │ └── utils_test.py │ │ ├── scopes.py │ │ ├── scripts │ │ ├── project_setup.sh │ │ ├── start_tpu.sh │ │ └── start_vm.sh │ │ ├── storage.py │ │ ├── storage_test.py │ │ ├── system_characteristics.py │ │ ├── test_utils.py │ │ ├── tpu.py │ │ ├── tpu_health_check.py │ │ ├── tpu_health_check_main.py │ │ ├── tpu_health_check_test.py │ │ ├── tpu_test.py │ │ ├── utils.py │ │ ├── utils_test.py │ │ ├── vertexai_tensorboard.py │ │ ├── vertexai_tensorboard_test.py │ │ ├── vm.py │ │ └── vm_test.py ├── common │ ├── __init__.py │ ├── adapter_flax.py │ ├── adapter_flax_test.py │ ├── adapter_torch.py │ ├── adapter_torch_test.py │ ├── aot_compilation.py │ ├── aot_compilation_test.py │ ├── array_serialization.py │ ├── array_serialization_test.py │ ├── attention.py │ ├── attention_bias.py │ ├── attention_bias_test.py │ ├── attention_test.py │ ├── base_encoder_decoder.py │ ├── base_layer.py │ ├── base_layer_test.py │ ├── base_model.py │ ├── bert.py │ ├── bert_test.py │ ├── causal_lm.py │ ├── causal_lm_test.py │ ├── checkpointer.py │ ├── checkpointer_orbax.py │ ├── checkpointer_orbax_emergency.py │ ├── checkpointer_orbax_emergency_test.py │ ├── checkpointer_orbax_test.py │ ├── checkpointer_test.py │ ├── compiler_options.py │ ├── compiler_options_test.py │ ├── config.py │ ├── config_test.py │ ├── conformer.py │ ├── conformer_test.py │ ├── convolution.py │ ├── convolution_test.py │ ├── deberta.py │ ├── deberta_test.py │ ├── debug_utils.py │ ├── debug_utils_test.py │ ├── decoder.py │ ├── decoder_test.py │ ├── decoding.py │ ├── decoding_test.py │ ├── distilbert.py │ ├── distilbert_test.py │ ├── distillation.py │ ├── distillation_test.py │ ├── dit.py │ ├── dit_test.py │ ├── ein_ops.py │ ├── ein_ops_test.py │ ├── embedding.py │ ├── embedding_test.py │ ├── encoder.py │ ├── encoder_decoder.py │ ├── encoder_decoder_test.py │ ├── encoder_test.py │ ├── env_test.py │ ├── eval_classification.py │ ├── eval_classification_test.py │ ├── eval_retrieval.py │ ├── eval_retrieval_test.py │ ├── evaler.py │ ├── evaler_test.py │ ├── factorized_rms.py │ ├── factorized_rms_test.py │ ├── file_system.py │ ├── file_system_test.py │ ├── flash_attention │ │ ├── __init__.py │ │ ├── common.py │ │ ├── decoding_test.py │ │ ├── gpu_attention.py │ │ ├── gpu_attention_benchmark.py │ │ ├── gpu_attention_test.py │ │ ├── gpu_decoding.py │ │ ├── gpu_paged_attention.py │ │ ├── layer.py │ │ ├── layer_test.py │ │ ├── neuron_attention.py │ │ ├── neuron_attention_test.py │ │ ├── remat.py │ │ ├── remat_test.py │ │ ├── test_utils.py │ │ ├── tpu_attention.py │ │ ├── tpu_attention_benchmark.py │ │ ├── tpu_attention_test.py │ │ ├── tpu_decoding.py │ │ ├── tpu_paged_attention.py │ │ ├── tpu_splash_attention.py │ │ ├── utils.py │ │ └── utils_test.py │ ├── gda_test.py │ ├── gradient_accumulation.py │ ├── gradient_accumulation_test.py │ ├── host_array_test.py │ ├── inference.py │ ├── inference_output.py │ ├── inference_pipeline.py │ ├── inference_test.py │ ├── input_base.py │ ├── input_base_test.py │ ├── input_composite.py │ ├── input_composite_test.py │ ├── input_dispatch.py │ ├── input_dispatch_test.py │ ├── input_fake.py │ ├── input_glue.py │ ├── input_glue_test.py │ ├── input_grain.py │ ├── input_grain_lm.py │ ├── input_grain_lm_test.py │ ├── input_grain_test.py │ ├── input_grain_text.py │ ├── input_grain_text_test.py │ ├── input_lm.py │ ├── input_lm_test.py │ ├── input_mlm.py │ ├── input_mlm_test.py │ ├── input_ranking.py │ ├── input_ranking_test.py │ ├── input_reading_comprehension.py │ ├── input_reading_comprehension_test.py │ ├── input_t5.py │ ├── input_t5_test.py │ ├── input_text.py │ ├── input_text_test.py │ ├── input_tf_data.py │ ├── input_tf_data_test.py │ ├── launch.py │ ├── launch_test.py │ ├── launch_trainer.py │ ├── launch_trainer_main.py │ ├── launch_trainer_test.py │ ├── layers.py │ ├── layers_test.py │ ├── learner.py │ ├── learner_base.py │ ├── learner_test.py │ ├── liveness_monitor.py │ ├── liveness_monitor_test.py │ ├── logit_modifiers.py │ ├── logit_modifiers_test.py │ ├── lora.py │ ├── lora_test.py │ ├── loss.py │ ├── loss_metrics.py │ ├── loss_test.py │ ├── measurement.py │ ├── measurement_test.py │ ├── metrics.py │ ├── metrics_classification.py │ ├── metrics_classification_test.py │ ├── metrics_correlation.py │ ├── metrics_correlation_test.py │ ├── metrics_glue.py │ ├── metrics_glue_test.py │ ├── metrics_retrieval.py │ ├── metrics_retrieval_test.py │ ├── metrics_test.py │ ├── metrics_text_dual_encoder.py │ ├── metrics_text_dual_encoder_test.py │ ├── mixture_of_experts.py │ ├── mixture_of_experts_test.py │ ├── module.py │ ├── module_test.py │ ├── monitoring │ │ ├── __init__.py │ │ ├── device_monitor.py │ │ ├── device_monitor_test.py │ │ ├── gpu_client.py │ │ ├── gpu_client_test.py │ │ ├── gpu_device_monitor.py │ │ └── gpu_device_monitor_test.py │ ├── multi_stream_model.py │ ├── multi_stream_model_test.py │ ├── multiway_transformer.py │ ├── multiway_transformer_test.py │ ├── neural_retrieval.py │ ├── neural_retrieval_test.py │ ├── normalize.py │ ├── normalize_test.py │ ├── ops │ │ ├── __init__.py │ │ ├── _optimization_barrier.py │ │ └── _optimization_barrier_test.py │ ├── optimizer_base.py │ ├── optimizers.py │ ├── optimizers_test.py │ ├── param_converter.py │ ├── param_converter_test.py │ ├── param_init.py │ ├── param_init_test.py │ ├── pipeline.py │ ├── pipeline_test.py │ ├── poolings.py │ ├── poolings_test.py │ ├── quantized_dot_general │ │ ├── __init__.py │ │ ├── activation_clipping.py │ │ ├── fp8_ops.py │ │ ├── layers.py │ │ ├── layers_test.py │ │ ├── utils.py │ │ └── utils_test.py │ ├── quantizer.py │ ├── quantizer_test.py │ ├── repeat.py │ ├── repeat_test.py │ ├── rnn.py │ ├── rnn_test.py │ ├── schedule.py │ ├── schedule_test.py │ ├── serialization.py │ ├── serialization_test.py │ ├── splade.py │ ├── splade_test.py │ ├── ssm.py │ ├── ssm_kernels │ │ ├── __init__.py │ │ ├── mamba_kernels.py │ │ ├── mamba_kernels_test.py │ │ ├── ssd_kernels.py │ │ └── ssd_kernels_test.py │ ├── ssm_test.py │ ├── state_builder.py │ ├── state_builder_test.py │ ├── status_server.py │ ├── status_server_test.py │ ├── struct.py │ ├── struct_test.py │ ├── summary.py │ ├── summary_test.py │ ├── summary_writer.py │ ├── summary_writer_test.py │ ├── t5.py │ ├── t5_test.py │ ├── test_utils.py │ ├── test_utils_test.py │ ├── text_dual_encoder.py │ ├── text_dual_encoder_test.py │ ├── text_encoder.py │ ├── torch_utils.py │ ├── traceback_util.py │ ├── traceback_util_test.py │ ├── trainer.py │ ├── trainer_config_modifier.py │ ├── trainer_config_modifier_test.py │ ├── trainer_test.py │ ├── transducer.py │ ├── transducer_test.py │ ├── update_transformation.py │ ├── update_transformation_test.py │ ├── utils.py │ ├── utils_spmd.py │ ├── utils_test.py │ ├── utils_text_dual_encoder.py │ ├── utils_text_dual_encoder_test.py │ ├── utils_tf.py │ ├── utils_tf_test.py │ ├── vision_transformer.py │ ├── vision_transformer_test.py │ ├── vocabulary_bpe.py │ └── vocabulary_bpe_test.py ├── data │ └── tokenizers │ │ └── sentencepiece │ │ ├── bpe_128k.json │ │ ├── bpe_128k_c4.model │ │ ├── bpe_128k_c4.vocab │ │ ├── bpe_32k.json │ │ ├── bpe_32k_c4.model │ │ ├── bpe_32k_c4.vocab │ │ ├── librispeech_bpe_1024.model │ │ └── librispeech_unigram_1024.model ├── experiments │ ├── __init__.py │ ├── aot_test.py │ ├── audio │ │ ├── __init__.py │ │ └── conformer │ │ │ ├── __init__.py │ │ │ ├── common.py │ │ │ ├── common_test.py │ │ │ ├── librispeech_trainer.py │ │ │ └── librispeech_trainer_test.py │ ├── calculate_goodput.py │ ├── conftest.py │ ├── golden_ckpt_test.py │ ├── golden_config_test.py │ ├── run_aot_compilation.py │ ├── test_utils.py │ ├── testdata │ │ ├── axlearn.common.conformer_test │ │ │ └── test_against_fairseq.npy │ │ ├── axlearn.common.encoder_decoder_test │ │ │ ├── test_against_t5x_False.npy │ │ │ └── test_against_t5x_True.npy │ │ ├── axlearn.common.param_converter_test │ │ │ ├── test_parameters_from_t5x_attention.npy │ │ │ ├── test_parameters_from_t5x_decoder.npy │ │ │ ├── test_parameters_from_t5x_dense.npy │ │ │ ├── test_parameters_from_t5x_embedding.npy │ │ │ ├── test_parameters_from_t5x_encoder.npy │ │ │ ├── test_parameters_from_t5x_encoder_decoder.npy │ │ │ ├── test_parameters_from_t5x_ff.npy │ │ │ ├── test_parameters_from_t5x_layer_norm.npy │ │ │ ├── test_parameters_from_t5x_rel_pos_emb_False.npy │ │ │ ├── test_parameters_from_t5x_rel_pos_emb_True.npy │ │ │ └── test_parameters_from_t5x_transformer_layer.npy │ │ ├── axlearn.common.quantizer_test │ │ │ └── test_forward_against_fairseq.npy │ │ ├── axlearn.common.t5_test │ │ │ ├── test_buckets_against_t5x_False_100.npy │ │ │ ├── test_buckets_against_t5x_False_20.npy │ │ │ ├── test_buckets_against_t5x_False_256.npy │ │ │ ├── test_buckets_against_t5x_True_100.npy │ │ │ ├── test_buckets_against_t5x_True_20.npy │ │ │ └── test_buckets_against_t5x_True_256.npy │ │ ├── axlearn.experiments.audio.conformer.librispeech_trainer │ │ │ ├── conformer-l-rnnt.txt │ │ │ ├── conformer-l-rnnt_init.txt │ │ │ ├── conformer-l-rnnt_regularizer.txt │ │ │ ├── conformer-test-ctc.txt │ │ │ ├── conformer-test-ctc_init.txt │ │ │ └── conformer-test-ctc_regularizer.txt │ │ ├── axlearn.experiments.text.gpt.c4_trainer │ │ │ ├── fuji-1B-v3-flash-fp8-single-host.txt │ │ │ ├── fuji-1B-v3-flash-fp8-single-host_init.txt │ │ │ ├── fuji-1B-v3-flash-fp8-single-host_regularizer.txt │ │ │ ├── fuji-1B-v3-flash-fp8.txt │ │ │ ├── fuji-1B-v3-flash-fp8_init.txt │ │ │ ├── fuji-1B-v3-flash-fp8_regularizer.txt │ │ │ ├── fuji-1B-v3-flash-single-host.txt │ │ │ ├── fuji-1B-v3-flash-single-host_init.txt │ │ │ ├── fuji-1B-v3-flash-single-host_regularizer.txt │ │ │ ├── fuji-1B-v3-flash.txt │ │ │ ├── fuji-1B-v3-flash_init.txt │ │ │ ├── fuji-1B-v3-flash_regularizer.txt │ │ │ ├── fuji-1B-v3-fp8-single-host.txt │ │ │ ├── fuji-1B-v3-fp8-single-host_init.txt │ │ │ ├── fuji-1B-v3-fp8-single-host_regularizer.txt │ │ │ ├── fuji-1B-v3-fp8.txt │ │ │ ├── fuji-1B-v3-fp8_init.txt │ │ │ ├── fuji-1B-v3-fp8_regularizer.txt │ │ │ ├── fuji-1B-v3-single-host.txt │ │ │ ├── fuji-1B-v3-single-host_init.txt │ │ │ ├── fuji-1B-v3-single-host_regularizer.txt │ │ │ ├── fuji-1B-v3-tiktoken-flash-fp8-single-host.txt │ │ │ ├── fuji-1B-v3-tiktoken-flash-fp8-single-host_init.txt │ │ │ ├── fuji-1B-v3-tiktoken-flash-fp8-single-host_regularizer.txt │ │ │ ├── fuji-1B-v3-tiktoken-flash-fp8.txt │ │ │ ├── fuji-1B-v3-tiktoken-flash-fp8_init.txt │ │ │ ├── fuji-1B-v3-tiktoken-flash-fp8_regularizer.txt │ │ │ ├── fuji-1B-v3-tiktoken-flash-single-host.txt │ │ │ ├── fuji-1B-v3-tiktoken-flash-single-host_init.txt │ │ │ ├── fuji-1B-v3-tiktoken-flash-single-host_regularizer.txt │ │ │ ├── fuji-1B-v3-tiktoken-flash.txt │ │ │ ├── fuji-1B-v3-tiktoken-flash_init.txt │ │ │ ├── fuji-1B-v3-tiktoken-flash_regularizer.txt │ │ │ ├── fuji-1B-v3-tiktoken-fp8-single-host.txt │ │ │ ├── fuji-1B-v3-tiktoken-fp8-single-host_init.txt │ │ │ ├── fuji-1B-v3-tiktoken-fp8-single-host_regularizer.txt │ │ │ ├── fuji-1B-v3-tiktoken-fp8.txt │ │ │ ├── fuji-1B-v3-tiktoken-fp8_init.txt │ │ │ ├── fuji-1B-v3-tiktoken-fp8_regularizer.txt │ │ │ ├── fuji-1B-v3-tiktoken-single-host.txt │ │ │ ├── fuji-1B-v3-tiktoken-single-host_init.txt │ │ │ ├── fuji-1B-v3-tiktoken-single-host_regularizer.txt │ │ │ ├── fuji-1B-v3-tiktoken.txt │ │ │ ├── fuji-1B-v3-tiktoken_init.txt │ │ │ ├── fuji-1B-v3-tiktoken_regularizer.txt │ │ │ ├── fuji-1B-v3.txt │ │ │ ├── fuji-1B-v3_init.txt │ │ │ ├── fuji-1B-v3_regularizer.txt │ │ │ ├── fuji-3B-v3-flash-fp8-single-host.txt │ │ │ ├── fuji-3B-v3-flash-fp8-single-host_init.txt │ │ │ ├── fuji-3B-v3-flash-fp8-single-host_regularizer.txt │ │ │ ├── fuji-3B-v3-flash-fp8.txt │ │ │ ├── fuji-3B-v3-flash-fp8_init.txt │ │ │ ├── fuji-3B-v3-flash-fp8_regularizer.txt │ │ │ ├── fuji-3B-v3-flash-single-host.txt │ │ │ ├── fuji-3B-v3-flash-single-host_init.txt │ │ │ ├── fuji-3B-v3-flash-single-host_regularizer.txt │ │ │ ├── fuji-3B-v3-flash.txt │ │ │ ├── fuji-3B-v3-flash_init.txt │ │ │ ├── fuji-3B-v3-flash_regularizer.txt │ │ │ ├── fuji-3B-v3-fp8-single-host.txt │ │ │ ├── fuji-3B-v3-fp8-single-host_init.txt │ │ │ ├── fuji-3B-v3-fp8-single-host_regularizer.txt │ │ │ ├── fuji-3B-v3-fp8.txt │ │ │ ├── fuji-3B-v3-fp8_init.txt │ │ │ ├── fuji-3B-v3-fp8_regularizer.txt │ │ │ ├── fuji-3B-v3-single-host.txt │ │ │ ├── fuji-3B-v3-single-host_init.txt │ │ │ ├── fuji-3B-v3-single-host_regularizer.txt │ │ │ ├── fuji-3B-v3-tiktoken-flash-fp8-single-host.txt │ │ │ ├── fuji-3B-v3-tiktoken-flash-fp8-single-host_init.txt │ │ │ ├── fuji-3B-v3-tiktoken-flash-fp8-single-host_regularizer.txt │ │ │ ├── fuji-3B-v3-tiktoken-flash-fp8.txt │ │ │ ├── fuji-3B-v3-tiktoken-flash-fp8_init.txt │ │ │ ├── fuji-3B-v3-tiktoken-flash-fp8_regularizer.txt │ │ │ ├── fuji-3B-v3-tiktoken-flash-single-host.txt │ │ │ ├── fuji-3B-v3-tiktoken-flash-single-host_init.txt │ │ │ ├── fuji-3B-v3-tiktoken-flash-single-host_regularizer.txt │ │ │ ├── fuji-3B-v3-tiktoken-flash.txt │ │ │ ├── fuji-3B-v3-tiktoken-flash_init.txt │ │ │ ├── fuji-3B-v3-tiktoken-flash_regularizer.txt │ │ │ ├── fuji-3B-v3-tiktoken-fp8-single-host.txt │ │ │ ├── fuji-3B-v3-tiktoken-fp8-single-host_init.txt │ │ │ ├── fuji-3B-v3-tiktoken-fp8-single-host_regularizer.txt │ │ │ ├── fuji-3B-v3-tiktoken-fp8.txt │ │ │ ├── fuji-3B-v3-tiktoken-fp8_init.txt │ │ │ ├── fuji-3B-v3-tiktoken-fp8_regularizer.txt │ │ │ ├── fuji-3B-v3-tiktoken-single-host.txt │ │ │ ├── fuji-3B-v3-tiktoken-single-host_init.txt │ │ │ ├── fuji-3B-v3-tiktoken-single-host_regularizer.txt │ │ │ ├── fuji-3B-v3-tiktoken.txt │ │ │ ├── fuji-3B-v3-tiktoken_init.txt │ │ │ ├── fuji-3B-v3-tiktoken_regularizer.txt │ │ │ ├── fuji-3B-v3.txt │ │ │ ├── fuji-3B-v3_init.txt │ │ │ ├── fuji-3B-v3_regularizer.txt │ │ │ ├── fuji-70B-v1-flash-fp8.txt │ │ │ ├── fuji-70B-v1-flash-fp8_init.txt │ │ │ ├── fuji-70B-v1-flash-fp8_regularizer.txt │ │ │ ├── fuji-70B-v1-flash.txt │ │ │ ├── fuji-70B-v1-flash_init.txt │ │ │ ├── fuji-70B-v1-flash_regularizer.txt │ │ │ ├── fuji-70B-v1-fp8.txt │ │ │ ├── fuji-70B-v1-fp8_init.txt │ │ │ ├── fuji-70B-v1-fp8_regularizer.txt │ │ │ ├── fuji-70B-v1.txt │ │ │ ├── fuji-70B-v1_init.txt │ │ │ ├── fuji-70B-v1_regularizer.txt │ │ │ ├── fuji-70B-v2-flash-fp8.txt │ │ │ ├── fuji-70B-v2-flash-fp8_init.txt │ │ │ ├── fuji-70B-v2-flash-fp8_regularizer.txt │ │ │ ├── fuji-70B-v2-flash.txt │ │ │ ├── fuji-70B-v2-flash_init.txt │ │ │ ├── fuji-70B-v2-flash_regularizer.txt │ │ │ ├── fuji-70B-v2-fp8.txt │ │ │ ├── fuji-70B-v2-fp8_init.txt │ │ │ ├── fuji-70B-v2-fp8_regularizer.txt │ │ │ ├── fuji-70B-v2.txt │ │ │ ├── fuji-70B-v2_init.txt │ │ │ ├── fuji-70B-v2_regularizer.txt │ │ │ ├── fuji-70B-v3-flash-fp8.txt │ │ │ ├── fuji-70B-v3-flash-fp8_init.txt │ │ │ ├── fuji-70B-v3-flash-fp8_regularizer.txt │ │ │ ├── fuji-70B-v3-flash.txt │ │ │ ├── fuji-70B-v3-flash_init.txt │ │ │ ├── fuji-70B-v3-flash_regularizer.txt │ │ │ ├── fuji-70B-v3-fp8.txt │ │ │ ├── fuji-70B-v3-fp8_init.txt │ │ │ ├── fuji-70B-v3-fp8_regularizer.txt │ │ │ ├── fuji-70B-v3-tiktoken-flash-fp8.txt │ │ │ ├── fuji-70B-v3-tiktoken-flash-fp8_init.txt │ │ │ ├── fuji-70B-v3-tiktoken-flash-fp8_regularizer.txt │ │ │ ├── fuji-70B-v3-tiktoken-flash.txt │ │ │ ├── fuji-70B-v3-tiktoken-flash_init.txt │ │ │ ├── fuji-70B-v3-tiktoken-flash_regularizer.txt │ │ │ ├── fuji-70B-v3-tiktoken-fp8.txt │ │ │ ├── fuji-70B-v3-tiktoken-fp8_init.txt │ │ │ ├── fuji-70B-v3-tiktoken-fp8_regularizer.txt │ │ │ ├── fuji-70B-v3-tiktoken.txt │ │ │ ├── fuji-70B-v3-tiktoken_init.txt │ │ │ ├── fuji-70B-v3-tiktoken_regularizer.txt │ │ │ ├── fuji-70B-v3.txt │ │ │ ├── fuji-70B-v3_init.txt │ │ │ ├── fuji-70B-v3_regularizer.txt │ │ │ ├── fuji-7B-v1-flash-fp8-single-host.txt │ │ │ ├── fuji-7B-v1-flash-fp8-single-host_init.txt │ │ │ ├── fuji-7B-v1-flash-fp8-single-host_regularizer.txt │ │ │ ├── fuji-7B-v1-flash-fp8.txt │ │ │ ├── fuji-7B-v1-flash-fp8_init.txt │ │ │ ├── fuji-7B-v1-flash-fp8_regularizer.txt │ │ │ ├── fuji-7B-v1-flash-single-host.txt │ │ │ ├── fuji-7B-v1-flash-single-host_init.txt │ │ │ ├── fuji-7B-v1-flash-single-host_regularizer.txt │ │ │ ├── fuji-7B-v1-flash.txt │ │ │ ├── fuji-7B-v1-flash_init.txt │ │ │ ├── fuji-7B-v1-flash_regularizer.txt │ │ │ ├── fuji-7B-v1-fp8-single-host.txt │ │ │ ├── fuji-7B-v1-fp8-single-host_init.txt │ │ │ ├── fuji-7B-v1-fp8-single-host_regularizer.txt │ │ │ ├── fuji-7B-v1-fp8.txt │ │ │ ├── fuji-7B-v1-fp8_init.txt │ │ │ ├── fuji-7B-v1-fp8_regularizer.txt │ │ │ ├── fuji-7B-v1-single-host.txt │ │ │ ├── fuji-7B-v1-single-host_init.txt │ │ │ ├── fuji-7B-v1-single-host_regularizer.txt │ │ │ ├── fuji-7B-v1.txt │ │ │ ├── fuji-7B-v1_init.txt │ │ │ ├── fuji-7B-v1_regularizer.txt │ │ │ ├── fuji-7B-v2-flash-fp8-single-host.txt │ │ │ ├── fuji-7B-v2-flash-fp8-single-host_init.txt │ │ │ ├── fuji-7B-v2-flash-fp8-single-host_regularizer.txt │ │ │ ├── fuji-7B-v2-flash-fp8.txt │ │ │ ├── fuji-7B-v2-flash-fp8_init.txt │ │ │ ├── fuji-7B-v2-flash-fp8_regularizer.txt │ │ │ ├── fuji-7B-v2-flash-single-host.txt │ │ │ ├── fuji-7B-v2-flash-single-host_init.txt │ │ │ ├── fuji-7B-v2-flash-single-host_regularizer.txt │ │ │ ├── fuji-7B-v2-flash.txt │ │ │ ├── fuji-7B-v2-flash_init.txt │ │ │ ├── fuji-7B-v2-flash_regularizer.txt │ │ │ ├── fuji-7B-v2-fp8-single-host.txt │ │ │ ├── fuji-7B-v2-fp8-single-host_init.txt │ │ │ ├── fuji-7B-v2-fp8-single-host_regularizer.txt │ │ │ ├── fuji-7B-v2-fp8.txt │ │ │ ├── fuji-7B-v2-fp8_init.txt │ │ │ ├── fuji-7B-v2-fp8_regularizer.txt │ │ │ ├── fuji-7B-v2-single-host.txt │ │ │ ├── fuji-7B-v2-single-host_init.txt │ │ │ ├── fuji-7B-v2-single-host_regularizer.txt │ │ │ ├── fuji-7B-v2.txt │ │ │ ├── fuji-7B-v2_init.txt │ │ │ ├── fuji-7B-v2_regularizer.txt │ │ │ ├── fuji-7B-v3-flash-fp8-single-host.txt │ │ │ ├── fuji-7B-v3-flash-fp8-single-host_init.txt │ │ │ ├── fuji-7B-v3-flash-fp8-single-host_regularizer.txt │ │ │ ├── fuji-7B-v3-flash-fp8.txt │ │ │ ├── fuji-7B-v3-flash-fp8_init.txt │ │ │ ├── fuji-7B-v3-flash-fp8_regularizer.txt │ │ │ ├── fuji-7B-v3-flash-single-host.txt │ │ │ ├── fuji-7B-v3-flash-single-host_init.txt │ │ │ ├── fuji-7B-v3-flash-single-host_regularizer.txt │ │ │ ├── fuji-7B-v3-flash.txt │ │ │ ├── fuji-7B-v3-flash_init.txt │ │ │ ├── fuji-7B-v3-flash_regularizer.txt │ │ │ ├── fuji-7B-v3-fp8-single-host.txt │ │ │ ├── fuji-7B-v3-fp8-single-host_init.txt │ │ │ ├── fuji-7B-v3-fp8-single-host_regularizer.txt │ │ │ ├── fuji-7B-v3-fp8.txt │ │ │ ├── fuji-7B-v3-fp8_init.txt │ │ │ ├── fuji-7B-v3-fp8_regularizer.txt │ │ │ ├── fuji-7B-v3-single-host.txt │ │ │ ├── fuji-7B-v3-single-host_init.txt │ │ │ ├── fuji-7B-v3-single-host_regularizer.txt │ │ │ ├── fuji-7B-v3.txt │ │ │ ├── fuji-7B-v3_init.txt │ │ │ ├── fuji-7B-v3_regularizer.txt │ │ │ ├── fuji-8B-v3-tiktoken-flash-fp8-single-host.txt │ │ │ ├── fuji-8B-v3-tiktoken-flash-fp8-single-host_init.txt │ │ │ ├── fuji-8B-v3-tiktoken-flash-fp8-single-host_regularizer.txt │ │ │ ├── fuji-8B-v3-tiktoken-flash-fp8.txt │ │ │ ├── fuji-8B-v3-tiktoken-flash-fp8_init.txt │ │ │ ├── fuji-8B-v3-tiktoken-flash-fp8_regularizer.txt │ │ │ ├── fuji-8B-v3-tiktoken-flash-single-host.txt │ │ │ ├── fuji-8B-v3-tiktoken-flash-single-host_init.txt │ │ │ ├── fuji-8B-v3-tiktoken-flash-single-host_regularizer.txt │ │ │ ├── fuji-8B-v3-tiktoken-flash.txt │ │ │ ├── fuji-8B-v3-tiktoken-flash_init.txt │ │ │ ├── fuji-8B-v3-tiktoken-flash_regularizer.txt │ │ │ ├── fuji-8B-v3-tiktoken-fp8-single-host.txt │ │ │ ├── fuji-8B-v3-tiktoken-fp8-single-host_init.txt │ │ │ ├── fuji-8B-v3-tiktoken-fp8-single-host_regularizer.txt │ │ │ ├── fuji-8B-v3-tiktoken-fp8.txt │ │ │ ├── fuji-8B-v3-tiktoken-fp8_init.txt │ │ │ ├── fuji-8B-v3-tiktoken-fp8_regularizer.txt │ │ │ ├── fuji-8B-v3-tiktoken-single-host.txt │ │ │ ├── fuji-8B-v3-tiktoken-single-host_init.txt │ │ │ ├── fuji-8B-v3-tiktoken-single-host_regularizer.txt │ │ │ ├── fuji-8B-v3-tiktoken.txt │ │ │ ├── fuji-8B-v3-tiktoken_init.txt │ │ │ ├── fuji-8B-v3-tiktoken_regularizer.txt │ │ │ ├── fuji-golden-run-test-v1.txt │ │ │ ├── fuji-golden-run-test-v1_init.txt │ │ │ ├── fuji-golden-run-test-v1_regularizer.txt │ │ │ ├── fuji-golden-run-test-v2.txt │ │ │ ├── fuji-golden-run-test-v2_init.txt │ │ │ ├── fuji-golden-run-test-v2_regularizer.txt │ │ │ ├── fuji-golden-run-test-v3-tiktoken.txt │ │ │ ├── fuji-golden-run-test-v3-tiktoken_init.txt │ │ │ ├── fuji-golden-run-test-v3-tiktoken_regularizer.txt │ │ │ ├── fuji-golden-run-test-v3.txt │ │ │ ├── fuji-golden-run-test-v3_init.txt │ │ │ ├── fuji-golden-run-test-v3_regularizer.txt │ │ │ ├── fuji-test-v1-flash.txt │ │ │ ├── fuji-test-v1-flash_init.txt │ │ │ ├── fuji-test-v1-flash_regularizer.txt │ │ │ ├── fuji-test-v1.txt │ │ │ ├── fuji-test-v1_init.txt │ │ │ ├── fuji-test-v1_regularizer.txt │ │ │ ├── fuji-test-v2-flash.txt │ │ │ ├── fuji-test-v2-flash_init.txt │ │ │ ├── fuji-test-v2-flash_regularizer.txt │ │ │ ├── fuji-test-v2.txt │ │ │ ├── fuji-test-v2_init.txt │ │ │ ├── fuji-test-v2_regularizer.txt │ │ │ ├── fuji-test-v3-flash.txt │ │ │ ├── fuji-test-v3-flash_init.txt │ │ │ ├── fuji-test-v3-flash_regularizer.txt │ │ │ ├── fuji-test-v3-tiktoken-flash.txt │ │ │ ├── fuji-test-v3-tiktoken-flash_init.txt │ │ │ ├── fuji-test-v3-tiktoken-flash_regularizer.txt │ │ │ ├── fuji-test-v3-tiktoken.txt │ │ │ ├── fuji-test-v3-tiktoken_init.txt │ │ │ ├── fuji-test-v3-tiktoken_regularizer.txt │ │ │ ├── fuji-test-v3.txt │ │ │ ├── fuji-test-v3_init.txt │ │ │ ├── fuji-test-v3_regularizer.txt │ │ │ ├── gspmd-16B-2x16x8-stream.txt │ │ │ ├── gspmd-16B-2x16x8-stream_init.txt │ │ │ ├── gspmd-16B-2x16x8-stream_regularizer.txt │ │ │ ├── gspmd-16B-2x16x8.txt │ │ │ ├── gspmd-16B-2x16x8_init.txt │ │ │ └── gspmd-16B-2x16x8_regularizer.txt │ │ ├── axlearn.experiments.text.gpt.deterministic_trainer │ │ │ ├── gala-1B-flash-pajama-2t-32k.txt │ │ │ ├── gala-1B-hybridnorm-alibi-flash-pajama-2t-32k.txt │ │ │ ├── gala-302M-flash-pajama-2t-32k.txt │ │ │ ├── gala-7B-flash-pajama-2t-32k.txt │ │ │ ├── gala-7B-hybridnorm-alibi-flash-pajama-2t-32k.txt │ │ │ ├── gala-85M-flash-pajama-2t-32k.txt │ │ │ ├── gala-test-flash-pajama-2t-32k.txt │ │ │ ├── honeycrisp-3B-flash-pajama-15t-49k.txt │ │ │ ├── honeycrisp-85M-flash-pajama-15t-49k.txt │ │ │ └── honeycrisp-test-flash-pajama-15t-49k.txt │ │ ├── axlearn.experiments.text.gpt.pajama_sigmoid_trainer │ │ │ ├── gala-sigmoid-1B-4k-hybridnorm-alibi-sp-rp.txt │ │ │ ├── gala-sigmoid-1B-4k-hybridnorm-alibi-sp-rp_init.txt │ │ │ ├── gala-sigmoid-1B-4k-hybridnorm-alibi-sp-rp_regularizer.txt │ │ │ ├── gala-sigmoid-1B-deterministic-4k-hybridnorm-alibi-pajama-2t-32k.txt │ │ │ ├── gala-sigmoid-1B-deterministic-4k-hybridnorm-alibi-pajama-2t-32k_init.txt │ │ │ ├── gala-sigmoid-1B-deterministic-4k-hybridnorm-alibi-pajama-2t-32k_regularizer.txt │ │ │ ├── gala-sigmoid-7B-4k-hybridnorm-alibi-sp-rp.txt │ │ │ ├── gala-sigmoid-7B-4k-hybridnorm-alibi-sp-rp_init.txt │ │ │ ├── gala-sigmoid-7B-4k-hybridnorm-alibi-sp-rp_regularizer.txt │ │ │ ├── gala-sigmoid-7B-deterministic-4k-hybridnorm-alibi-pajama-2t-32k.txt │ │ │ ├── gala-sigmoid-7B-deterministic-4k-hybridnorm-alibi-pajama-2t-32k_init.txt │ │ │ ├── gala-sigmoid-7B-deterministic-4k-hybridnorm-alibi-pajama-2t-32k_regularizer.txt │ │ │ ├── gala-sigmoid-85M-4k-hybridnorm-alibi-sp-rp.txt │ │ │ ├── gala-sigmoid-85M-4k-hybridnorm-alibi-sp-rp_init.txt │ │ │ ├── gala-sigmoid-85M-4k-hybridnorm-alibi-sp-rp_regularizer.txt │ │ │ ├── gala-sigmoid-85M-deterministic-4k-hybridnorm-alibi-pajama-2t-32k.txt │ │ │ ├── gala-sigmoid-85M-deterministic-4k-hybridnorm-alibi-pajama-2t-32k_init.txt │ │ │ └── gala-sigmoid-85M-deterministic-4k-hybridnorm-alibi-pajama-2t-32k_regularizer.txt │ │ ├── axlearn.experiments.text.gpt.pajama_trainer │ │ │ ├── gala-1B-flash-sp-rp.txt │ │ │ ├── gala-1B-flash-sp-rp_init.txt │ │ │ ├── gala-1B-flash-sp-rp_regularizer.txt │ │ │ ├── gala-1B-hybridnorm-alibi-flash-sp-rp.txt │ │ │ ├── gala-1B-hybridnorm-alibi-flash-sp-rp_init.txt │ │ │ ├── gala-1B-hybridnorm-alibi-flash-sp-rp_regularizer.txt │ │ │ ├── gala-1B-sp-rp.txt │ │ │ ├── gala-1B-sp-rp_init.txt │ │ │ ├── gala-1B-sp-rp_regularizer.txt │ │ │ ├── gala-302M-flash-sp-rp.txt │ │ │ ├── gala-302M-flash-sp-rp_init.txt │ │ │ ├── gala-302M-flash-sp-rp_regularizer.txt │ │ │ ├── gala-302M-sp-rp.txt │ │ │ ├── gala-302M-sp-rp_init.txt │ │ │ ├── gala-302M-sp-rp_regularizer.txt │ │ │ ├── gala-7B-flash-sp-rp.txt │ │ │ ├── gala-7B-flash-sp-rp_init.txt │ │ │ ├── gala-7B-flash-sp-rp_regularizer.txt │ │ │ ├── gala-7B-hybridnorm-alibi-flash-sp-rp.txt │ │ │ ├── gala-7B-hybridnorm-alibi-flash-sp-rp_init.txt │ │ │ ├── gala-7B-hybridnorm-alibi-flash-sp-rp_regularizer.txt │ │ │ ├── gala-7B-hybridnorm-alibi-sp-rp.txt │ │ │ ├── gala-7B-hybridnorm-alibi-sp-rp_init.txt │ │ │ ├── gala-7B-hybridnorm-alibi-sp-rp_regularizer.txt │ │ │ ├── gala-7B-sp-rp.txt │ │ │ ├── gala-7B-sp-rp_init.txt │ │ │ ├── gala-7B-sp-rp_regularizer.txt │ │ │ ├── gala-85M-flash-sp-rp.txt │ │ │ ├── gala-85M-flash-sp-rp_init.txt │ │ │ ├── gala-85M-flash-sp-rp_regularizer.txt │ │ │ ├── gala-85M-sp-rp.txt │ │ │ ├── gala-85M-sp-rp_init.txt │ │ │ ├── gala-85M-sp-rp_regularizer.txt │ │ │ ├── gala-test-flash-sp-rp.txt │ │ │ ├── gala-test-flash-sp-rp_init.txt │ │ │ ├── gala-test-flash-sp-rp_regularizer.txt │ │ │ ├── gala-test-sp-rp.txt │ │ │ ├── gala-test-sp-rp_init.txt │ │ │ ├── gala-test-sp-rp_regularizer.txt │ │ │ ├── honeycrisp-3B-flash-sp-rp.txt │ │ │ ├── honeycrisp-3B-flash-sp-rp_init.txt │ │ │ ├── honeycrisp-3B-flash-sp-rp_regularizer.txt │ │ │ ├── honeycrisp-3B-sp-rp.txt │ │ │ ├── honeycrisp-3B-sp-rp_init.txt │ │ │ ├── honeycrisp-3B-sp-rp_regularizer.txt │ │ │ ├── honeycrisp-85M-flash-sp-rp.txt │ │ │ ├── honeycrisp-85M-flash-sp-rp_init.txt │ │ │ ├── honeycrisp-85M-flash-sp-rp_regularizer.txt │ │ │ ├── honeycrisp-85M-sp-rp.txt │ │ │ ├── honeycrisp-85M-sp-rp_init.txt │ │ │ ├── honeycrisp-85M-sp-rp_regularizer.txt │ │ │ ├── honeycrisp-test-flash-sp-rp.txt │ │ │ ├── honeycrisp-test-flash-sp-rp_init.txt │ │ │ ├── honeycrisp-test-flash-sp-rp_regularizer.txt │ │ │ ├── honeycrisp-test-sp-rp.txt │ │ │ ├── honeycrisp-test-sp-rp_init.txt │ │ │ └── honeycrisp-test-sp-rp_regularizer.txt │ │ ├── axlearn.experiments.vision.resnet.imagenet_trainer │ │ │ ├── ResNet-101.txt │ │ │ ├── ResNet-101_init.txt │ │ │ ├── ResNet-101_regularizer.txt │ │ │ ├── ResNet-152.txt │ │ │ ├── ResNet-152_init.txt │ │ │ ├── ResNet-152_regularizer.txt │ │ │ ├── ResNet-18.txt │ │ │ ├── ResNet-18_init.txt │ │ │ ├── ResNet-18_regularizer.txt │ │ │ ├── ResNet-34.txt │ │ │ ├── ResNet-34_init.txt │ │ │ ├── ResNet-34_regularizer.txt │ │ │ ├── ResNet-50-ema.txt │ │ │ ├── ResNet-50-ema_init.txt │ │ │ ├── ResNet-50-ema_regularizer.txt │ │ │ ├── ResNet-50.txt │ │ │ ├── ResNet-50_init.txt │ │ │ ├── ResNet-50_regularizer.txt │ │ │ ├── ResNet-Test.txt │ │ │ ├── ResNet-Test_init.txt │ │ │ ├── ResNet-Test_regularizer.txt │ │ │ ├── ResNet-Testb.txt │ │ │ ├── ResNet-Testb_init.txt │ │ │ └── ResNet-Testb_regularizer.txt │ │ └── axlearn_common_measurement_test │ │ │ ├── __init__.py │ │ │ └── dummy_recorder.py │ ├── text │ │ ├── __init__.py │ │ ├── common.py │ │ ├── gpt │ │ │ ├── __init__.py │ │ │ ├── c4_trainer.py │ │ │ ├── common.py │ │ │ ├── common_test.py │ │ │ ├── deterministic_trainer.py │ │ │ ├── fuji.py │ │ │ ├── gala.py │ │ │ ├── gala_sigmoid.py │ │ │ ├── gala_sigmoid_test.py │ │ │ ├── gspmd.py │ │ │ ├── honeycrisp.py │ │ │ ├── pajama_sigmoid_trainer.py │ │ │ ├── pajama_trainer.py │ │ │ ├── param_converter_test.py │ │ │ ├── vocabulary_fuji_v3.py │ │ │ └── vocabulary_fuji_v3_test.py │ │ └── train_spm.py │ ├── trainer_config_utils.py │ ├── trainer_config_utils_test.py │ └── vision │ │ ├── __init__.py │ │ ├── imagenet │ │ ├── __init__.py │ │ ├── common.py │ │ └── common_test.py │ │ └── resnet │ │ ├── __init__.py │ │ ├── common.py │ │ ├── common_test.py │ │ ├── imagenet_trainer.py │ │ └── imagenet_trainer_test.py ├── huggingface │ ├── __init__.py │ ├── hf_extractive_qa.py │ ├── hf_extractive_qa_test.py │ ├── hf_module.py │ ├── hf_pretrained_loaders.py │ ├── hf_pretrained_loaders_test.py │ ├── hf_sequence_classification.py │ ├── hf_sequence_classification_test.py │ └── hf_text_encoder.py ├── open_api │ ├── __init__.py │ ├── anthropic.py │ ├── anthropic_test.py │ ├── common.py │ ├── common_test.py │ ├── eval_set │ │ ├── __init__.py │ │ ├── mmau.py │ │ └── mmau_test.py │ ├── evaluator.py │ ├── evaluator_test.py │ ├── gemini.py │ ├── gemini_test.py │ ├── generator.py │ ├── generator_test.py │ ├── metrics │ │ ├── __init__.py │ │ ├── code_contests.py │ │ ├── code_contests_test.py │ │ ├── code_execute.py │ │ ├── code_execute_test.py │ │ ├── code_kaggle.py │ │ ├── code_kaggle_test.py │ │ ├── math.py │ │ ├── math_test.py │ │ ├── tool_use_execution.py │ │ ├── tool_use_execution_test.py │ │ ├── tool_use_execution_utils.py │ │ ├── tool_use_execution_utils_test.py │ │ ├── tool_use_plan.py │ │ └── tool_use_plan_test.py │ ├── mock_utils.py │ ├── openai.py │ ├── openai_test.py │ ├── registry.py │ └── registry_test.py └── vision │ ├── __init__.py │ ├── anchor.py │ ├── anchor_test.py │ ├── attention.py │ ├── attention_test.py │ ├── augment.py │ ├── augment_test.py │ ├── beit_image_tokenizer.py │ ├── beit_image_tokenizer_test.py │ ├── box_coder.py │ ├── box_coder_test.py │ ├── clip.py │ ├── clip_test.py │ ├── coca.py │ ├── coca_test.py │ ├── coco_evaluator.py │ ├── coco_utils.py │ ├── coco_utils_test.py │ ├── cyclip.py │ ├── cyclip_test.py │ ├── detection_generator.py │ ├── detection_generator_test.py │ ├── detection_heads.py │ ├── detection_heads_test.py │ ├── efficientdet.py │ ├── efficientdet_test.py │ ├── eval_detection.py │ ├── eval_detection_test.py │ ├── feature_tokenizer.py │ ├── feature_tokenizer_test.py │ ├── fpn.py │ ├── fpn_test.py │ ├── image_classification.py │ ├── image_classification_test.py │ ├── imagenet_adversarial_text │ ├── README.md │ ├── __init__.py │ ├── add_attack_tfrecord.py │ ├── imagenet-simple.json │ ├── openai_clip_pred_1tfrecord_target2esti.pickle │ ├── util_im_process.py │ ├── util_imagenet.py │ ├── util_tfdata.py │ └── zeroshot_eval_with_opensource_clip.ipynb │ ├── input_detection.py │ ├── input_detection_test.py │ ├── input_image.py │ ├── input_image_test.py │ ├── mask_generator.py │ ├── mask_generator_test.py │ ├── masked_image_model.py │ ├── masked_image_model_test.py │ ├── matchers.py │ ├── matchers_test.py │ ├── metrics_vqa.py │ ├── metrics_vqa_test.py │ ├── mobilenets.py │ ├── mobilenets_blocks.py │ ├── mobilenets_blocks_test.py │ ├── mobilenets_test.py │ ├── nms.py │ ├── nms_test.py │ ├── param_converter.py │ ├── param_converter_test.py │ ├── rcnn.py │ ├── rcnn_losses.py │ ├── rcnn_losses_test.py │ ├── rcnn_sampler.py │ ├── rcnn_sampler_test.py │ ├── rcnn_test.py │ ├── resnet.py │ ├── resnet_test.py │ ├── retinanet.py │ ├── retinanet_test.py │ ├── roi_aligner.py │ ├── roi_aligner_test.py │ ├── roi_generator.py │ ├── roi_generator_test.py │ ├── rpn_sampler.py │ ├── rpn_sampler_test.py │ ├── samplers.py │ ├── samplers_test.py │ ├── similarity_ops.py │ ├── similarity_ops_test.py │ ├── spatial_transform_ops.py │ ├── spatial_transform_ops_test.py │ ├── utils_detection.py │ ├── utils_detection_test.py │ ├── utils_visualization.py │ ├── virtex.py │ ├── virtex_test.py │ ├── vitdet_transformer.py │ ├── vitdet_transformer_test.py │ ├── window_attention.py │ └── window_attention_test.py ├── conftest.py ├── docs ├── 01-start.md ├── 02-concepts.md ├── 03-cli.md ├── 04-infrastructure.md ├── 05-Goodput-Monitoring.md ├── ml_api_style.md └── research │ └── mmau │ ├── README.md │ └── figures │ ├── MMAU-herofig.png │ └── results_radar_bar_combined.png ├── pyproject.toml └── run_tests.sh /.github/scripts/monitor_memory.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Simple memory monitor that prints to console periodically. 3 | # Designed for GitHub Actions monitoring. 4 | 5 | INFO="\033[90m" 6 | RESET="\033[0m" 7 | 8 | echo -e "${INFO}====== Starting memory monitor at $(date) ======${RESET}" 9 | 10 | while true; do 11 | # Get current timestamp 12 | timestamp=$(date '+%Y-%m-%d %H:%M:%S') 13 | 14 | echo "" 15 | echo -e "${INFO}=== Memory Check: $timestamp ===${RESET}" 16 | 17 | # Print simplified memory stats. 18 | free -h | grep "Mem:" | awk "{printf \"${INFO}Memory: %s used, %s free, %s total\n${RESET}\", \$3, \$4, \$2}" 19 | 20 | # Print memory usage percentage. 21 | mem_total=$(free | grep Mem | awk '{print $2}') 22 | mem_used=$(free | grep Mem | awk '{print $3}') 23 | mem_percent=$(awk "BEGIN {printf \"%.1f\", $mem_used/$mem_total*100}") 24 | echo -e "${INFO}Memory usage: $mem_percent%${RESET}" 25 | 26 | # Sleep for 30 seconds 27 | sleep 30 28 | done 29 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # To run locally: 2 | # % pre-commit run -a 3 | repos: 4 | - repo: https://github.com/pre-commit/pre-commit-hooks 5 | rev: v2.3.0 6 | hooks: 7 | - id: check-yaml 8 | - id: end-of-file-fixer 9 | # These files are generated. 10 | exclude: "axlearn/experiments/testdata/.*" 11 | - id: trailing-whitespace 12 | - repo: local 13 | hooks: 14 | - id: black 15 | name: black 16 | entry: black 17 | language: system 18 | types: [python] 19 | - id: isort 20 | name: isort 21 | entry: isort 22 | language: system 23 | types: [python] 24 | - id: pylint 25 | name: pylint 26 | entry: pylint 27 | args: ['--msg-template="{abspath}:{line}: [{msg_id}({symbol}), {obj}] {msg}"'] 28 | language: system 29 | types: [python] 30 | -------------------------------------------------------------------------------- /CODEOWNERS: -------------------------------------------------------------------------------- 1 | * @ruomingp @markblee @apple/axlearn-admins 2 | axlearn/cloud/ @apple/axlearn-cloud 3 | -------------------------------------------------------------------------------- /axlearn/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | """AXLearn.""" 4 | -------------------------------------------------------------------------------- /axlearn/audio/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/axlearn/383c415daefd2f09d920d6575662ca92ae8410b3/axlearn/audio/__init__.py -------------------------------------------------------------------------------- /axlearn/audio/test_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | """Speech testing utils.""" 4 | 5 | import jax 6 | import jax.numpy as jnp 7 | 8 | from axlearn.common.utils import Tensor 9 | 10 | 11 | def fake_audio( 12 | *, 13 | batch_size: int, 14 | seq_len: int, 15 | prng_key: Tensor, 16 | scale: float = 32768.0, 17 | dtype: jnp.dtype = jnp.float32, 18 | ): 19 | """Generates fake audio data with a fixed seed.""" 20 | input_key, length_key = jax.random.split(prng_key) 21 | inputs = jax.random.uniform( 22 | input_key, 23 | shape=[batch_size, seq_len], 24 | minval=-scale, 25 | maxval=scale, 26 | dtype=jnp.float32, 27 | ).astype(dtype) 28 | lengths = jax.random.randint(length_key, shape=[batch_size, 1], minval=0, maxval=seq_len) 29 | paddings = (jnp.arange(seq_len)[None, :] >= lengths).astype(jnp.bool) 30 | return inputs, paddings 31 | -------------------------------------------------------------------------------- /axlearn/cli/testdata/dummy.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | """A dummy absl program.""" 4 | 5 | from absl import app, flags 6 | 7 | flags.DEFINE_string("required", None, "A required flag.", required=True) 8 | flags.DEFINE_string("optional", None, "An optional flag.") 9 | flags.DEFINE_string("root_default", None, "A required flag defaulted at a parent.", required=True) 10 | 11 | FLAGS = flags.FLAGS 12 | 13 | 14 | def main(_): 15 | print( 16 | f"required: {FLAGS.required}, optional: {FLAGS.optional}, " 17 | f"root_default: {FLAGS.root_default}" 18 | ) 19 | 20 | 21 | if __name__ == "__main__": 22 | app.run(main) 23 | -------------------------------------------------------------------------------- /axlearn/cloud/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | """AXLearn cloud modules.""" 4 | import pathlib 5 | 6 | # Root of the cloud module, e.g. /path/to/axlearn. 7 | ROOT_MODULE = pathlib.Path(__file__).resolve().parent.parent 8 | ROOT_MODULE_NAME = ROOT_MODULE.name 9 | -------------------------------------------------------------------------------- /axlearn/cloud/common/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | """AXLearn cloud utilities.""" 4 | -------------------------------------------------------------------------------- /axlearn/cloud/common/testdata/counter.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | """A dummy script that spawns a subprocess which keeps updating a counter.""" 4 | 5 | import shlex 6 | import subprocess 7 | import sys 8 | import time 9 | 10 | 11 | def _child(path: str): 12 | print(f"emitting to {path}") 13 | for i in range(100): 14 | with open(path, "w", encoding="utf-8") as f: 15 | f.seek(0, 0) 16 | print(f"incrementing to {i}") 17 | f.write(str(i)) 18 | f.flush() 19 | time.sleep(0.1) 20 | 21 | 22 | if __name__ == "__main__": 23 | output_path, parent_or_child = sys.argv[1], sys.argv[2] 24 | 25 | if parent_or_child == "parent": 26 | # pylint: disable-next=consider-using-with 27 | p = subprocess.Popen( 28 | shlex.split(f"python3 {__file__} {output_path} child"), start_new_session=True 29 | ) 30 | print("returncode:", p.wait()) 31 | else: 32 | assert parent_or_child == "child" 33 | _child(output_path) 34 | -------------------------------------------------------------------------------- /axlearn/cloud/common/validator.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2025 Apple Inc. 2 | 3 | """Utilities to validate Jobs.""" 4 | 5 | from axlearn.cloud.common.types import JobSpec 6 | from axlearn.common.config import Configurable 7 | 8 | 9 | class JobValidator(Configurable): 10 | """A job validator interface.""" 11 | 12 | def validate(self, job: JobSpec): 13 | """Raises ValidationError with reason if jobspec is invalid.""" 14 | raise NotImplementedError(type(self)) 15 | -------------------------------------------------------------------------------- /axlearn/cloud/gcp/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | """AXLearn GCP module.""" 4 | -------------------------------------------------------------------------------- /axlearn/cloud/gcp/jobs/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2025 Apple Inc. 2 | 3 | """A collection of CLI entrypoints.""" 4 | -------------------------------------------------------------------------------- /axlearn/cloud/gcp/jobs/bastion_vm_test.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | """Tests bastion VM.""" 4 | # pylint: disable=protected-access 5 | 6 | from absl import flags 7 | 8 | from axlearn.cloud.gcp.jobs import bastion_vm 9 | from axlearn.cloud.gcp.test_utils import default_mock_settings, mock_gcp_settings 10 | from axlearn.common.test_utils import TestWithTemporaryCWD 11 | 12 | 13 | class MainTest(TestWithTemporaryCWD): 14 | """Tests CLI entrypoint.""" 15 | 16 | def test_private_flags(self): 17 | with mock_gcp_settings(bastion_vm.__name__, default_mock_settings()): 18 | fv = flags.FlagValues() 19 | bastion_vm._private_flags(flag_values=fv) 20 | # Basic sanity check. 21 | self.assertIsNotNone(fv["project"].default) 22 | self.assertIsNotNone(fv["zone"].default) 23 | self.assertIsNotNone(fv["env_id"].default) 24 | -------------------------------------------------------------------------------- /axlearn/cloud/gcp/monitoring/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/axlearn/383c415daefd2f09d920d6575662ca92ae8410b3/axlearn/cloud/gcp/monitoring/__init__.py -------------------------------------------------------------------------------- /axlearn/cloud/gcp/runners/base.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2024 Apple Inc. 2 | 3 | """Base runner interface.""" 4 | 5 | from axlearn.cloud.common.bundler import Bundler 6 | from axlearn.cloud.gcp.job import GCPJob 7 | 8 | 9 | class BaseRunnerJob(GCPJob): 10 | """Base runner job interface.""" 11 | 12 | Config = GCPJob.Config 13 | 14 | def __init__(self, cfg: Config, *, bundler: Bundler): 15 | super().__init__(cfg) 16 | self._bundler = bundler 17 | -------------------------------------------------------------------------------- /axlearn/common/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/axlearn/383c415daefd2f09d920d6575662ca92ae8410b3/axlearn/common/__init__.py -------------------------------------------------------------------------------- /axlearn/common/env_test.py: -------------------------------------------------------------------------------- 1 | """Tests for AXLearn environment.""" 2 | 3 | # pylint: disable=no-self-use 4 | 5 | import tensorflow_io # noqa: F401 # pylint: disable=unused-import 6 | from absl.testing import absltest 7 | from tensorflow import io as tf_io 8 | 9 | from axlearn.common import test_utils 10 | 11 | 12 | class EnvTest(test_utils.TestCase): 13 | def test_tf_io_s3_support(self): 14 | self.assertIn("s3", tf_io.gfile.get_registered_schemes()) 15 | 16 | 17 | if __name__ == "__main__": 18 | absltest.main() 19 | -------------------------------------------------------------------------------- /axlearn/common/flash_attention/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/axlearn/383c415daefd2f09d920d6575662ca92ae8410b3/axlearn/common/flash_attention/__init__.py -------------------------------------------------------------------------------- /axlearn/common/launch_trainer_main.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | """Main function for launching the trainer.""" 4 | 5 | from absl import app, flags 6 | 7 | from axlearn.common import launch, launch_trainer, measurement 8 | from axlearn.common.config import config_for_function 9 | 10 | 11 | def main(_): 12 | measurement.initialize(flags.FLAGS) 13 | launch.setup() 14 | trainer_config = launch_trainer.get_trainer_config() 15 | trainer_config.set(recorder=config_for_function(lambda: measurement.global_recorder)) 16 | measurement.start_monitoring() 17 | launch_trainer.run_trainer(trainer_config) 18 | 19 | 20 | if __name__ == "__main__": 21 | measurement.define_flags() 22 | app.run(main) 23 | -------------------------------------------------------------------------------- /axlearn/common/loss_metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2025 Apple Inc. 2 | 3 | """Layers for computing training time metrics.""" 4 | 5 | from axlearn.common.base_layer import BaseLayer 6 | from axlearn.common.utils import Nested, Tensor 7 | 8 | 9 | class BaseLossMetrics(BaseLayer): 10 | """A module for computing training time metrics. 11 | 12 | See `causal_lm.Model` for an example usage. 13 | """ 14 | 15 | def forward( 16 | self, 17 | input_batch: Nested[Tensor], 18 | *, 19 | predict_outputs: Nested[Tensor], 20 | module_outputs: Nested[Tensor], 21 | ) -> tuple[Tensor, Nested[Tensor]]: 22 | """Computes metrics from inputs and predictions. 23 | 24 | Args: 25 | input_batch: A mapping from input keys to Tensors. 26 | predict_outputs: Model predictions for computing metrics. 27 | module_outputs: Outputs from the model's invocation context. 28 | 29 | Returns: 30 | A tuple (loss, metrics). 31 | """ 32 | raise NotImplementedError(type(self)) 33 | -------------------------------------------------------------------------------- /axlearn/common/monitoring/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/axlearn/383c415daefd2f09d920d6575662ca92ae8410b3/axlearn/common/monitoring/__init__.py -------------------------------------------------------------------------------- /axlearn/common/normalize.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | """Normalization utils.""" 4 | import jax 5 | 6 | from axlearn.common.utils import Tensor 7 | 8 | 9 | def l2_normalize(x: Tensor, eps: float = 1e-8, axis: int = -1) -> Tensor: 10 | """l2_normalize Normalizes along the dimension `axis` using an L2 norm. 11 | 12 | Args: 13 | x: Input tensor. 14 | axis: Dimension along which to normalize. 15 | eps: A lower bound value for the norm. Defaults to 1e-8. 16 | 17 | Returns: 18 | A Tensor with the same shape as x. 19 | """ 20 | sum2 = (x * x).sum(axis=axis, keepdims=True) 21 | return x * jax.lax.rsqrt(sum2 + eps) 22 | -------------------------------------------------------------------------------- /axlearn/common/normalize_test.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | """Tests normalization utils.""" 4 | import jax.numpy as jnp 5 | import numpy as np 6 | import pytest 7 | import tensorflow as tf 8 | 9 | from axlearn.common.normalize import l2_normalize 10 | 11 | 12 | @pytest.mark.parametrize("shape, axis", [([1, 4], -1), ([1, 3], 1), ([2, 5, 4], 2), ([1, 3, 4], 0)]) 13 | def test_l2_normalize(shape, axis): 14 | x = np.random.rand(*shape) 15 | ref = tf.math.l2_normalize(x, axis) 16 | assert jnp.allclose(l2_normalize(x, eps=1e-12, axis=axis), ref.numpy()) 17 | -------------------------------------------------------------------------------- /axlearn/common/ops/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | """Custom ops.""" 4 | 5 | from ._optimization_barrier import forward_optimization_barrier 6 | -------------------------------------------------------------------------------- /axlearn/common/quantized_dot_general/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/axlearn/383c415daefd2f09d920d6575662ca92ae8410b3/axlearn/common/quantized_dot_general/__init__.py -------------------------------------------------------------------------------- /axlearn/common/ssm_kernels/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/axlearn/383c415daefd2f09d920d6575662ca92ae8410b3/axlearn/common/ssm_kernels/__init__.py -------------------------------------------------------------------------------- /axlearn/common/utils_tf.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | """Tensorflow utils.""" 4 | from typing import Union 5 | 6 | import tensorflow as tf 7 | 8 | 9 | def masked_fill(orig: tf.Tensor, mask: tf.Tensor, fill_value: Union[int, float, str]) -> tf.Tensor: 10 | """Replaces values in `orig` with `fill_value` where `mask` is true. 11 | 12 | Args: 13 | orig: A Tensor representing the original values. 14 | mask: A boolean Tensor of the same shape as `orig`, 15 | representing where the values should be replaced. 16 | fill_value: The value to fill where mask is True. 17 | 18 | Returns: 19 | A Tensor of the same size and dtype as `orig`, but with some values replaced with 20 | `fill_value` where `mask` is true. 21 | """ 22 | fill = tf.cast(tf.fill(tf.shape(orig), fill_value), orig.dtype) 23 | return tf.where(mask, fill, orig) 24 | -------------------------------------------------------------------------------- /axlearn/common/utils_tf_test.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | """Tests Tensorflow utils.""" 4 | import tensorflow as tf 5 | from absl.testing import absltest, parameterized 6 | 7 | from axlearn.common.utils_tf import masked_fill 8 | 9 | 10 | class MaskedFillTest(parameterized.TestCase): 11 | @parameterized.parameters(tf.string, tf.int32, tf.float32) 12 | def test_basic(self, dtype): 13 | if dtype == tf.string: 14 | orig = ["a", "b", "c"] 15 | fill_value = "x" 16 | else: 17 | orig = [1, 2, 3] 18 | fill_value = tf.constant(-1, dtype=dtype) 19 | result = masked_fill( 20 | tf.convert_to_tensor(orig, dtype=dtype), 21 | mask=tf.convert_to_tensor([True, False, True]), 22 | fill_value=fill_value, 23 | ) 24 | if dtype == tf.string: 25 | self.assertSequenceEqual(result.numpy().tolist(), [b"x", b"b", b"x"]) 26 | else: 27 | self.assertSequenceEqual(result.numpy().tolist(), [-1, 2, -1]) 28 | 29 | 30 | if __name__ == "__main__": 31 | absltest.main() 32 | -------------------------------------------------------------------------------- /axlearn/data/tokenizers/sentencepiece/bpe_128k.json: -------------------------------------------------------------------------------- 1 | { 2 | "pad_id": 0, 3 | "eos_id": 1, 4 | "unk_id": 2, 5 | "bos_id": -1, 6 | "pad_piece": "", 7 | "eos_piece": "", 8 | "unk_piece": "", 9 | "bos_piece": "", 10 | "byte_fallback": true, 11 | "vocab_size": 128000, 12 | "user_defined_symbols": "", 13 | "control_symbols": "", 14 | "character_coverage": 0.9995, 15 | "model_type": "bpe", 16 | "max_sentence_length": 32768, 17 | "allow_whitespace_only_pieces": true, 18 | "split_by_whitespace": false, 19 | "remove_extra_whitespaces": false, 20 | "shuffle_input_sentence": false, 21 | "split_digits": true, 22 | "normalization_rule_name": "identity" 23 | } 24 | -------------------------------------------------------------------------------- /axlearn/data/tokenizers/sentencepiece/bpe_128k_c4.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/axlearn/383c415daefd2f09d920d6575662ca92ae8410b3/axlearn/data/tokenizers/sentencepiece/bpe_128k_c4.model -------------------------------------------------------------------------------- /axlearn/data/tokenizers/sentencepiece/bpe_32k.json: -------------------------------------------------------------------------------- 1 | { 2 | "pad_id": 0, 3 | "eos_id": 1, 4 | "unk_id": 2, 5 | "bos_id": -1, 6 | "pad_piece": "", 7 | "eos_piece": "", 8 | "unk_piece": "", 9 | "bos_piece": "", 10 | "byte_fallback": true, 11 | "vocab_size": 32000, 12 | "user_defined_symbols": "", 13 | "control_symbols": "", 14 | "character_coverage": 0.9995, 15 | "model_type": "bpe", 16 | "max_sentence_length": 32768, 17 | "split_by_whitespace": false, 18 | "allow_whitespace_only_pieces": true, 19 | "remove_extra_whitespaces": false, 20 | "shuffle_input_sentence": false, 21 | "split_digits": true, 22 | "normalization_rule_name": "identity" 23 | } 24 | -------------------------------------------------------------------------------- /axlearn/data/tokenizers/sentencepiece/bpe_32k_c4.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/axlearn/383c415daefd2f09d920d6575662ca92ae8410b3/axlearn/data/tokenizers/sentencepiece/bpe_32k_c4.model -------------------------------------------------------------------------------- /axlearn/data/tokenizers/sentencepiece/librispeech_bpe_1024.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/axlearn/383c415daefd2f09d920d6575662ca92ae8410b3/axlearn/data/tokenizers/sentencepiece/librispeech_bpe_1024.model -------------------------------------------------------------------------------- /axlearn/data/tokenizers/sentencepiece/librispeech_unigram_1024.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/axlearn/383c415daefd2f09d920d6575662ca92ae8410b3/axlearn/data/tokenizers/sentencepiece/librispeech_unigram_1024.model -------------------------------------------------------------------------------- /axlearn/experiments/audio/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2024 Apple Inc. 2 | 3 | """AXLearn audio experiments.""" 4 | -------------------------------------------------------------------------------- /axlearn/experiments/audio/conformer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2024 Apple Inc. 2 | 3 | """AXLearn Conformer experiments.""" 4 | 5 | from . import librispeech_trainer 6 | -------------------------------------------------------------------------------- /axlearn/experiments/audio/conformer/librispeech_trainer_test.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2024 Apple Inc. 2 | 3 | """Tests Conformer LibriSpeech configs.""" 4 | 5 | from axlearn.common import test_utils 6 | from axlearn.experiments.audio.conformer import librispeech_trainer 7 | 8 | 9 | class LibriSpeechTrainerTest(test_utils.TrainerConfigTestCase): 10 | """Tests LibriSpeech trainer.""" 11 | 12 | def test_trainer(self): 13 | self._test_with_trainer_config( 14 | librispeech_trainer.named_trainer_configs()["conformer-test-ctc"](), 15 | ) 16 | -------------------------------------------------------------------------------- /axlearn/experiments/conftest.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | """Configures pytest.""" 4 | 5 | 6 | def pytest_addoption(parser): 7 | # pylint: disable-next=import-outside-toplevel 8 | from axlearn.common.test_utils import pytest_addoption_atomic 9 | 10 | pytest_addoption_atomic( 11 | parser, 12 | "--update", 13 | action="store_true", 14 | default=False, 15 | help="If true, update all golden files.", 16 | ) 17 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.common.conformer_test/test_against_fairseq.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/axlearn/383c415daefd2f09d920d6575662ca92ae8410b3/axlearn/experiments/testdata/axlearn.common.conformer_test/test_against_fairseq.npy -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.common.encoder_decoder_test/test_against_t5x_False.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/axlearn/383c415daefd2f09d920d6575662ca92ae8410b3/axlearn/experiments/testdata/axlearn.common.encoder_decoder_test/test_against_t5x_False.npy -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.common.encoder_decoder_test/test_against_t5x_True.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/axlearn/383c415daefd2f09d920d6575662ca92ae8410b3/axlearn/experiments/testdata/axlearn.common.encoder_decoder_test/test_against_t5x_True.npy -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.common.param_converter_test/test_parameters_from_t5x_attention.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/axlearn/383c415daefd2f09d920d6575662ca92ae8410b3/axlearn/experiments/testdata/axlearn.common.param_converter_test/test_parameters_from_t5x_attention.npy -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.common.param_converter_test/test_parameters_from_t5x_decoder.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/axlearn/383c415daefd2f09d920d6575662ca92ae8410b3/axlearn/experiments/testdata/axlearn.common.param_converter_test/test_parameters_from_t5x_decoder.npy -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.common.param_converter_test/test_parameters_from_t5x_dense.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/axlearn/383c415daefd2f09d920d6575662ca92ae8410b3/axlearn/experiments/testdata/axlearn.common.param_converter_test/test_parameters_from_t5x_dense.npy -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.common.param_converter_test/test_parameters_from_t5x_embedding.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/axlearn/383c415daefd2f09d920d6575662ca92ae8410b3/axlearn/experiments/testdata/axlearn.common.param_converter_test/test_parameters_from_t5x_embedding.npy -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.common.param_converter_test/test_parameters_from_t5x_encoder.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/axlearn/383c415daefd2f09d920d6575662ca92ae8410b3/axlearn/experiments/testdata/axlearn.common.param_converter_test/test_parameters_from_t5x_encoder.npy -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.common.param_converter_test/test_parameters_from_t5x_encoder_decoder.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/axlearn/383c415daefd2f09d920d6575662ca92ae8410b3/axlearn/experiments/testdata/axlearn.common.param_converter_test/test_parameters_from_t5x_encoder_decoder.npy -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.common.param_converter_test/test_parameters_from_t5x_ff.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/axlearn/383c415daefd2f09d920d6575662ca92ae8410b3/axlearn/experiments/testdata/axlearn.common.param_converter_test/test_parameters_from_t5x_ff.npy -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.common.param_converter_test/test_parameters_from_t5x_layer_norm.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/axlearn/383c415daefd2f09d920d6575662ca92ae8410b3/axlearn/experiments/testdata/axlearn.common.param_converter_test/test_parameters_from_t5x_layer_norm.npy -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.common.param_converter_test/test_parameters_from_t5x_rel_pos_emb_False.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/axlearn/383c415daefd2f09d920d6575662ca92ae8410b3/axlearn/experiments/testdata/axlearn.common.param_converter_test/test_parameters_from_t5x_rel_pos_emb_False.npy -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.common.param_converter_test/test_parameters_from_t5x_rel_pos_emb_True.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/axlearn/383c415daefd2f09d920d6575662ca92ae8410b3/axlearn/experiments/testdata/axlearn.common.param_converter_test/test_parameters_from_t5x_rel_pos_emb_True.npy -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.common.param_converter_test/test_parameters_from_t5x_transformer_layer.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/axlearn/383c415daefd2f09d920d6575662ca92ae8410b3/axlearn/experiments/testdata/axlearn.common.param_converter_test/test_parameters_from_t5x_transformer_layer.npy -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.common.quantizer_test/test_forward_against_fairseq.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/axlearn/383c415daefd2f09d920d6575662ca92ae8410b3/axlearn/experiments/testdata/axlearn.common.quantizer_test/test_forward_against_fairseq.npy -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.common.t5_test/test_buckets_against_t5x_False_100.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/axlearn/383c415daefd2f09d920d6575662ca92ae8410b3/axlearn/experiments/testdata/axlearn.common.t5_test/test_buckets_against_t5x_False_100.npy -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.common.t5_test/test_buckets_against_t5x_False_20.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/axlearn/383c415daefd2f09d920d6575662ca92ae8410b3/axlearn/experiments/testdata/axlearn.common.t5_test/test_buckets_against_t5x_False_20.npy -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.common.t5_test/test_buckets_against_t5x_False_256.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/axlearn/383c415daefd2f09d920d6575662ca92ae8410b3/axlearn/experiments/testdata/axlearn.common.t5_test/test_buckets_against_t5x_False_256.npy -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.common.t5_test/test_buckets_against_t5x_True_100.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/axlearn/383c415daefd2f09d920d6575662ca92ae8410b3/axlearn/experiments/testdata/axlearn.common.t5_test/test_buckets_against_t5x_True_100.npy -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.common.t5_test/test_buckets_against_t5x_True_20.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/axlearn/383c415daefd2f09d920d6575662ca92ae8410b3/axlearn/experiments/testdata/axlearn.common.t5_test/test_buckets_against_t5x_True_20.npy -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.common.t5_test/test_buckets_against_t5x_True_256.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/axlearn/383c415daefd2f09d920d6575662ca92ae8410b3/axlearn/experiments/testdata/axlearn.common.t5_test/test_buckets_against_t5x_True_256.npy -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash-fp8-single-host_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[131072, 2048], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 48, 64), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 32, 64), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(2048, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(2048, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(8192, 2048), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash-fp8-single-host_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash-fp8_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[131072, 2048], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 48, 64), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 32, 64), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(2048, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(2048, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(8192, 2048), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash-fp8_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash-single-host_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[131072, 2048], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 48, 64), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 32, 64), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(2048, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(2048, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(8192, 2048), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash-single-host_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[131072, 2048], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 48, 64), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 32, 64), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(2048, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(2048, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(8192, 2048), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-fp8-single-host_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[131072, 2048], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 48, 64), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 32, 64), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(2048, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(2048, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(8192, 2048), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-fp8-single-host_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-fp8_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[131072, 2048], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 48, 64), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 32, 64), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(2048, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(2048, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(8192, 2048), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-fp8_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-single-host_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[131072, 2048], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 48, 64), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 32, 64), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(2048, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(2048, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(8192, 2048), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-single-host_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken-flash-fp8-single-host_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken-flash-fp8_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 2048], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 48, 64), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 32, 64), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(2048, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(2048, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(8192, 2048), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken-flash-fp8_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken-flash-single-host_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 2048], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 48, 64), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 32, 64), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(2048, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(2048, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(8192, 2048), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken-flash-single-host_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken-flash_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 2048], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 48, 64), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 32, 64), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(2048, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(2048, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(8192, 2048), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken-flash_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken-fp8-single-host_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken-fp8_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 2048], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 48, 64), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 32, 64), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(2048, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(2048, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(8192, 2048), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken-fp8_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken-single-host_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 2048], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 48, 64), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 32, 64), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(2048, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(2048, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(8192, 2048), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken-single-host_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 2048], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 48, 64), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 32, 64), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(2048, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(2048, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(8192, 2048), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[131072, 2048], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 48, 64), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 32, 64), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(2048, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(2048, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(8192, 2048), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash-fp8-single-host_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[131072, 3072], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(3072, 40, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(3072, 24, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(3072, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(3072, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(8192, 3072), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash-fp8-single-host_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash-fp8_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[131072, 3072], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(3072, 40, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(3072, 24, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(3072, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(3072, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(8192, 3072), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash-fp8_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash-single-host_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[131072, 3072], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(3072, 40, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(3072, 24, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(3072, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(3072, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(8192, 3072), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash-single-host_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[131072, 3072], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(3072, 40, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(3072, 24, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(3072, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(3072, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(8192, 3072), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-fp8-single-host_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[131072, 3072], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(3072, 40, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(3072, 24, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(3072, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(3072, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(8192, 3072), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-fp8-single-host_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-fp8_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[131072, 3072], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(3072, 40, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(3072, 24, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(3072, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(3072, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(8192, 3072), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-fp8_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-single-host_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[131072, 3072], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(3072, 40, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(3072, 24, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(3072, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(3072, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(8192, 3072), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-single-host_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken-flash-fp8-single-host_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken-flash-fp8_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 3072], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(3072, 40, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(3072, 24, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(3072, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(3072, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(8192, 3072), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken-flash-fp8_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken-flash-single-host_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken-flash_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 3072], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(3072, 40, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(3072, 24, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(3072, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(3072, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(8192, 3072), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken-flash_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken-fp8-single-host_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken-fp8_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 3072], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(3072, 40, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(3072, 24, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(3072, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(3072, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(8192, 3072), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken-fp8_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken-single-host_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 3072], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(3072, 40, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(3072, 24, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(3072, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(3072, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(8192, 3072), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken-single-host_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 3072], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(3072, 40, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(3072, 24, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(3072, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(3072, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(8192, 3072), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[131072, 3072], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(3072, 40, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(3072, 24, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(3072, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(3072, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(8192, 3072), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1-flash-fp8_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/lm_head/weight: 1 4 | decoder/output_norm/scale: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 8 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 11 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 12 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1-flash_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/lm_head/weight: 1 4 | decoder/output_norm/scale: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 8 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 11 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 12 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1-fp8_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/lm_head/weight: 1 4 | decoder/output_norm/scale: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 8 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 11 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 12 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/lm_head/weight: 1 4 | decoder/output_norm/scale: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 8 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 11 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 12 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2-flash-fp8_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/lm_head/weight: 1 4 | decoder/output_norm/scale: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 8 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 11 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 12 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2-flash_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/lm_head/weight: 1 4 | decoder/output_norm/scale: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 8 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 11 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 12 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2-fp8_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/lm_head/weight: 1 4 | decoder/output_norm/scale: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 8 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 11 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 12 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/lm_head/weight: 1 4 | decoder/output_norm/scale: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 8 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 11 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 12 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash-fp8_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/lm_head/weight: 1 4 | decoder/output_norm/scale: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 8 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 11 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 12 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/lm_head/weight: 1 4 | decoder/output_norm/scale: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 8 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 11 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 12 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-fp8_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/lm_head/weight: 1 4 | decoder/output_norm/scale: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 8 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 11 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 12 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-tiktoken-flash-fp8_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/lm_head/weight: 1 4 | decoder/output_norm/scale: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 8 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 11 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 12 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-tiktoken-flash_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/lm_head/weight: 1 4 | decoder/output_norm/scale: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 8 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 11 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 12 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-tiktoken-fp8_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/lm_head/weight: 1 4 | decoder/output_norm/scale: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 8 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 11 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 12 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-tiktoken_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/lm_head/weight: 1 4 | decoder/output_norm/scale: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 8 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 11 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 12 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/lm_head/weight: 1 4 | decoder/output_norm/scale: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 8 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 11 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 12 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-flash-fp8-single-host_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-flash-fp8_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[32768, 4096], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 32, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 32, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(4096, 11008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(4096, 11008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(11008, 4096), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-flash-fp8_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-flash-single-host_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[32768, 4096], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 32, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 32, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(4096, 11008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(4096, 11008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(11008, 4096), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-flash-single-host_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-flash_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[32768, 4096], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 32, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 32, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(4096, 11008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(4096, 11008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(11008, 4096), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-flash_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-fp8-single-host_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[32768, 4096], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 32, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 32, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(4096, 11008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(4096, 11008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(11008, 4096), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-fp8-single-host_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-fp8_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[32768, 4096], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 32, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 32, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(4096, 11008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(4096, 11008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(11008, 4096), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-fp8_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-single-host_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[32768, 4096], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 32, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 32, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(4096, 11008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(4096, 11008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(11008, 4096), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-single-host_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[32768, 4096], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 32, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 32, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(4096, 11008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(4096, 11008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(11008, 4096), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-flash-fp8-single-host_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-flash-fp8_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[32768, 4096], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 32, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 32, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(4096, 11008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(4096, 11008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(11008, 4096), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-flash-fp8_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-flash-single-host_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[32768, 4096], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 32, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 32, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(4096, 11008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(4096, 11008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(11008, 4096), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-flash-single-host_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-flash_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[32768, 4096], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 32, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 32, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(4096, 11008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(4096, 11008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(11008, 4096), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-flash_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-fp8-single-host_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[32768, 4096], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 32, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 32, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(4096, 11008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(4096, 11008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(11008, 4096), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-fp8-single-host_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-fp8_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[32768, 4096], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 32, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 32, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(4096, 11008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(4096, 11008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(11008, 4096), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-fp8_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-single-host_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[32768, 4096], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 32, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 32, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(4096, 11008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(4096, 11008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(11008, 4096), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-single-host_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[32768, 4096], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 32, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 32, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(4096, 11008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(4096, 11008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(11008, 4096), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3-flash-fp8-single-host_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3-flash-fp8_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[131072, 4096], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 48, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 32, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(4096, 11008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(4096, 11008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(11008, 4096), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3-flash-fp8_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3-flash-single-host_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[131072, 4096], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 48, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 32, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(4096, 11008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(4096, 11008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(11008, 4096), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3-flash-single-host_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3-flash_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[131072, 4096], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 48, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 32, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(4096, 11008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(4096, 11008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(11008, 4096), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3-flash_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3-fp8-single-host_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[131072, 4096], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 48, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 32, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(4096, 11008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(4096, 11008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(11008, 4096), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3-fp8-single-host_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3-fp8_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[131072, 4096], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 48, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 32, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(4096, 11008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(4096, 11008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(11008, 4096), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3-fp8_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3-single-host_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[131072, 4096], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 48, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 32, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(4096, 11008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(4096, 11008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(11008, 4096), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3-single-host_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[131072, 4096], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 48, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 32, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(4096, 11008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(4096, 11008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(11008, 4096), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-flash-fp8-single-host_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/lm_head/weight: 1 4 | decoder/output_norm/scale: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 8 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 11 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 12 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-flash-fp8_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/lm_head/weight: 1 4 | decoder/output_norm/scale: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 8 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 11 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 12 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-flash-single-host_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/lm_head/weight: 1 4 | decoder/output_norm/scale: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 8 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 11 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 12 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-flash_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/lm_head/weight: 1 4 | decoder/output_norm/scale: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 8 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 11 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 12 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-fp8-single-host_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/lm_head/weight: 1 4 | decoder/output_norm/scale: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 8 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 11 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 12 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-fp8_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/lm_head/weight: 1 4 | decoder/output_norm/scale: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 8 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 11 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 12 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-single-host_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/lm_head/weight: 1 4 | decoder/output_norm/scale: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 8 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 11 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 12 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/lm_head/weight: 1 4 | decoder/output_norm/scale: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 8 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 11 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 12 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-golden-run-test-v1_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[32, 8], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(8, 8, 2), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(8, 4, 2), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(8, 32), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(8, 32), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(32, 8), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-golden-run-test-v1_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-golden-run-test-v2_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[32, 8], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(8, 8, 2), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(8, 4, 2), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(8, 32), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(8, 32), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(32, 8), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-golden-run-test-v2_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-golden-run-test-v3-tiktoken_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[32, 8], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(8, 8, 2), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(8, 4, 2), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(8, 32), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(8, 32), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(32, 8), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-golden-run-test-v3-tiktoken_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-golden-run-test-v3_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[32, 8], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(8, 8, 2), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(8, 4, 2), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(8, 32), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(8, 32), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(32, 8), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-golden-run-test-v3_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v1-flash_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[32, 8], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(8, 8, 2), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(8, 4, 2), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(8, 32), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(8, 32), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(32, 8), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v1-flash_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v1_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[32, 8], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(8, 8, 2), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(8, 4, 2), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(8, 32), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(8, 32), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(32, 8), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v1_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v2-flash_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[32, 8], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(8, 8, 2), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(8, 4, 2), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(8, 32), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(8, 32), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(32, 8), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v2-flash_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v2_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[32, 8], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(8, 8, 2), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(8, 4, 2), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(8, 32), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(8, 32), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(32, 8), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v2_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v3-flash_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[32, 8], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(8, 8, 2), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(8, 4, 2), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(8, 32), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(8, 32), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(32, 8), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v3-flash_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v3-tiktoken-flash_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[32, 8], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(8, 8, 2), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(8, 4, 2), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(8, 32), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(8, 32), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(32, 8), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v3-tiktoken-flash_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v3-tiktoken_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[32, 8], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(8, 8, 2), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(8, 4, 2), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(8, 32), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(8, 32), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(32, 8), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v3-tiktoken_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v3_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[32, 8], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(8, 8, 2), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(8, 4, 2), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(8, 32), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(8, 32), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(32, 8), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v3_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_sigmoid_trainer/gala-sigmoid-1B-4k-hybridnorm-alibi-sp-rp_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/postnorm/scale: 1 8 | decoder/transformer/repeat/layer/feed_forward/prenorm/scale: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/qkv_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 11 | decoder/transformer/repeat/layer/self_attention/attention/scale_key/norm/scale: 1 12 | decoder/transformer/repeat/layer/self_attention/attention/scale_query/norm/scale: 1 13 | decoder/transformer/repeat/layer/self_attention/postnorm/scale: 1 14 | decoder/transformer/repeat/layer/self_attention/prenorm/scale: 1 15 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_sigmoid_trainer/gala-sigmoid-1B-deterministic-4k-hybridnorm-alibi-pajama-2t-32k_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/postnorm/scale: 1 8 | decoder/transformer/repeat/layer/feed_forward/prenorm/scale: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/qkv_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 11 | decoder/transformer/repeat/layer/self_attention/attention/scale_key/norm/scale: 1 12 | decoder/transformer/repeat/layer/self_attention/attention/scale_query/norm/scale: 1 13 | decoder/transformer/repeat/layer/self_attention/postnorm/scale: 1 14 | decoder/transformer/repeat/layer/self_attention/prenorm/scale: 1 15 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_sigmoid_trainer/gala-sigmoid-7B-4k-hybridnorm-alibi-sp-rp_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/postnorm/scale: 1 8 | decoder/transformer/repeat/layer/feed_forward/prenorm/scale: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/qkv_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 11 | decoder/transformer/repeat/layer/self_attention/attention/scale_key/norm/scale: 1 12 | decoder/transformer/repeat/layer/self_attention/attention/scale_query/norm/scale: 1 13 | decoder/transformer/repeat/layer/self_attention/postnorm/scale: 1 14 | decoder/transformer/repeat/layer/self_attention/prenorm/scale: 1 15 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_sigmoid_trainer/gala-sigmoid-7B-deterministic-4k-hybridnorm-alibi-pajama-2t-32k_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/postnorm/scale: 1 8 | decoder/transformer/repeat/layer/feed_forward/prenorm/scale: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/qkv_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 11 | decoder/transformer/repeat/layer/self_attention/attention/scale_key/norm/scale: 1 12 | decoder/transformer/repeat/layer/self_attention/attention/scale_query/norm/scale: 1 13 | decoder/transformer/repeat/layer/self_attention/postnorm/scale: 1 14 | decoder/transformer/repeat/layer/self_attention/prenorm/scale: 1 15 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_sigmoid_trainer/gala-sigmoid-85M-4k-hybridnorm-alibi-sp-rp_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/postnorm/scale: 1 8 | decoder/transformer/repeat/layer/feed_forward/prenorm/scale: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/qkv_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 11 | decoder/transformer/repeat/layer/self_attention/attention/scale_key/norm/scale: 1 12 | decoder/transformer/repeat/layer/self_attention/attention/scale_query/norm/scale: 1 13 | decoder/transformer/repeat/layer/self_attention/postnorm/scale: 1 14 | decoder/transformer/repeat/layer/self_attention/prenorm/scale: 1 15 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_sigmoid_trainer/gala-sigmoid-85M-deterministic-4k-hybridnorm-alibi-pajama-2t-32k_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/postnorm/scale: 1 8 | decoder/transformer/repeat/layer/feed_forward/prenorm/scale: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/qkv_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 11 | decoder/transformer/repeat/layer/self_attention/attention/scale_key/norm/scale: 1 12 | decoder/transformer/repeat/layer/self_attention/attention/scale_query/norm/scale: 1 13 | decoder/transformer/repeat/layer/self_attention/postnorm/scale: 1 14 | decoder/transformer/repeat/layer/self_attention/prenorm/scale: 1 15 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/gala-1B-flash-sp-rp_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/attention/scale_key/norm/scale: 1 11 | decoder/transformer/repeat/layer/self_attention/attention/scale_query/norm/scale: 1 12 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 13 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/gala-1B-hybridnorm-alibi-flash-sp-rp_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/postnorm/scale: 1 8 | decoder/transformer/repeat/layer/feed_forward/prenorm/scale: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/qkv_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 11 | decoder/transformer/repeat/layer/self_attention/attention/scale_key/norm/scale: 1 12 | decoder/transformer/repeat/layer/self_attention/attention/scale_query/norm/scale: 1 13 | decoder/transformer/repeat/layer/self_attention/postnorm/scale: 1 14 | decoder/transformer/repeat/layer/self_attention/prenorm/scale: 1 15 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/gala-1B-sp-rp_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/attention/scale_key/norm/scale: 1 11 | decoder/transformer/repeat/layer/self_attention/attention/scale_query/norm/scale: 1 12 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 13 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/gala-302M-flash-sp-rp_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/attention/scale_key/norm/scale: 1 11 | decoder/transformer/repeat/layer/self_attention/attention/scale_query/norm/scale: 1 12 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 13 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/gala-302M-sp-rp_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/attention/scale_key/norm/scale: 1 11 | decoder/transformer/repeat/layer/self_attention/attention/scale_query/norm/scale: 1 12 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 13 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/gala-7B-flash-sp-rp_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/attention/scale_key/norm/scale: 1 11 | decoder/transformer/repeat/layer/self_attention/attention/scale_query/norm/scale: 1 12 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 13 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/gala-7B-hybridnorm-alibi-flash-sp-rp_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/postnorm/scale: 1 8 | decoder/transformer/repeat/layer/feed_forward/prenorm/scale: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/qkv_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 11 | decoder/transformer/repeat/layer/self_attention/attention/scale_key/norm/scale: 1 12 | decoder/transformer/repeat/layer/self_attention/attention/scale_query/norm/scale: 1 13 | decoder/transformer/repeat/layer/self_attention/postnorm/scale: 1 14 | decoder/transformer/repeat/layer/self_attention/prenorm/scale: 1 15 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/gala-7B-hybridnorm-alibi-sp-rp_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/postnorm/scale: 1 8 | decoder/transformer/repeat/layer/feed_forward/prenorm/scale: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/qkv_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 11 | decoder/transformer/repeat/layer/self_attention/attention/scale_key/norm/scale: 1 12 | decoder/transformer/repeat/layer/self_attention/attention/scale_query/norm/scale: 1 13 | decoder/transformer/repeat/layer/self_attention/postnorm/scale: 1 14 | decoder/transformer/repeat/layer/self_attention/prenorm/scale: 1 15 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/gala-7B-sp-rp_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/attention/scale_key/norm/scale: 1 11 | decoder/transformer/repeat/layer/self_attention/attention/scale_query/norm/scale: 1 12 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 13 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/gala-85M-flash-sp-rp_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/attention/scale_key/norm/scale: 1 11 | decoder/transformer/repeat/layer/self_attention/attention/scale_query/norm/scale: 1 12 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 13 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/gala-85M-sp-rp_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/attention/scale_key/norm/scale: 1 11 | decoder/transformer/repeat/layer/self_attention/attention/scale_query/norm/scale: 1 12 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 13 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/gala-test-flash-sp-rp_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/attention/scale_key/norm/scale: 1 11 | decoder/transformer/repeat/layer/self_attention/attention/scale_query/norm/scale: 1 12 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 13 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/gala-test-sp-rp_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/attention/scale_key/norm/scale: 1 11 | decoder/transformer/repeat/layer/self_attention/attention/scale_query/norm/scale: 1 12 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 13 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/honeycrisp-3B-flash-sp-rp_regularizer.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/axlearn/383c415daefd2f09d920d6575662ca92ae8410b3/axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/honeycrisp-3B-flash-sp-rp_regularizer.txt -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/honeycrisp-3B-sp-rp_regularizer.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/axlearn/383c415daefd2f09d920d6575662ca92ae8410b3/axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/honeycrisp-3B-sp-rp_regularizer.txt -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/honeycrisp-85M-flash-sp-rp_regularizer.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/axlearn/383c415daefd2f09d920d6575662ca92ae8410b3/axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/honeycrisp-85M-flash-sp-rp_regularizer.txt -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/honeycrisp-85M-sp-rp_regularizer.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/axlearn/383c415daefd2f09d920d6575662ca92ae8410b3/axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/honeycrisp-85M-sp-rp_regularizer.txt -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/honeycrisp-test-flash-sp-rp_regularizer.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/axlearn/383c415daefd2f09d920d6575662ca92ae8410b3/axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/honeycrisp-test-flash-sp-rp_regularizer.txt -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/honeycrisp-test-sp-rp_regularizer.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/axlearn/383c415daefd2f09d920d6575662ca92ae8410b3/axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/honeycrisp-test-sp-rp_regularizer.txt -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.vision.resnet.imagenet_trainer/ResNet-Test_init.txt: -------------------------------------------------------------------------------- 1 | backbone/stem/conv1/weight: normal(0, 1.4142135623730951 / fan_out), shape=[7, 7, 3, 4], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | backbone/stem/norm1/scale: constant(1.0) 3 | backbone/stem/norm1/bias: constant(0.0) 4 | backbone/stem/norm1/moving_mean: constant(0.0) 5 | backbone/stem/norm1/moving_variance: constant(1.0) 6 | classifier/weight: normal(0, 1.4142135623730951 / fan_out), shape=(4, 1000), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | classifier/bias: constant(0.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.vision.resnet.imagenet_trainer/ResNet-Test_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | backbone/stem/conv1/weight: 1.0 3 | backbone/stem/norm1/bias: 0 4 | backbone/stem/norm1/moving_mean: 0 5 | backbone/stem/norm1/moving_variance: 0 6 | backbone/stem/norm1/scale: 0 7 | classifier/bias: 1.0 8 | classifier/weight: 1.0 9 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.vision.resnet.imagenet_trainer/ResNet-Testb_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | backbone/stage0/block0/conv1/weight: 1.0 3 | backbone/stage0/block0/conv2/weight: 1.0 4 | backbone/stage0/block0/norm1/bias: 0 5 | backbone/stage0/block0/norm1/moving_mean: 0 6 | backbone/stage0/block0/norm1/moving_variance: 0 7 | backbone/stage0/block0/norm1/scale: 0 8 | backbone/stage0/block0/norm2/bias: 0 9 | backbone/stage0/block0/norm2/moving_mean: 0 10 | backbone/stage0/block0/norm2/moving_variance: 0 11 | backbone/stage0/block0/norm2/scale: 0 12 | backbone/stem/conv1/weight: 1.0 13 | backbone/stem/norm1/bias: 0 14 | backbone/stem/norm1/moving_mean: 0 15 | backbone/stem/norm1/moving_variance: 0 16 | backbone/stem/norm1/scale: 0 17 | classifier/bias: 1.0 18 | classifier/weight: 1.0 19 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn_common_measurement_test/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/axlearn/383c415daefd2f09d920d6575662ca92ae8410b3/axlearn/experiments/testdata/axlearn_common_measurement_test/__init__.py -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn_common_measurement_test/dummy_recorder.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2024 Apple Inc. 2 | 3 | """A dummy recorder used for measurement tests.""" 4 | 5 | from axlearn.common import measurement 6 | 7 | 8 | @measurement.register_recorder("dummy_recorder") 9 | class DummyRecorder(measurement.Recorder): 10 | @classmethod 11 | def from_flags(cls, fv) -> measurement.Recorder: 12 | del fv 13 | return cls.default_config().set(name="dummy_recorder").instantiate() 14 | -------------------------------------------------------------------------------- /axlearn/experiments/text/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | """AXLearn text experiments.""" 4 | -------------------------------------------------------------------------------- /axlearn/experiments/text/gpt/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | """AXLearn GPT experiments.""" 4 | 5 | from axlearn.experiments.text.gpt import ( 6 | c4_trainer, 7 | deterministic_trainer, 8 | pajama_sigmoid_trainer, 9 | pajama_trainer, 10 | ) 11 | -------------------------------------------------------------------------------- /axlearn/experiments/text/gpt/gala_sigmoid_test.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2024 Apple Inc. 2 | 3 | """Tests Gala sigmoid methods.""" 4 | from absl.testing import absltest 5 | 6 | from axlearn.common import input_tf_data, utils 7 | from axlearn.common.config import config_for_function 8 | from axlearn.common.test_utils import TestCase 9 | from axlearn.common.trainer import SpmdTrainer 10 | from axlearn.experiments.text.gpt.common import mixture_train_input_source 11 | from axlearn.experiments.text.gpt.gala_sigmoid import _set_seq_len_recursively 12 | 13 | 14 | class SetConfigTest(TestCase): 15 | def test_set_seq_len_recursively(self): 16 | train_input_source = config_for_function(mixture_train_input_source).set( 17 | max_sequence_length=200 18 | ) 19 | cfg = SpmdTrainer.default_config().set( 20 | input=input_tf_data.Input.default_config().set(source=train_input_source) 21 | ) 22 | 23 | self.assertEqual(cfg.input.source.max_sequence_length, 200) 24 | _set_seq_len_recursively(cfg, max_sequence_length=100) 25 | self.assertEqual(cfg.input.source.max_sequence_length, 100) 26 | 27 | 28 | if __name__ == "__main__": 29 | with utils.numeric_checks(True): 30 | absltest.main() 31 | -------------------------------------------------------------------------------- /axlearn/experiments/vision/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | """AXLearn vision experiments.""" 4 | -------------------------------------------------------------------------------- /axlearn/experiments/vision/imagenet/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | """ImageNet trainer module.""" 4 | 5 | from .common import NUM_CLASSES, NUM_TRAIN_EXAMPLES, input_config 6 | -------------------------------------------------------------------------------- /axlearn/experiments/vision/resnet/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | """ResNet trainer module.""" 4 | 5 | from . import imagenet_trainer 6 | from .common import learner_config, model_config 7 | -------------------------------------------------------------------------------- /axlearn/experiments/vision/resnet/imagenet_trainer_test.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | """ResNet on ImageNet trainer config tests.""" 4 | 5 | from axlearn.common import test_utils 6 | from axlearn.experiments.vision.resnet import imagenet_trainer 7 | 8 | 9 | class ImageNetTrainerTest(test_utils.TrainerConfigTestCase): 10 | """Tests ImageNet trainer.""" 11 | 12 | def test_trainer(self): 13 | self._test_with_trainer_config( 14 | imagenet_trainer.named_trainer_configs()["ResNet-Test"](), 15 | ) 16 | -------------------------------------------------------------------------------- /axlearn/huggingface/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/axlearn/383c415daefd2f09d920d6575662ca92ae8410b3/axlearn/huggingface/__init__.py -------------------------------------------------------------------------------- /axlearn/open_api/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2024 Apple Inc. 2 | -------------------------------------------------------------------------------- /axlearn/open_api/eval_set/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2024 Apple Inc. 2 | -------------------------------------------------------------------------------- /axlearn/open_api/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2024 Apple Inc. 2 | -------------------------------------------------------------------------------- /axlearn/vision/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/axlearn/383c415daefd2f09d920d6575662ca92ae8410b3/axlearn/vision/__init__.py -------------------------------------------------------------------------------- /axlearn/vision/imagenet_adversarial_text/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/axlearn/383c415daefd2f09d920d6575662ca92ae8410b3/axlearn/vision/imagenet_adversarial_text/__init__.py -------------------------------------------------------------------------------- /axlearn/vision/imagenet_adversarial_text/openai_clip_pred_1tfrecord_target2esti.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/axlearn/383c415daefd2f09d920d6575662ca92ae8410b3/axlearn/vision/imagenet_adversarial_text/openai_clip_pred_1tfrecord_target2esti.pickle -------------------------------------------------------------------------------- /axlearn/vision/imagenet_adversarial_text/util_imagenet.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=invalid-name 2 | """ImageNet utils.""" 3 | import pickle 4 | 5 | import numpy as np 6 | 7 | # Obtained by OpenAI model's estimation, include 8 | # - target2esti: for every category, obtain its most frequent confused categories 9 | # - S: (1000x1000) np array, S = text_embedding.T * text_embedding 10 | FN_OPENAI_TARGET2ESTI = "openai_clip_pred_1tfrecord_target2esti.pickle" 11 | 12 | 13 | class ImageNet_SimilarClass: 14 | """Identifies similar classes.""" 15 | 16 | def __init__(self) -> None: 17 | with open(FN_OPENAI_TARGET2ESTI, "rb") as handle: 18 | target2esti, S = pickle.load(handle) 19 | self.target2esti = target2esti 20 | self.S = S 21 | 22 | def most_similar_class(self, c): 23 | mylist = None 24 | if c in self.target2esti: 25 | mylist = [tmp for tmp in self.target2esti[c] if tmp != c] 26 | if mylist: 27 | return max(set(mylist), key=mylist.count) 28 | 29 | # If there is not confused class, choose the similar from text embeddings. 30 | sc = self.S[c] 31 | sc[c] = 0 32 | return np.argmax(sc) 33 | -------------------------------------------------------------------------------- /axlearn/vision/imagenet_adversarial_text/util_tfdata.py: -------------------------------------------------------------------------------- 1 | """Tensorflow data utils.""" 2 | 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | 7 | def get_dataset_cardinality(ds): 8 | batch_size = 1 9 | cardinality = np.sum([1 for i in ds.batch(batch_size)]) 10 | return cardinality 11 | 12 | 13 | # The following are copied from: 14 | # https://github.com/tensorflow/models/blob/master/research/object_detection/utils/dataset_util.py 15 | 16 | 17 | def int64_feature(value): 18 | return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) 19 | 20 | 21 | def int64_list_feature(value): 22 | return tf.train.Feature(int64_list=tf.train.Int64List(value=value)) 23 | 24 | 25 | def bytes_feature(value): 26 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 27 | 28 | 29 | def bytes_list_feature(value): 30 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=value)) 31 | 32 | 33 | def float_feature(value): 34 | return tf.train.Feature(float_list=tf.train.FloatList(value=[value])) 35 | 36 | 37 | def float_list_feature(value): 38 | return tf.train.Feature(float_list=tf.train.FloatList(value=value)) 39 | -------------------------------------------------------------------------------- /docs/research/mmau/figures/MMAU-herofig.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/axlearn/383c415daefd2f09d920d6575662ca92ae8410b3/docs/research/mmau/figures/MMAU-herofig.png -------------------------------------------------------------------------------- /docs/research/mmau/figures/results_radar_bar_combined.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/axlearn/383c415daefd2f09d920d6575662ca92ae8410b3/docs/research/mmau/figures/results_radar_bar_combined.png --------------------------------------------------------------------------------