├── axlearn ├── audio │ ├── __init__.py │ └── test_utils.py ├── common │ ├── __init__.py │ ├── monitoring │ │ └── __init__.py │ ├── ssm_kernels │ │ └── __init__.py │ ├── flash_attention │ │ └── __init__.py │ ├── quantized_dot_general │ │ └── __init__.py │ ├── ops │ │ └── __init__.py │ ├── env_test.py │ ├── normalize_test.py │ ├── normalize.py │ ├── launch_trainer_main.py │ ├── utils_tf.py │ ├── loss_metrics.py │ ├── utils_tf_test.py │ ├── base_model.py │ ├── learner_base.py │ └── debug_utils_test.py ├── vision │ ├── __init__.py │ ├── imagenet_adversarial_text │ │ ├── __init__.py │ │ ├── openai_clip_pred_1tfrecord_target2esti.pickle │ │ ├── util_tfdata.py │ │ ├── util_imagenet.py │ │ └── README.md │ ├── mask_generator_test.py │ └── nms_test.py ├── huggingface │ └── __init__.py ├── cloud │ ├── gcp │ │ ├── monitoring │ │ │ └── __init__.py │ │ ├── __init__.py │ │ ├── jobs │ │ │ ├── __init__.py │ │ │ └── bastion_vm_test.py │ │ ├── runners │ │ │ ├── base.py │ │ │ └── utils_test.py │ │ ├── job_pathways.py │ │ └── scopes.py │ ├── common │ │ ├── __init__.py │ │ ├── validator.py │ │ └── testdata │ │ │ └── counter.py │ └── __init__.py ├── open_api │ ├── __init__.py │ ├── eval_set │ │ └── __init__.py │ └── metrics │ │ └── __init__.py ├── experiments │ ├── testdata │ │ ├── axlearn_common_measurement_test │ │ │ ├── __init__.py │ │ │ └── dummy_recorder.py │ │ ├── axlearn.experiments.text.gpt.pajama_trainer │ │ │ ├── honeycrisp-3B-sp-rp_regularizer.txt │ │ │ ├── honeycrisp-85M-sp-rp_regularizer.txt │ │ │ ├── honeycrisp-test-sp-rp_regularizer.txt │ │ │ ├── honeycrisp-3B-flash-sp-rp_regularizer.txt │ │ │ ├── honeycrisp-85M-flash-sp-rp_regularizer.txt │ │ │ ├── honeycrisp-test-flash-sp-rp_regularizer.txt │ │ │ ├── gala-1B-sp-rp_regularizer.txt │ │ │ ├── gala-7B-sp-rp_regularizer.txt │ │ │ ├── gala-302M-sp-rp_regularizer.txt │ │ │ ├── gala-85M-sp-rp_regularizer.txt │ │ │ ├── gala-test-sp-rp_regularizer.txt │ │ │ ├── gala-1B-flash-sp-rp_regularizer.txt │ │ │ ├── gala-302M-flash-sp-rp_regularizer.txt │ │ │ ├── gala-7B-flash-sp-rp_regularizer.txt │ │ │ ├── gala-85M-flash-sp-rp_regularizer.txt │ │ │ ├── gala-test-flash-sp-rp_regularizer.txt │ │ │ ├── gala-7B-hybridnorm-alibi-sp-rp_regularizer.txt │ │ │ ├── gala-1B-hybridnorm-alibi-flash-sp-rp_regularizer.txt │ │ │ ├── gala-7B-hybridnorm-alibi-flash-sp-rp_regularizer.txt │ │ │ ├── gala-test-sp-rp_init.txt │ │ │ ├── gala-test-flash-sp-rp_init.txt │ │ │ ├── honeycrisp-test-sp-rp_init.txt │ │ │ ├── honeycrisp-test-flash-sp-rp_init.txt │ │ │ ├── gala-85M-sp-rp_init.txt │ │ │ ├── gala-1B-sp-rp_init.txt │ │ │ ├── gala-302M-sp-rp_init.txt │ │ │ ├── gala-85M-flash-sp-rp_init.txt │ │ │ ├── honeycrisp-85M-sp-rp_init.txt │ │ │ ├── gala-1B-flash-sp-rp_init.txt │ │ │ ├── gala-302M-flash-sp-rp_init.txt │ │ │ ├── gala-7B-sp-rp_init.txt │ │ │ ├── honeycrisp-3B-sp-rp_init.txt │ │ │ ├── honeycrisp-85M-flash-sp-rp_init.txt │ │ │ ├── gala-7B-flash-sp-rp_init.txt │ │ │ └── honeycrisp-3B-flash-sp-rp_init.txt │ │ ├── axlearn.common.conformer_test │ │ │ └── test_against_fairseq.npy │ │ ├── axlearn.common.encoder_decoder_test │ │ │ ├── test_against_t5x_True.npy │ │ │ └── test_against_t5x_False.npy │ │ ├── axlearn.common.t5_test │ │ │ ├── test_buckets_against_t5x_False_100.npy │ │ │ ├── test_buckets_against_t5x_False_20.npy │ │ │ ├── test_buckets_against_t5x_False_256.npy │ │ │ ├── test_buckets_against_t5x_True_100.npy │ │ │ ├── test_buckets_against_t5x_True_20.npy │ │ │ └── test_buckets_against_t5x_True_256.npy │ │ ├── axlearn.common.quantizer_test │ │ │ └── test_forward_against_fairseq.npy │ │ ├── axlearn.common.param_converter_test │ │ │ ├── test_parameters_from_t5x_ff.npy │ │ │ ├── test_parameters_from_t5x_dense.npy │ │ │ ├── test_parameters_from_t5x_decoder.npy │ │ │ ├── test_parameters_from_t5x_encoder.npy │ │ │ ├── test_parameters_from_t5x_attention.npy │ │ │ ├── test_parameters_from_t5x_embedding.npy │ │ │ ├── test_parameters_from_t5x_layer_norm.npy │ │ │ ├── test_parameters_from_t5x_encoder_decoder.npy │ │ │ ├── test_parameters_from_t5x_rel_pos_emb_True.npy │ │ │ ├── test_parameters_from_t5x_rel_pos_emb_False.npy │ │ │ └── test_parameters_from_t5x_transformer_layer.npy │ │ ├── axlearn.experiments.vision.resnet.imagenet_trainer │ │ │ ├── ResNet-Test_regularizer.txt │ │ │ ├── ResNet-Test_init.txt │ │ │ ├── ResNet-Testb_regularizer.txt │ │ │ └── ResNet-Testb_init.txt │ │ ├── axlearn.experiments.text.gpt.c4_trainer │ │ │ ├── fuji-1B-v3_regularizer.txt │ │ │ ├── fuji-3B-v3_regularizer.txt │ │ │ ├── fuji-7B-v1_regularizer.txt │ │ │ ├── fuji-7B-v2_regularizer.txt │ │ │ ├── fuji-7B-v3_regularizer.txt │ │ │ ├── fuji-test-v1_regularizer.txt │ │ │ ├── fuji-test-v2_regularizer.txt │ │ │ ├── fuji-test-v3_regularizer.txt │ │ │ ├── fuji-1B-v3-flash_regularizer.txt │ │ │ ├── fuji-3B-v3-flash_regularizer.txt │ │ │ ├── fuji-7B-v1-flash_regularizer.txt │ │ │ ├── fuji-7B-v2-flash_regularizer.txt │ │ │ ├── fuji-7B-v3-flash_regularizer.txt │ │ │ ├── fuji-test-v1-flash_regularizer.txt │ │ │ ├── fuji-test-v2-flash_regularizer.txt │ │ │ ├── fuji-test-v3-flash_regularizer.txt │ │ │ ├── fuji-1B-v3-single-host_regularizer.txt │ │ │ ├── fuji-1B-v3-tiktoken_regularizer.txt │ │ │ ├── fuji-3B-v3-single-host_regularizer.txt │ │ │ ├── fuji-3B-v3-tiktoken_regularizer.txt │ │ │ ├── fuji-7B-v1-single-host_regularizer.txt │ │ │ ├── fuji-7B-v2-single-host_regularizer.txt │ │ │ ├── fuji-7B-v3-single-host_regularizer.txt │ │ │ ├── fuji-golden-run-test-v1_regularizer.txt │ │ │ ├── fuji-golden-run-test-v2_regularizer.txt │ │ │ ├── fuji-golden-run-test-v3_regularizer.txt │ │ │ ├── fuji-test-v3-tiktoken_regularizer.txt │ │ │ ├── fuji-1B-v3-flash-single-host_regularizer.txt │ │ │ ├── fuji-1B-v3-tiktoken-flash_regularizer.txt │ │ │ ├── fuji-3B-v3-flash-single-host_regularizer.txt │ │ │ ├── fuji-3B-v3-tiktoken-flash_regularizer.txt │ │ │ ├── fuji-7B-v1-flash-single-host_regularizer.txt │ │ │ ├── fuji-7B-v2-flash-single-host_regularizer.txt │ │ │ ├── fuji-7B-v3-flash-single-host_regularizer.txt │ │ │ ├── fuji-test-v3-tiktoken-flash_regularizer.txt │ │ │ ├── fuji-1B-v3-tiktoken-single-host_regularizer.txt │ │ │ ├── fuji-3B-v3-tiktoken-single-host_regularizer.txt │ │ │ ├── fuji-golden-run-test-v3-tiktoken_regularizer.txt │ │ │ ├── fuji-1B-v3-tiktoken-flash-single-host_regularizer.txt │ │ │ ├── fuji-3B-v3-tiktoken-flash-single-host_regularizer.txt │ │ │ ├── fuji-70B-v1_regularizer.txt │ │ │ ├── fuji-70B-v2_regularizer.txt │ │ │ ├── fuji-70B-v3_regularizer.txt │ │ │ ├── fuji-70B-v1-flash_regularizer.txt │ │ │ ├── fuji-70B-v2-flash_regularizer.txt │ │ │ ├── fuji-70B-v3-flash_regularizer.txt │ │ │ ├── fuji-70B-v3-tiktoken_regularizer.txt │ │ │ ├── fuji-8B-v3-tiktoken_regularizer.txt │ │ │ ├── fuji-70B-v3-tiktoken-flash_regularizer.txt │ │ │ ├── fuji-8B-v3-tiktoken-flash_regularizer.txt │ │ │ ├── fuji-8B-v3-tiktoken-single-host_regularizer.txt │ │ │ ├── fuji-8B-v3-tiktoken-flash-single-host_regularizer.txt │ │ │ ├── fuji-test-v1_init.txt │ │ │ ├── fuji-test-v2_init.txt │ │ │ ├── fuji-test-v3_init.txt │ │ │ ├── fuji-test-v1-flash_init.txt │ │ │ ├── fuji-test-v2-flash_init.txt │ │ │ ├── fuji-test-v3-flash_init.txt │ │ │ ├── fuji-test-v3-tiktoken_init.txt │ │ │ ├── fuji-golden-run-test-v1_init.txt │ │ │ ├── fuji-golden-run-test-v2_init.txt │ │ │ ├── fuji-golden-run-test-v3_init.txt │ │ │ ├── fuji-test-v3-tiktoken-flash_init.txt │ │ │ ├── fuji-golden-run-test-v3-tiktoken_init.txt │ │ │ ├── fuji-1B-v3_init.txt │ │ │ ├── fuji-3B-v3_init.txt │ │ │ ├── fuji-7B-v1_init.txt │ │ │ ├── fuji-7B-v2_init.txt │ │ │ ├── fuji-1B-v3-flash_init.txt │ │ │ ├── fuji-1B-v3-tiktoken_init.txt │ │ │ ├── fuji-3B-v3-flash_init.txt │ │ │ ├── fuji-7B-v3_init.txt │ │ │ ├── fuji-1B-v3-single-host_init.txt │ │ │ ├── fuji-3B-v3-single-host_init.txt │ │ │ ├── fuji-3B-v3-tiktoken_init.txt │ │ │ ├── fuji-7B-v1-flash_init.txt │ │ │ ├── fuji-7B-v2-flash_init.txt │ │ │ ├── fuji-7B-v3-flash_init.txt │ │ │ ├── fuji-1B-v3-flash-single-host_init.txt │ │ │ ├── fuji-1B-v3-tiktoken-flash_init.txt │ │ │ ├── fuji-3B-v3-tiktoken-flash_init.txt │ │ │ ├── fuji-7B-v1-single-host_init.txt │ │ │ ├── fuji-7B-v2-single-host_init.txt │ │ │ ├── fuji-7B-v3-single-host_init.txt │ │ │ ├── fuji-1B-v3-tiktoken-single-host_init.txt │ │ │ ├── fuji-3B-v3-flash-single-host_init.txt │ │ │ ├── fuji-3B-v3-tiktoken-single-host_init.txt │ │ │ ├── fuji-7B-v1-flash-single-host_init.txt │ │ │ ├── fuji-7B-v2-flash-single-host_init.txt │ │ │ ├── fuji-7B-v3-flash-single-host_init.txt │ │ │ ├── fuji-1B-v3-tiktoken-flash-single-host_init.txt │ │ │ ├── fuji-3B-v3-tiktoken-flash-single-host_init.txt │ │ │ ├── fuji-70B-v1_init.txt │ │ │ ├── fuji-70B-v2_init.txt │ │ │ ├── fuji-70B-v3_init.txt │ │ │ ├── fuji-70B-v1-flash_init.txt │ │ │ ├── fuji-70B-v2-flash_init.txt │ │ │ ├── fuji-70B-v3-flash_init.txt │ │ │ ├── fuji-70B-v3-tiktoken_init.txt │ │ │ ├── fuji-8B-v3-tiktoken_init.txt │ │ │ ├── fuji-70B-v3-tiktoken-flash_init.txt │ │ │ ├── fuji-8B-v3-tiktoken-flash_init.txt │ │ │ ├── fuji-8B-v3-tiktoken-single-host_init.txt │ │ │ └── fuji-8B-v3-tiktoken-flash-single-host_init.txt │ │ └── axlearn.experiments.text.gpt.pajama_sigmoid_trainer │ │ │ ├── gala-sigmoid-1B-4k-hybridnorm-alibi-sp-rp_regularizer.txt │ │ │ ├── gala-sigmoid-7B-4k-hybridnorm-alibi-sp-rp_regularizer.txt │ │ │ ├── gala-sigmoid-85M-4k-hybridnorm-alibi-sp-rp_regularizer.txt │ │ │ ├── gala-sigmoid-1B-deterministic-4k-hybridnorm-alibi-pajama-2t-32k_regularizer.txt │ │ │ ├── gala-sigmoid-7B-deterministic-4k-hybridnorm-alibi-pajama-2t-32k_regularizer.txt │ │ │ └── gala-sigmoid-85M-deterministic-4k-hybridnorm-alibi-pajama-2t-32k_regularizer.txt │ ├── audio │ │ ├── __init__.py │ │ └── conformer │ │ │ ├── __init__.py │ │ │ └── librispeech_trainer_test.py │ ├── text │ │ ├── __init__.py │ │ └── gpt │ │ │ ├── __init__.py │ │ │ └── gala_sigmoid_test.py │ ├── vision │ │ ├── __init__.py │ │ ├── imagenet │ │ │ └── __init__.py │ │ └── resnet │ │ │ ├── __init__.py │ │ │ ├── imagenet_trainer_test.py │ │ │ └── common_test.py │ ├── conftest.py │ └── trainer_config_utils.py ├── __init__.py ├── data │ └── tokenizers │ │ └── sentencepiece │ │ ├── bpe_128k_c4.model │ │ ├── bpe_32k_c4.model │ │ ├── librispeech_bpe_1024.model │ │ ├── librispeech_unigram_1024.model │ │ ├── bpe_128k.json │ │ └── bpe_32k.json └── cli │ └── testdata │ └── dummy.py ├── CODEOWNERS ├── docs └── research │ └── mmau │ └── figures │ ├── MMAU-herofig.png │ └── results_radar_bar_combined.png ├── .github └── workflows │ ├── pre-commit.yml │ └── build.yml ├── .pre-commit-config.yaml ├── CHANGELOG.md └── conftest.py /axlearn/audio/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /axlearn/common/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /axlearn/vision/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /axlearn/huggingface/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /axlearn/common/monitoring/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /axlearn/common/ssm_kernels/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /axlearn/cloud/gcp/monitoring/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /axlearn/common/flash_attention/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /axlearn/common/quantized_dot_general/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /axlearn/vision/imagenet_adversarial_text/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /CODEOWNERS: -------------------------------------------------------------------------------- 1 | * @ruomingp @markblee @apple/axlearn-admins 2 | -------------------------------------------------------------------------------- /axlearn/open_api/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2024 Apple Inc. 2 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn_common_measurement_test/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /axlearn/open_api/eval_set/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2024 Apple Inc. 2 | -------------------------------------------------------------------------------- /axlearn/open_api/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2024 Apple Inc. 2 | -------------------------------------------------------------------------------- /axlearn/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | """AXLearn.""" 4 | -------------------------------------------------------------------------------- /axlearn/cloud/gcp/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | """AXLearn GCP module.""" 4 | -------------------------------------------------------------------------------- /axlearn/cloud/common/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | """AXLearn cloud utilities.""" 4 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/honeycrisp-3B-sp-rp_regularizer.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /axlearn/experiments/audio/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2024 Apple Inc. 2 | 3 | """AXLearn audio experiments.""" 4 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/honeycrisp-85M-sp-rp_regularizer.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/honeycrisp-test-sp-rp_regularizer.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /axlearn/experiments/text/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | """AXLearn text experiments.""" 4 | -------------------------------------------------------------------------------- /axlearn/cloud/gcp/jobs/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2025 Apple Inc. 2 | 3 | """A collection of CLI entrypoints.""" 4 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/honeycrisp-3B-flash-sp-rp_regularizer.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/honeycrisp-85M-flash-sp-rp_regularizer.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/honeycrisp-test-flash-sp-rp_regularizer.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /axlearn/experiments/vision/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | """AXLearn vision experiments.""" 4 | -------------------------------------------------------------------------------- /docs/research/mmau/figures/MMAU-herofig.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/findmyway/axlearn/main/docs/research/mmau/figures/MMAU-herofig.png -------------------------------------------------------------------------------- /axlearn/data/tokenizers/sentencepiece/bpe_128k_c4.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/findmyway/axlearn/main/axlearn/data/tokenizers/sentencepiece/bpe_128k_c4.model -------------------------------------------------------------------------------- /axlearn/data/tokenizers/sentencepiece/bpe_32k_c4.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/findmyway/axlearn/main/axlearn/data/tokenizers/sentencepiece/bpe_32k_c4.model -------------------------------------------------------------------------------- /axlearn/common/ops/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | """Custom ops.""" 4 | 5 | from ._optimization_barrier import forward_optimization_barrier 6 | -------------------------------------------------------------------------------- /docs/research/mmau/figures/results_radar_bar_combined.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/findmyway/axlearn/main/docs/research/mmau/figures/results_radar_bar_combined.png -------------------------------------------------------------------------------- /axlearn/experiments/audio/conformer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2024 Apple Inc. 2 | 3 | """AXLearn Conformer experiments.""" 4 | 5 | from . import librispeech_trainer 6 | -------------------------------------------------------------------------------- /axlearn/data/tokenizers/sentencepiece/librispeech_bpe_1024.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/findmyway/axlearn/main/axlearn/data/tokenizers/sentencepiece/librispeech_bpe_1024.model -------------------------------------------------------------------------------- /axlearn/data/tokenizers/sentencepiece/librispeech_unigram_1024.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/findmyway/axlearn/main/axlearn/data/tokenizers/sentencepiece/librispeech_unigram_1024.model -------------------------------------------------------------------------------- /axlearn/experiments/vision/imagenet/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | """ImageNet trainer module.""" 4 | 5 | from .common import NUM_CLASSES, NUM_TRAIN_EXAMPLES, input_config 6 | -------------------------------------------------------------------------------- /axlearn/experiments/vision/resnet/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | """ResNet trainer module.""" 4 | 5 | from . import imagenet_trainer 6 | from .common import learner_config, model_config 7 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.common.conformer_test/test_against_fairseq.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/findmyway/axlearn/main/axlearn/experiments/testdata/axlearn.common.conformer_test/test_against_fairseq.npy -------------------------------------------------------------------------------- /axlearn/vision/imagenet_adversarial_text/openai_clip_pred_1tfrecord_target2esti.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/findmyway/axlearn/main/axlearn/vision/imagenet_adversarial_text/openai_clip_pred_1tfrecord_target2esti.pickle -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.common.encoder_decoder_test/test_against_t5x_True.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/findmyway/axlearn/main/axlearn/experiments/testdata/axlearn.common.encoder_decoder_test/test_against_t5x_True.npy -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.common.t5_test/test_buckets_against_t5x_False_100.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/findmyway/axlearn/main/axlearn/experiments/testdata/axlearn.common.t5_test/test_buckets_against_t5x_False_100.npy -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.common.t5_test/test_buckets_against_t5x_False_20.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/findmyway/axlearn/main/axlearn/experiments/testdata/axlearn.common.t5_test/test_buckets_against_t5x_False_20.npy -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.common.t5_test/test_buckets_against_t5x_False_256.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/findmyway/axlearn/main/axlearn/experiments/testdata/axlearn.common.t5_test/test_buckets_against_t5x_False_256.npy -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.common.t5_test/test_buckets_against_t5x_True_100.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/findmyway/axlearn/main/axlearn/experiments/testdata/axlearn.common.t5_test/test_buckets_against_t5x_True_100.npy -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.common.t5_test/test_buckets_against_t5x_True_20.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/findmyway/axlearn/main/axlearn/experiments/testdata/axlearn.common.t5_test/test_buckets_against_t5x_True_20.npy -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.common.t5_test/test_buckets_against_t5x_True_256.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/findmyway/axlearn/main/axlearn/experiments/testdata/axlearn.common.t5_test/test_buckets_against_t5x_True_256.npy -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.common.encoder_decoder_test/test_against_t5x_False.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/findmyway/axlearn/main/axlearn/experiments/testdata/axlearn.common.encoder_decoder_test/test_against_t5x_False.npy -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.common.quantizer_test/test_forward_against_fairseq.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/findmyway/axlearn/main/axlearn/experiments/testdata/axlearn.common.quantizer_test/test_forward_against_fairseq.npy -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.common.param_converter_test/test_parameters_from_t5x_ff.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/findmyway/axlearn/main/axlearn/experiments/testdata/axlearn.common.param_converter_test/test_parameters_from_t5x_ff.npy -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.common.param_converter_test/test_parameters_from_t5x_dense.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/findmyway/axlearn/main/axlearn/experiments/testdata/axlearn.common.param_converter_test/test_parameters_from_t5x_dense.npy -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.common.param_converter_test/test_parameters_from_t5x_decoder.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/findmyway/axlearn/main/axlearn/experiments/testdata/axlearn.common.param_converter_test/test_parameters_from_t5x_decoder.npy -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.common.param_converter_test/test_parameters_from_t5x_encoder.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/findmyway/axlearn/main/axlearn/experiments/testdata/axlearn.common.param_converter_test/test_parameters_from_t5x_encoder.npy -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.common.param_converter_test/test_parameters_from_t5x_attention.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/findmyway/axlearn/main/axlearn/experiments/testdata/axlearn.common.param_converter_test/test_parameters_from_t5x_attention.npy -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.common.param_converter_test/test_parameters_from_t5x_embedding.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/findmyway/axlearn/main/axlearn/experiments/testdata/axlearn.common.param_converter_test/test_parameters_from_t5x_embedding.npy -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.common.param_converter_test/test_parameters_from_t5x_layer_norm.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/findmyway/axlearn/main/axlearn/experiments/testdata/axlearn.common.param_converter_test/test_parameters_from_t5x_layer_norm.npy -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.common.param_converter_test/test_parameters_from_t5x_encoder_decoder.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/findmyway/axlearn/main/axlearn/experiments/testdata/axlearn.common.param_converter_test/test_parameters_from_t5x_encoder_decoder.npy -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.common.param_converter_test/test_parameters_from_t5x_rel_pos_emb_True.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/findmyway/axlearn/main/axlearn/experiments/testdata/axlearn.common.param_converter_test/test_parameters_from_t5x_rel_pos_emb_True.npy -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.common.param_converter_test/test_parameters_from_t5x_rel_pos_emb_False.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/findmyway/axlearn/main/axlearn/experiments/testdata/axlearn.common.param_converter_test/test_parameters_from_t5x_rel_pos_emb_False.npy -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.common.param_converter_test/test_parameters_from_t5x_transformer_layer.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/findmyway/axlearn/main/axlearn/experiments/testdata/axlearn.common.param_converter_test/test_parameters_from_t5x_transformer_layer.npy -------------------------------------------------------------------------------- /axlearn/cloud/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | """AXLearn cloud modules.""" 4 | import pathlib 5 | 6 | # Root of the cloud module, e.g. /path/to/axlearn. 7 | ROOT_MODULE = pathlib.Path(__file__).resolve().parent.parent 8 | ROOT_MODULE_NAME = ROOT_MODULE.name 9 | -------------------------------------------------------------------------------- /axlearn/experiments/text/gpt/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | """AXLearn GPT experiments.""" 4 | 5 | from axlearn.experiments.text.gpt import ( 6 | c4_trainer, 7 | deterministic_trainer, 8 | pajama_sigmoid_trainer, 9 | pajama_trainer, 10 | ) 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.vision.resnet.imagenet_trainer/ResNet-Test_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | backbone/stem/conv1/weight: 1.0 3 | backbone/stem/norm1/bias: 0 4 | backbone/stem/norm1/moving_mean: 0 5 | backbone/stem/norm1/moving_variance: 0 6 | backbone/stem/norm1/scale: 0 7 | classifier/bias: 1.0 8 | classifier/weight: 1.0 9 | -------------------------------------------------------------------------------- /axlearn/cloud/gcp/runners/base.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2024 Apple Inc. 2 | 3 | """Base runner interface.""" 4 | 5 | from axlearn.cloud.common.bundler import Bundler 6 | from axlearn.cloud.gcp.job import GCPJob 7 | 8 | 9 | class BaseRunnerJob(GCPJob): 10 | """Base runner job interface.""" 11 | 12 | Config = GCPJob.Config 13 | 14 | def __init__(self, cfg: Config, *, bundler: Bundler): 15 | super().__init__(cfg) 16 | self._bundler = bundler 17 | -------------------------------------------------------------------------------- /axlearn/experiments/conftest.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | """Configures pytest.""" 4 | 5 | 6 | def pytest_addoption(parser): 7 | # pylint: disable-next=import-outside-toplevel 8 | from axlearn.common.test_utils import pytest_addoption_atomic 9 | 10 | pytest_addoption_atomic( 11 | parser, 12 | "--update", 13 | action="store_true", 14 | default=False, 15 | help="If true, update all golden files.", 16 | ) 17 | -------------------------------------------------------------------------------- /axlearn/cloud/common/validator.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2025 Apple Inc. 2 | 3 | """Utilities to validate Jobs.""" 4 | 5 | from axlearn.cloud.common.types import JobSpec 6 | from axlearn.common.config import Configurable 7 | 8 | 9 | class JobValidator(Configurable): 10 | """A job validator interface.""" 11 | 12 | def validate(self, job: JobSpec): 13 | """Raises ValidationError with reason if jobspec is invalid.""" 14 | raise NotImplementedError(type(self)) 15 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn_common_measurement_test/dummy_recorder.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2024 Apple Inc. 2 | 3 | """A dummy recorder used for measurement tests.""" 4 | 5 | from axlearn.common import measurement 6 | 7 | 8 | @measurement.register_recorder("dummy_recorder") 9 | class DummyRecorder(measurement.Recorder): 10 | @classmethod 11 | def from_flags(cls, fv) -> measurement.Recorder: 12 | del fv 13 | return cls.default_config().set(name="dummy_recorder").instantiate() 14 | -------------------------------------------------------------------------------- /axlearn/common/env_test.py: -------------------------------------------------------------------------------- 1 | """Tests for AXLearn environment.""" 2 | 3 | # pylint: disable=no-self-use 4 | 5 | import tensorflow_io # noqa: F401 # pylint: disable=unused-import 6 | from absl.testing import absltest 7 | from tensorflow import io as tf_io 8 | 9 | from axlearn.common import test_utils 10 | 11 | 12 | class EnvTest(test_utils.TestCase): 13 | def test_tf_io_s3_support(self): 14 | self.assertIn("s3", tf_io.gfile.get_registered_schemes()) 15 | 16 | 17 | if __name__ == "__main__": 18 | absltest.main() 19 | -------------------------------------------------------------------------------- /axlearn/experiments/vision/resnet/imagenet_trainer_test.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | """ResNet on ImageNet trainer config tests.""" 4 | 5 | from axlearn.common import test_utils 6 | from axlearn.experiments.vision.resnet import imagenet_trainer 7 | 8 | 9 | class ImageNetTrainerTest(test_utils.TrainerConfigTestCase): 10 | """Tests ImageNet trainer.""" 11 | 12 | def test_trainer(self): 13 | self._test_with_trainer_config( 14 | imagenet_trainer.named_trainer_configs()["ResNet-Test"](), 15 | ) 16 | -------------------------------------------------------------------------------- /axlearn/experiments/audio/conformer/librispeech_trainer_test.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2024 Apple Inc. 2 | 3 | """Tests Conformer LibriSpeech configs.""" 4 | 5 | from axlearn.common import test_utils 6 | from axlearn.experiments.audio.conformer import librispeech_trainer 7 | 8 | 9 | class LibriSpeechTrainerTest(test_utils.TrainerConfigTestCase): 10 | """Tests LibriSpeech trainer.""" 11 | 12 | def test_trainer(self): 13 | self._test_with_trainer_config( 14 | librispeech_trainer.named_trainer_configs()["conformer-test-ctc"](), 15 | ) 16 | -------------------------------------------------------------------------------- /axlearn/common/normalize_test.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | """Tests normalization utils.""" 4 | import jax.numpy as jnp 5 | import numpy as np 6 | import pytest 7 | import tensorflow as tf 8 | 9 | from axlearn.common.normalize import l2_normalize 10 | 11 | 12 | @pytest.mark.parametrize("shape, axis", [([1, 4], -1), ([1, 3], 1), ([2, 5, 4], 2), ([1, 3, 4], 0)]) 13 | def test_l2_normalize(shape, axis): 14 | x = np.random.rand(*shape) 15 | ref = tf.math.l2_normalize(x, axis) 16 | assert jnp.allclose(l2_normalize(x, eps=1e-12, axis=axis), ref.numpy()) 17 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.vision.resnet.imagenet_trainer/ResNet-Test_init.txt: -------------------------------------------------------------------------------- 1 | backbone/stem/conv1/weight: normal(0, 1.4142135623730951 / fan_out), shape=[7, 7, 3, 4], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | backbone/stem/norm1/scale: constant(1.0) 3 | backbone/stem/norm1/bias: constant(0.0) 4 | backbone/stem/norm1/moving_mean: constant(0.0) 5 | backbone/stem/norm1/moving_variance: constant(1.0) 6 | classifier/weight: normal(0, 1.4142135623730951 / fan_out), shape=(4, 1000), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | classifier/bias: constant(0.0) -------------------------------------------------------------------------------- /axlearn/cli/testdata/dummy.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | """A dummy absl program.""" 4 | 5 | from absl import app, flags 6 | 7 | flags.DEFINE_string("required", None, "A required flag.", required=True) 8 | flags.DEFINE_string("optional", None, "An optional flag.") 9 | flags.DEFINE_string("root_default", None, "A required flag defaulted at a parent.", required=True) 10 | 11 | FLAGS = flags.FLAGS 12 | 13 | 14 | def main(_): 15 | print( 16 | f"required: {FLAGS.required}, optional: {FLAGS.optional}, " 17 | f"root_default: {FLAGS.root_default}" 18 | ) 19 | 20 | 21 | if __name__ == "__main__": 22 | app.run(main) 23 | -------------------------------------------------------------------------------- /axlearn/common/normalize.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | """Normalization utils.""" 4 | import jax 5 | 6 | from axlearn.common.utils import Tensor 7 | 8 | 9 | def l2_normalize(x: Tensor, eps: float = 1e-8, axis: int = -1) -> Tensor: 10 | """l2_normalize Normalizes along the dimension `axis` using an L2 norm. 11 | 12 | Args: 13 | x: Input tensor. 14 | axis: Dimension along which to normalize. 15 | eps: A lower bound value for the norm. Defaults to 1e-8. 16 | 17 | Returns: 18 | A Tensor with the same shape as x. 19 | """ 20 | sum2 = (x * x).sum(axis=axis, keepdims=True) 21 | return x * jax.lax.rsqrt(sum2 + eps) 22 | -------------------------------------------------------------------------------- /axlearn/data/tokenizers/sentencepiece/bpe_128k.json: -------------------------------------------------------------------------------- 1 | { 2 | "pad_id": 0, 3 | "eos_id": 1, 4 | "unk_id": 2, 5 | "bos_id": -1, 6 | "pad_piece": "", 7 | "eos_piece": "", 8 | "unk_piece": "", 9 | "bos_piece": "", 10 | "byte_fallback": true, 11 | "vocab_size": 128000, 12 | "user_defined_symbols": "", 13 | "control_symbols": "", 14 | "character_coverage": 0.9995, 15 | "model_type": "bpe", 16 | "max_sentence_length": 32768, 17 | "allow_whitespace_only_pieces": true, 18 | "split_by_whitespace": false, 19 | "remove_extra_whitespaces": false, 20 | "shuffle_input_sentence": false, 21 | "split_digits": true, 22 | "normalization_rule_name": "identity" 23 | } 24 | -------------------------------------------------------------------------------- /axlearn/data/tokenizers/sentencepiece/bpe_32k.json: -------------------------------------------------------------------------------- 1 | { 2 | "pad_id": 0, 3 | "eos_id": 1, 4 | "unk_id": 2, 5 | "bos_id": -1, 6 | "pad_piece": "", 7 | "eos_piece": "", 8 | "unk_piece": "", 9 | "bos_piece": "", 10 | "byte_fallback": true, 11 | "vocab_size": 32000, 12 | "user_defined_symbols": "", 13 | "control_symbols": "", 14 | "character_coverage": 0.9995, 15 | "model_type": "bpe", 16 | "max_sentence_length": 32768, 17 | "split_by_whitespace": false, 18 | "allow_whitespace_only_pieces": true, 19 | "remove_extra_whitespaces": false, 20 | "shuffle_input_sentence": false, 21 | "split_digits": true, 22 | "normalization_rule_name": "identity" 23 | } 24 | -------------------------------------------------------------------------------- /axlearn/common/launch_trainer_main.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | """Main function for launching the trainer.""" 4 | 5 | from absl import app, flags 6 | 7 | from axlearn.common import launch, launch_trainer, measurement 8 | from axlearn.common.config import config_for_function 9 | 10 | 11 | def main(_): 12 | measurement.initialize(flags.FLAGS) 13 | launch.setup() 14 | trainer_config = launch_trainer.get_trainer_config() 15 | trainer_config.set(recorder=config_for_function(lambda: measurement.global_recorder)) 16 | measurement.start_monitoring() 17 | launch_trainer.run_trainer(trainer_config) 18 | 19 | 20 | if __name__ == "__main__": 21 | measurement.define_flags() 22 | app.run(main) 23 | -------------------------------------------------------------------------------- /.github/workflows/pre-commit.yml: -------------------------------------------------------------------------------- 1 | name: pre-commit 2 | 3 | on: [pull_request, merge_group] 4 | 5 | jobs: 6 | pre-commit: 7 | runs-on: ubuntu-latest 8 | # resource_class: large 9 | steps: 10 | - uses: actions/checkout@v4 11 | - uses: actions/setup-python@v5 12 | with: 13 | python-version: '3.10' 14 | cache: 'pip' 15 | - run: pip install --upgrade pip 16 | # TODO(markblee): Remove gcp,vertexai_tensorboard from CI. (needed by pytype) 17 | - run: pip install '.[core,dev,grain,gcp,vertexai_tensorboard]' 18 | # pylint uses approx 12GB of memory during this run, look into split to decrease? 19 | - run: pre-commit run --all-files 20 | - run: pytype -j auto . 21 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v1_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v2_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v3_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-flash_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-flash_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3-flash_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v1-flash_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v2-flash_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v3-flash_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-single-host_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-single-host_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-single-host_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-single-host_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3-single-host_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-golden-run-test-v1_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-golden-run-test-v2_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-golden-run-test-v3_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v3-tiktoken_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash-single-host_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken-flash_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash-single-host_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken-flash_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-flash-single-host_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-flash-single-host_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3-flash-single-host_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v3-tiktoken-flash_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken-single-host_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken-single-host_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-golden-run-test-v3-tiktoken_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken-flash-single-host_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken-flash-single-host_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 11 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/lm_head/weight: 1 4 | decoder/output_norm/scale: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 8 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 11 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 12 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/lm_head/weight: 1 4 | decoder/output_norm/scale: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 8 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 11 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 12 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/lm_head/weight: 1 4 | decoder/output_norm/scale: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 8 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 11 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 12 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1-flash_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/lm_head/weight: 1 4 | decoder/output_norm/scale: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 8 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 11 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 12 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2-flash_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/lm_head/weight: 1 4 | decoder/output_norm/scale: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 8 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 11 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 12 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/lm_head/weight: 1 4 | decoder/output_norm/scale: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 8 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 11 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 12 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-tiktoken_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/lm_head/weight: 1 4 | decoder/output_norm/scale: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 8 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 11 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 12 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/lm_head/weight: 1 4 | decoder/output_norm/scale: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 8 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 11 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 12 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-tiktoken-flash_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/lm_head/weight: 1 4 | decoder/output_norm/scale: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 8 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 11 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 12 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-flash_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/lm_head/weight: 1 4 | decoder/output_norm/scale: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 8 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 11 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 12 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-single-host_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/lm_head/weight: 1 4 | decoder/output_norm/scale: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 8 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 11 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 12 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # To run locally: 2 | # % pre-commit run -a 3 | repos: 4 | - repo: https://github.com/pre-commit/pre-commit-hooks 5 | rev: v2.3.0 6 | hooks: 7 | - id: check-yaml 8 | - id: end-of-file-fixer 9 | # These files are generated. 10 | exclude: "axlearn/experiments/testdata/.*" 11 | - id: trailing-whitespace 12 | - repo: local 13 | hooks: 14 | - id: black 15 | name: black 16 | entry: black 17 | language: system 18 | types: [python] 19 | - id: isort 20 | name: isort 21 | entry: isort 22 | language: system 23 | types: [python] 24 | - id: pylint 25 | name: pylint 26 | entry: pylint 27 | args: ['--msg-template="{abspath}:{line}: [{msg_id}({symbol}), {obj}] {msg}"'] 28 | language: system 29 | types: [python] 30 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-flash-single-host_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/lm_head/weight: 1 4 | decoder/output_norm/scale: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 8 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 11 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 12 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.vision.resnet.imagenet_trainer/ResNet-Testb_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | backbone/stage0/block0/conv1/weight: 1.0 3 | backbone/stage0/block0/conv2/weight: 1.0 4 | backbone/stage0/block0/norm1/bias: 0 5 | backbone/stage0/block0/norm1/moving_mean: 0 6 | backbone/stage0/block0/norm1/moving_variance: 0 7 | backbone/stage0/block0/norm1/scale: 0 8 | backbone/stage0/block0/norm2/bias: 0 9 | backbone/stage0/block0/norm2/moving_mean: 0 10 | backbone/stage0/block0/norm2/moving_variance: 0 11 | backbone/stage0/block0/norm2/scale: 0 12 | backbone/stem/conv1/weight: 1.0 13 | backbone/stem/norm1/bias: 0 14 | backbone/stem/norm1/moving_mean: 0 15 | backbone/stem/norm1/moving_variance: 0 16 | backbone/stem/norm1/scale: 0 17 | classifier/bias: 1.0 18 | classifier/weight: 1.0 19 | -------------------------------------------------------------------------------- /axlearn/cloud/gcp/jobs/bastion_vm_test.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | """Tests bastion VM.""" 4 | # pylint: disable=protected-access 5 | 6 | from absl import flags 7 | 8 | from axlearn.cloud.gcp.jobs import bastion_vm 9 | from axlearn.cloud.gcp.test_utils import default_mock_settings, mock_gcp_settings 10 | from axlearn.common.test_utils import TestWithTemporaryCWD 11 | 12 | 13 | class MainTest(TestWithTemporaryCWD): 14 | """Tests CLI entrypoint.""" 15 | 16 | def test_private_flags(self): 17 | with mock_gcp_settings(bastion_vm.__name__, default_mock_settings()): 18 | fv = flags.FlagValues() 19 | bastion_vm._private_flags(flag_values=fv) 20 | # Basic sanity check. 21 | self.assertIsNotNone(fv["project"].default) 22 | self.assertIsNotNone(fv["zone"].default) 23 | self.assertIsNotNone(fv["env_id"].default) 24 | -------------------------------------------------------------------------------- /axlearn/common/utils_tf.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | """Tensorflow utils.""" 4 | from typing import Union 5 | 6 | import tensorflow as tf 7 | 8 | 9 | def masked_fill(orig: tf.Tensor, mask: tf.Tensor, fill_value: Union[int, float, str]) -> tf.Tensor: 10 | """Replaces values in `orig` with `fill_value` where `mask` is true. 11 | 12 | Args: 13 | orig: A Tensor representing the original values. 14 | mask: A boolean Tensor of the same shape as `orig`, 15 | representing where the values should be replaced. 16 | fill_value: The value to fill where mask is True. 17 | 18 | Returns: 19 | A Tensor of the same size and dtype as `orig`, but with some values replaced with 20 | `fill_value` where `mask` is true. 21 | """ 22 | fill = tf.cast(tf.fill(tf.shape(orig), fill_value), orig.dtype) 23 | return tf.where(mask, fill, orig) 24 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/gala-1B-sp-rp_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/attention/scale_key/norm/scale: 1 11 | decoder/transformer/repeat/layer/self_attention/attention/scale_query/norm/scale: 1 12 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 13 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/gala-7B-sp-rp_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/attention/scale_key/norm/scale: 1 11 | decoder/transformer/repeat/layer/self_attention/attention/scale_query/norm/scale: 1 12 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 13 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/gala-302M-sp-rp_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/attention/scale_key/norm/scale: 1 11 | decoder/transformer/repeat/layer/self_attention/attention/scale_query/norm/scale: 1 12 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 13 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/gala-85M-sp-rp_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/attention/scale_key/norm/scale: 1 11 | decoder/transformer/repeat/layer/self_attention/attention/scale_query/norm/scale: 1 12 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 13 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/gala-test-sp-rp_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/attention/scale_key/norm/scale: 1 11 | decoder/transformer/repeat/layer/self_attention/attention/scale_query/norm/scale: 1 12 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 13 | -------------------------------------------------------------------------------- /axlearn/audio/test_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | """Speech testing utils.""" 4 | 5 | import jax 6 | import jax.numpy as jnp 7 | 8 | from axlearn.common.utils import Tensor 9 | 10 | 11 | def fake_audio( 12 | *, 13 | batch_size: int, 14 | seq_len: int, 15 | prng_key: Tensor, 16 | scale: float = 32768.0, 17 | dtype: jnp.dtype = jnp.float32, 18 | ): 19 | """Generates fake audio data with a fixed seed.""" 20 | input_key, length_key = jax.random.split(prng_key) 21 | inputs = jax.random.uniform( 22 | input_key, 23 | shape=[batch_size, seq_len], 24 | minval=-scale, 25 | maxval=scale, 26 | dtype=jnp.float32, 27 | ).astype(dtype) 28 | lengths = jax.random.randint(length_key, shape=[batch_size, 1], minval=0, maxval=seq_len) 29 | paddings = (jnp.arange(seq_len)[None, :] >= lengths).astype(jnp.bool) 30 | return inputs, paddings 31 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/gala-1B-flash-sp-rp_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/attention/scale_key/norm/scale: 1 11 | decoder/transformer/repeat/layer/self_attention/attention/scale_query/norm/scale: 1 12 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 13 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/gala-302M-flash-sp-rp_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/attention/scale_key/norm/scale: 1 11 | decoder/transformer/repeat/layer/self_attention/attention/scale_query/norm/scale: 1 12 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 13 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/gala-7B-flash-sp-rp_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/attention/scale_key/norm/scale: 1 11 | decoder/transformer/repeat/layer/self_attention/attention/scale_query/norm/scale: 1 12 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 13 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/gala-85M-flash-sp-rp_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/attention/scale_key/norm/scale: 1 11 | decoder/transformer/repeat/layer/self_attention/attention/scale_query/norm/scale: 1 12 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 13 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/gala-test-flash-sp-rp_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 8 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/attention/scale_key/norm/scale: 1 11 | decoder/transformer/repeat/layer/self_attention/attention/scale_query/norm/scale: 1 12 | decoder/transformer/repeat/layer/self_attention/norm/scale: 1 13 | -------------------------------------------------------------------------------- /axlearn/cloud/common/testdata/counter.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | """A dummy script that spawns a subprocess which keeps updating a counter.""" 4 | 5 | import shlex 6 | import subprocess 7 | import sys 8 | import time 9 | 10 | 11 | def _child(path: str): 12 | print(f"emitting to {path}") 13 | for i in range(100): 14 | with open(path, "w", encoding="utf-8") as f: 15 | f.seek(0, 0) 16 | print(f"incrementing to {i}") 17 | f.write(str(i)) 18 | f.flush() 19 | time.sleep(0.1) 20 | 21 | 22 | if __name__ == "__main__": 23 | output_path, parent_or_child = sys.argv[1], sys.argv[2] 24 | 25 | if parent_or_child == "parent": 26 | # pylint: disable-next=consider-using-with 27 | p = subprocess.Popen( 28 | shlex.split(f"python3 {__file__} {output_path} child"), start_new_session=True 29 | ) 30 | print("returncode:", p.wait()) 31 | else: 32 | assert parent_or_child == "child" 33 | _child(output_path) 34 | -------------------------------------------------------------------------------- /axlearn/common/loss_metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2025 Apple Inc. 2 | 3 | """Layers for computing training time metrics.""" 4 | 5 | from axlearn.common.base_layer import BaseLayer 6 | from axlearn.common.utils import Nested, Tensor 7 | 8 | 9 | class BaseLossMetrics(BaseLayer): 10 | """A module for computing training time metrics. 11 | 12 | See `causal_lm.Model` for an example usage. 13 | """ 14 | 15 | def forward( 16 | self, 17 | input_batch: Nested[Tensor], 18 | *, 19 | predict_outputs: Nested[Tensor], 20 | module_outputs: Nested[Tensor], 21 | ) -> tuple[Tensor, Nested[Tensor]]: 22 | """Computes metrics from inputs and predictions. 23 | 24 | Args: 25 | input_batch: A mapping from input keys to Tensors. 26 | predict_outputs: Model predictions for computing metrics. 27 | module_outputs: Outputs from the model's invocation context. 28 | 29 | Returns: 30 | A tuple (loss, metrics). 31 | """ 32 | raise NotImplementedError(type(self)) 33 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/gala-7B-hybridnorm-alibi-sp-rp_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/postnorm/scale: 1 8 | decoder/transformer/repeat/layer/feed_forward/prenorm/scale: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/qkv_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 11 | decoder/transformer/repeat/layer/self_attention/attention/scale_key/norm/scale: 1 12 | decoder/transformer/repeat/layer/self_attention/attention/scale_query/norm/scale: 1 13 | decoder/transformer/repeat/layer/self_attention/postnorm/scale: 1 14 | decoder/transformer/repeat/layer/self_attention/prenorm/scale: 1 15 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/gala-1B-hybridnorm-alibi-flash-sp-rp_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/postnorm/scale: 1 8 | decoder/transformer/repeat/layer/feed_forward/prenorm/scale: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/qkv_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 11 | decoder/transformer/repeat/layer/self_attention/attention/scale_key/norm/scale: 1 12 | decoder/transformer/repeat/layer/self_attention/attention/scale_query/norm/scale: 1 13 | decoder/transformer/repeat/layer/self_attention/postnorm/scale: 1 14 | decoder/transformer/repeat/layer/self_attention/prenorm/scale: 1 15 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/gala-7B-hybridnorm-alibi-flash-sp-rp_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/postnorm/scale: 1 8 | decoder/transformer/repeat/layer/feed_forward/prenorm/scale: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/qkv_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 11 | decoder/transformer/repeat/layer/self_attention/attention/scale_key/norm/scale: 1 12 | decoder/transformer/repeat/layer/self_attention/attention/scale_query/norm/scale: 1 13 | decoder/transformer/repeat/layer/self_attention/postnorm/scale: 1 14 | decoder/transformer/repeat/layer/self_attention/prenorm/scale: 1 15 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_sigmoid_trainer/gala-sigmoid-1B-4k-hybridnorm-alibi-sp-rp_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/postnorm/scale: 1 8 | decoder/transformer/repeat/layer/feed_forward/prenorm/scale: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/qkv_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 11 | decoder/transformer/repeat/layer/self_attention/attention/scale_key/norm/scale: 1 12 | decoder/transformer/repeat/layer/self_attention/attention/scale_query/norm/scale: 1 13 | decoder/transformer/repeat/layer/self_attention/postnorm/scale: 1 14 | decoder/transformer/repeat/layer/self_attention/prenorm/scale: 1 15 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_sigmoid_trainer/gala-sigmoid-7B-4k-hybridnorm-alibi-sp-rp_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/postnorm/scale: 1 8 | decoder/transformer/repeat/layer/feed_forward/prenorm/scale: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/qkv_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 11 | decoder/transformer/repeat/layer/self_attention/attention/scale_key/norm/scale: 1 12 | decoder/transformer/repeat/layer/self_attention/attention/scale_query/norm/scale: 1 13 | decoder/transformer/repeat/layer/self_attention/postnorm/scale: 1 14 | decoder/transformer/repeat/layer/self_attention/prenorm/scale: 1 15 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_sigmoid_trainer/gala-sigmoid-85M-4k-hybridnorm-alibi-sp-rp_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/postnorm/scale: 1 8 | decoder/transformer/repeat/layer/feed_forward/prenorm/scale: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/qkv_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 11 | decoder/transformer/repeat/layer/self_attention/attention/scale_key/norm/scale: 1 12 | decoder/transformer/repeat/layer/self_attention/attention/scale_query/norm/scale: 1 13 | decoder/transformer/repeat/layer/self_attention/postnorm/scale: 1 14 | decoder/transformer/repeat/layer/self_attention/prenorm/scale: 1 15 | -------------------------------------------------------------------------------- /axlearn/common/utils_tf_test.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | """Tests Tensorflow utils.""" 4 | import tensorflow as tf 5 | from absl.testing import absltest, parameterized 6 | 7 | from axlearn.common.utils_tf import masked_fill 8 | 9 | 10 | class MaskedFillTest(parameterized.TestCase): 11 | @parameterized.parameters(tf.string, tf.int32, tf.float32) 12 | def test_basic(self, dtype): 13 | if dtype == tf.string: 14 | orig = ["a", "b", "c"] 15 | fill_value = "x" 16 | else: 17 | orig = [1, 2, 3] 18 | fill_value = tf.constant(-1, dtype=dtype) 19 | result = masked_fill( 20 | tf.convert_to_tensor(orig, dtype=dtype), 21 | mask=tf.convert_to_tensor([True, False, True]), 22 | fill_value=fill_value, 23 | ) 24 | if dtype == tf.string: 25 | self.assertSequenceEqual(result.numpy().tolist(), [b"x", b"b", b"x"]) 26 | else: 27 | self.assertSequenceEqual(result.numpy().tolist(), [-1, 2, -1]) 28 | 29 | 30 | if __name__ == "__main__": 31 | absltest.main() 32 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_sigmoid_trainer/gala-sigmoid-1B-deterministic-4k-hybridnorm-alibi-pajama-2t-32k_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/postnorm/scale: 1 8 | decoder/transformer/repeat/layer/feed_forward/prenorm/scale: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/qkv_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 11 | decoder/transformer/repeat/layer/self_attention/attention/scale_key/norm/scale: 1 12 | decoder/transformer/repeat/layer/self_attention/attention/scale_query/norm/scale: 1 13 | decoder/transformer/repeat/layer/self_attention/postnorm/scale: 1 14 | decoder/transformer/repeat/layer/self_attention/prenorm/scale: 1 15 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_sigmoid_trainer/gala-sigmoid-7B-deterministic-4k-hybridnorm-alibi-pajama-2t-32k_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/postnorm/scale: 1 8 | decoder/transformer/repeat/layer/feed_forward/prenorm/scale: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/qkv_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 11 | decoder/transformer/repeat/layer/self_attention/attention/scale_key/norm/scale: 1 12 | decoder/transformer/repeat/layer/self_attention/attention/scale_query/norm/scale: 1 13 | decoder/transformer/repeat/layer/self_attention/postnorm/scale: 1 14 | decoder/transformer/repeat/layer/self_attention/prenorm/scale: 1 15 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_sigmoid_trainer/gala-sigmoid-85M-deterministic-4k-hybridnorm-alibi-pajama-2t-32k_regularizer.txt: -------------------------------------------------------------------------------- 1 | ====================weight_decay_scale root.optimizer==================== 2 | decoder/emb/token_emb/weight: 1 3 | decoder/output_norm/scale: 1 4 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 5 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 6 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 7 | decoder/transformer/repeat/layer/feed_forward/postnorm/scale: 1 8 | decoder/transformer/repeat/layer/feed_forward/prenorm/scale: 1 9 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/qkv_proj/weight: 1 10 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 11 | decoder/transformer/repeat/layer/self_attention/attention/scale_key/norm/scale: 1 12 | decoder/transformer/repeat/layer/self_attention/attention/scale_query/norm/scale: 1 13 | decoder/transformer/repeat/layer/self_attention/postnorm/scale: 1 14 | decoder/transformer/repeat/layer/self_attention/prenorm/scale: 1 15 | -------------------------------------------------------------------------------- /axlearn/vision/imagenet_adversarial_text/util_tfdata.py: -------------------------------------------------------------------------------- 1 | """Tensorflow data utils.""" 2 | 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | 7 | def get_dataset_cardinality(ds): 8 | batch_size = 1 9 | cardinality = np.sum([1 for i in ds.batch(batch_size)]) 10 | return cardinality 11 | 12 | 13 | # The following are copied from: 14 | # https://github.com/tensorflow/models/blob/master/research/object_detection/utils/dataset_util.py 15 | 16 | 17 | def int64_feature(value): 18 | return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) 19 | 20 | 21 | def int64_list_feature(value): 22 | return tf.train.Feature(int64_list=tf.train.Int64List(value=value)) 23 | 24 | 25 | def bytes_feature(value): 26 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 27 | 28 | 29 | def bytes_list_feature(value): 30 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=value)) 31 | 32 | 33 | def float_feature(value): 34 | return tf.train.Feature(float_list=tf.train.FloatList(value=[value])) 35 | 36 | 37 | def float_list_feature(value): 38 | return tf.train.Feature(float_list=tf.train.FloatList(value=value)) 39 | -------------------------------------------------------------------------------- /axlearn/vision/imagenet_adversarial_text/util_imagenet.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=invalid-name 2 | """ImageNet utils.""" 3 | import pickle 4 | 5 | import numpy as np 6 | 7 | # Obtained by OpenAI model's estimation, include 8 | # - target2esti: for every category, obtain its most frequent confused categories 9 | # - S: (1000x1000) np array, S = text_embedding.T * text_embedding 10 | FN_OPENAI_TARGET2ESTI = "openai_clip_pred_1tfrecord_target2esti.pickle" 11 | 12 | 13 | class ImageNet_SimilarClass: 14 | """Identifies similar classes.""" 15 | 16 | def __init__(self) -> None: 17 | with open(FN_OPENAI_TARGET2ESTI, "rb") as handle: 18 | target2esti, S = pickle.load(handle) 19 | self.target2esti = target2esti 20 | self.S = S 21 | 22 | def most_similar_class(self, c): 23 | mylist = None 24 | if c in self.target2esti: 25 | mylist = [tmp for tmp in self.target2esti[c] if tmp != c] 26 | if mylist: 27 | return max(set(mylist), key=mylist.count) 28 | 29 | # If there is not confused class, choose the similar from text embeddings. 30 | sc = self.S[c] 31 | sc[c] = 0 32 | return np.argmax(sc) 33 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v1_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[32, 8], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(8, 8, 2), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(8, 4, 2), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(8, 32), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(8, 32), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(32, 8), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v2_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[32, 8], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(8, 8, 2), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(8, 4, 2), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(8, 32), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(8, 32), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(32, 8), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v3_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[32, 8], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(8, 8, 2), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(8, 4, 2), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(8, 32), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(8, 32), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(32, 8), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v1-flash_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[32, 8], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(8, 8, 2), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(8, 4, 2), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(8, 32), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(8, 32), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(32, 8), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v2-flash_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[32, 8], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(8, 8, 2), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(8, 4, 2), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(8, 32), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(8, 32), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(32, 8), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v3-flash_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[32, 8], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(8, 8, 2), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(8, 4, 2), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(8, 32), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(8, 32), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(32, 8), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v3-tiktoken_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[32, 8], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(8, 8, 2), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(8, 4, 2), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(8, 32), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(8, 32), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(32, 8), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-golden-run-test-v1_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[32, 8], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(8, 8, 2), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(8, 4, 2), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(8, 32), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(8, 32), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(32, 8), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-golden-run-test-v2_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[32, 8], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(8, 8, 2), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(8, 4, 2), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(8, 32), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(8, 32), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(32, 8), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-golden-run-test-v3_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[32, 8], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(8, 8, 2), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(8, 4, 2), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(8, 32), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(8, 32), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(32, 8), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v3-tiktoken-flash_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[32, 8], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(8, 8, 2), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(8, 4, 2), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(8, 32), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(8, 32), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(32, 8), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-golden-run-test-v3-tiktoken_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[32, 8], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(8, 8, 2), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(8, 4, 2), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(8, 32), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(8, 32), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(32, 8), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[131072, 2048], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 48, 64), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 32, 64), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(2048, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(2048, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(8192, 2048), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[131072, 3072], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(3072, 40, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(3072, 24, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(3072, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(3072, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(8192, 3072), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[32768, 4096], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 32, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 32, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(4096, 11008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(4096, 11008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(11008, 4096), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[32768, 4096], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 32, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 32, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(4096, 11008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(4096, 11008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(11008, 4096), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/text/gpt/gala_sigmoid_test.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2024 Apple Inc. 2 | 3 | """Tests Gala sigmoid methods.""" 4 | from absl.testing import absltest 5 | 6 | from axlearn.common import input_tf_data, utils 7 | from axlearn.common.config import config_for_function 8 | from axlearn.common.test_utils import TestCase 9 | from axlearn.common.trainer import SpmdTrainer 10 | from axlearn.experiments.text.gpt.common import mixture_train_input_source 11 | from axlearn.experiments.text.gpt.gala_sigmoid import _set_seq_len_recursively 12 | 13 | 14 | class SetConfigTest(TestCase): 15 | def test_set_seq_len_recursively(self): 16 | train_input_source = config_for_function(mixture_train_input_source).set( 17 | max_sequence_length=200 18 | ) 19 | cfg = SpmdTrainer.default_config().set( 20 | input=input_tf_data.Input.default_config().set(source=train_input_source) 21 | ) 22 | 23 | self.assertEqual(cfg.input.source.max_sequence_length, 200) 24 | _set_seq_len_recursively(cfg, max_sequence_length=100) 25 | self.assertEqual(cfg.input.source.max_sequence_length, 100) 26 | 27 | 28 | if __name__ == "__main__": 29 | with utils.numeric_checks(True): 30 | absltest.main() 31 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[131072, 2048], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 48, 64), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 32, 64), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(2048, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(2048, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(8192, 2048), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 2048], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 48, 64), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 32, 64), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(2048, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(2048, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(8192, 2048), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[131072, 3072], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(3072, 40, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(3072, 24, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(3072, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(3072, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(8192, 3072), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[131072, 4096], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 48, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 32, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(4096, 11008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(4096, 11008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(11008, 4096), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-single-host_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[131072, 2048], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 48, 64), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 32, 64), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(2048, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(2048, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(8192, 2048), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-single-host_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[131072, 3072], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(3072, 40, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(3072, 24, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(3072, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(3072, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(8192, 3072), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 3072], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(3072, 40, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(3072, 24, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(3072, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(3072, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(8192, 3072), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-flash_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[32768, 4096], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 32, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 32, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(4096, 11008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(4096, 11008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(11008, 4096), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-flash_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[32768, 4096], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 32, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 32, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(4096, 11008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(4096, 11008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(11008, 4096), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3-flash_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[131072, 4096], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 48, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 32, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(4096, 11008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(4096, 11008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(11008, 4096), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash-single-host_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[131072, 2048], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 48, 64), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 32, 64), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(2048, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(2048, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(8192, 2048), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken-flash_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 2048], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 48, 64), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 32, 64), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(2048, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(2048, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(8192, 2048), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken-flash_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 3072], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(3072, 40, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(3072, 24, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(3072, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(3072, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(8192, 3072), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-single-host_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[32768, 4096], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 32, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 32, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(4096, 11008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(4096, 11008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(11008, 4096), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-single-host_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[32768, 4096], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 32, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 32, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(4096, 11008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(4096, 11008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(11008, 4096), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3-single-host_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[131072, 4096], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 48, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 32, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(4096, 11008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(4096, 11008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(11008, 4096), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken-single-host_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 2048], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 48, 64), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 32, 64), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(2048, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(2048, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(8192, 2048), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash-single-host_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[131072, 3072], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(3072, 40, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(3072, 24, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(3072, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(3072, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(8192, 3072), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken-single-host_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 3072], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(3072, 40, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(3072, 24, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(3072, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(3072, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(8192, 3072), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-flash-single-host_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[32768, 4096], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 32, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 32, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(4096, 11008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(4096, 11008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(11008, 4096), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-flash-single-host_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[32768, 4096], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 32, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 32, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(4096, 11008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(4096, 11008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(11008, 4096), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3-flash-single-host_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[131072, 4096], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 48, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 32, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(4096, 11008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(4096, 11008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(11008, 4096), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/cloud/gcp/runners/utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2024 Apple Inc. 2 | 3 | """Tests runner utils.""" 4 | 5 | from absl.testing import parameterized 6 | 7 | from axlearn.cloud.gcp.runners import utils as runner_utils 8 | 9 | 10 | class UtilsTest(parameterized.TestCase): 11 | @parameterized.parameters( 12 | dict(tier=None, reservation=None, expected=False), 13 | # Demoted -- should be rescheduled. 14 | dict(tier=None, reservation="test", expected=True), 15 | # Demoted -- should be rescheduled. 16 | dict(tier="1", reservation="test", expected=True), 17 | # Demoted -- should be rescheduled. 18 | dict(tier=None, reservation="test", expected=True, is_pending=True), 19 | # Promoted -- do not reschedule. Instead, let pre-emption trigger reschedule. 20 | dict(tier="0", reservation=None, expected=False), 21 | # Promoted, but job is pending. Take this opportunity to reschedule. 22 | dict(tier="0", reservation=None, expected=True, is_pending=True), 23 | ) 24 | def test_should_recreate_job(self, tier, reservation, expected, is_pending=False): 25 | self.assertEqual( 26 | expected, runner_utils.should_recreate_job(tier, reservation, is_pending=is_pending) 27 | ) 28 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken-flash-single-host_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 2048], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 48, 64), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 32, 64), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(2048, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(2048, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(8192, 2048), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken-flash-single-host_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 3072], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(3072, 40, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(3072, 24, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(3072, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(3072, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(8192, 3072), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/vision/imagenet_adversarial_text/README.md: -------------------------------------------------------------------------------- 1 | This folder includes scripts to generate ImageNet Adversarial Text Regions (ImageNet-Atr) dataset, which is used as the evaluation set in our paper ["Less is More: Removing Text-regions Improves CLIP Training Efficiency and Robustness", 2023](https://arxiv.org/abs/2305.05095). 2 | 3 | To generate the dataset by your own: see [add_attack_tfrecord.py](add_attack_tfrecord.py). It does not rely on other AXLearn libraries. Instead, we can run the code on a laptop with just the current folder. 4 | 5 | The dataset is available on GCP as a Tensorflow Dataset: `gs://axlearn-public/tensorflow_datasets/imagenet2012_ocr_attack/`. 6 | 7 | 8 | ## Recognition results of different CLIP models 9 | 10 | | Model | Num. Training | ImageNet2012 Top-1 Acc | ImageNet_Atr Top-1 Acc| 11 | | ---------------- | ----------- |----------------------- |------------- | 12 | | OpenAI CLIP B16 | 400M | 68.35% | 31.65% | 13 | | OpenCLIP B16 | 400M | 66.99% | 29.55% | 14 | | Our CLIP B16 | 1.1B | 68.66% | 35.73% | 15 | | Our Filter-based CLIP B16 | 0.7B | 70.77% | 68.78% | 16 | 17 | See [zeroshot_eval_with_opensource_clip.ipynb](zeroshot_eval_with_opensource_clip.ipynb) for evaluation code. 18 | -------------------------------------------------------------------------------- /.github/workflows/build.yml: -------------------------------------------------------------------------------- 1 | name: build-and-test 2 | on: [pull_request, push, merge_group] 3 | 4 | jobs: 5 | build-and-test-job: 6 | runs-on: ubuntu-latest 7 | strategy: 8 | matrix: 9 | test-group: [a, b, c, d, e] 10 | steps: 11 | - uses: actions/checkout@v4 12 | - uses: docker/setup-buildx-action@v3 13 | - name: Gather test files 14 | run: find axlearn -name '*_test.py' > pytest_files.txt 15 | - name: Split test files into groups 16 | # GNU split lets us do "-n r/5" to round robin into 5 files without breaking lines 17 | # BSD split requires knowing the number of lines and uses "-l XX" 18 | run: split -n r/5 -a 1 pytest_files.txt split_pytest_files 19 | - name: Select a test group 20 | run: tr '\n' ' ' < split_pytest_files${{ matrix.test-group }} > test_files_oneline 21 | - name: Read test inputs 22 | id: test-selector 23 | run: echo "PYTEST_FILES='$(cat test_files_oneline)'" >> "$GITHUB_OUTPUT" 24 | - name: Run tests 25 | uses: docker/build-push-action@v6 26 | with: 27 | push: false 28 | target: ci 29 | context: . 30 | build-args: | 31 | SKIP_PRECOMMIT=--skip-pre-commit 32 | PYTEST_FILES=${{ steps.test-selector.outputs.PYTEST_FILES }} 33 | -------------------------------------------------------------------------------- /axlearn/common/base_model.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | """Base model definition.""" 4 | 5 | from axlearn.common.base_layer import BaseLayer 6 | from axlearn.common.module import NestedTensor, Tensor 7 | 8 | 9 | class BaseModel(BaseLayer): 10 | """The base class of a model. 11 | 12 | Some subclasses also implement a `predict` method: 13 | 14 | def predict(self, input_batch: NestedTensor, **kwargs) -> NestedTensor: 15 | Computes predictions with the given inputs. 16 | 17 | Args: 18 | input_batch: a NestedTensor representing an input batch, containing Tensors with a 19 | leading dimension of `batch_size`. 20 | 21 | Returns: 22 | A NestedTensor containing Tensors with a leading dimension of `batch_size`. 23 | """ 24 | 25 | def forward(self, input_batch: NestedTensor) -> tuple[Tensor, NestedTensor]: 26 | """Computes loss and auxiliary outputs with the given inputs. 27 | 28 | Args: 29 | input_batch: a NestedTensor representing an input batch. 30 | 31 | Returns: 32 | (loss, aux), where `loss` is a scalar Tensor representing the model loss and `aux` 33 | is a NestedTensor containing model-specific auxiliary outputs. 34 | """ 35 | raise NotImplementedError(type(self)) 36 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Change Log 2 | 3 | ## 0.1.6 4 | 5 | * Changes 6 | * Upgrade Jax from 0.4.37 to 0.4.38. 7 | * Removes all QRM (queued resource manager) codepaths from `axlearn.cloud.gcp`. 8 | * Introduces `named_runner_configs`. See `axlearn gcp launch --help` for details. 9 | * Upgrade Grain from 0.2.3 to 0.2.7. This removes `input_grain.trim_and_pack_dataset`. 10 | 11 | ## 0.1.5 12 | 13 | * Changes 14 | * Upgrade Jax from 0.4.33 to 0.4.37. 15 | 16 | ## 0.1.4 17 | 18 | * Changes 19 | * Upgrade Jax from 0.4.33 to 0.4.34. 20 | * Updates the `input_base.Input` API to support configuring input partitioning behavior. 21 | * The config fields `batch_axis_names` and `seq_axis_names` in `causal_lm.Model` are now deprecated. Please use `input_base.Input.input_partitioner` instead. 22 | * Updates the `causal_lm.Model` API to support configuring metrics without subclassing. This requires a golden config change. 23 | 24 | ## 0.1.3 25 | 26 | * Changes 27 | * Upgrade Jax from 0.4.30 to 0.4.33. 28 | 29 | ## 0.1.2 30 | 31 | * Changes 32 | * Upgrade Python to 3.10 33 | * Fall back to triton backend for qkv in fp32 or with bias on gpu flash attention. 34 | 35 | ## 0.1.1 36 | 37 | * Changes 38 | * Upgrade Jax from 0.4.28 to 0.4.30. 39 | 40 | ## 0.1.0 (Aug 22, 2024) 41 | 42 | * Changes 43 | * Add changelog. 44 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.vision.resnet.imagenet_trainer/ResNet-Testb_init.txt: -------------------------------------------------------------------------------- 1 | backbone/stem/conv1/weight: normal(0, 1.4142135623730951 / fan_out), shape=[7, 7, 3, 4], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | backbone/stem/norm1/scale: constant(1.0) 3 | backbone/stem/norm1/bias: constant(0.0) 4 | backbone/stem/norm1/moving_mean: constant(0.0) 5 | backbone/stem/norm1/moving_variance: constant(1.0) 6 | backbone/stage0/block0/conv1/weight: normal(0, 1.4142135623730951 / fan_out), shape=[3, 3, 4, 4], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | backbone/stage0/block0/norm1/scale: constant(1.0) 8 | backbone/stage0/block0/norm1/bias: constant(0.0) 9 | backbone/stage0/block0/norm1/moving_mean: constant(0.0) 10 | backbone/stage0/block0/norm1/moving_variance: constant(1.0) 11 | backbone/stage0/block0/conv2/weight: normal(0, 1.4142135623730951 / fan_out), shape=[3, 3, 4, 4], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 12 | backbone/stage0/block0/norm2/scale: constant(1.0) 13 | backbone/stage0/block0/norm2/bias: constant(0.0) 14 | backbone/stage0/block0/norm2/moving_mean: constant(0.0) 15 | backbone/stage0/block0/norm2/moving_variance: constant(1.0) 16 | classifier/weight: normal(0, 1.4142135623730951 / fan_out), shape=(4, 1000), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 17 | classifier/bias: constant(0.0) -------------------------------------------------------------------------------- /axlearn/vision/mask_generator_test.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | """Tests mask generator.""" 4 | import numpy as np 5 | from absl.testing import absltest, parameterized 6 | 7 | from axlearn.vision import mask_generator 8 | 9 | 10 | class MaskGeneratorTest(parameterized.TestCase): 11 | """Tests MaskingGenerator.""" 12 | 13 | @parameterized.product( 14 | num_masking_patches=(0, 118, 196), 15 | max_aspect=(None, 0.8, 10.0), 16 | max_mask_patches=(None, 100), 17 | ) 18 | def test_mask_generation(self, num_masking_patches, max_aspect, max_mask_patches): 19 | input_size = 14 20 | model = mask_generator.MaskingGenerator( 21 | input_size=(input_size, input_size), 22 | num_masking_patches=num_masking_patches, 23 | max_aspect=max_aspect, 24 | min_mask_patches=16, 25 | max_mask_patches=max_mask_patches, 26 | ) 27 | mask = model() 28 | self.assertEqual(mask.sum(), num_masking_patches) 29 | if num_masking_patches == 0: 30 | np.testing.assert_array_equal(mask, np.zeros(shape=mask.shape, dtype=np.int32)) 31 | if num_masking_patches == input_size * input_size: 32 | np.testing.assert_array_equal(mask, np.ones(shape=mask.shape, dtype=np.int32)) 33 | 34 | 35 | if __name__ == "__main__": 36 | absltest.main() 37 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[32768, 8192], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(8192, 64, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(8192, 64, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(8192, 28672), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(8192, 28672), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(28672, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) 10 | decoder/lm_head/weight: normal(0, 1.0 / fan_in), shape=(32768, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[32768, 8192], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(8192, 80, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(8192, 64, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(8192, 28672), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(8192, 28672), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(28672, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) 10 | decoder/lm_head/weight: normal(0, 1.0 / fan_in), shape=(32768, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[131072, 8192], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(8192, 80, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(8192, 64, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(8192, 28672), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(8192, 28672), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(28672, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) 10 | decoder/lm_head/weight: normal(0, 1.0 / fan_in), shape=(131072, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) -------------------------------------------------------------------------------- /axlearn/common/learner_base.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2024 Apple Inc. 2 | """Interfaces for learners and modules used inside of them.""" 3 | 4 | from __future__ import annotations 5 | 6 | from typing import Any 7 | 8 | from axlearn.common.base_layer import ParameterSpec 9 | from axlearn.common.module import Module 10 | from axlearn.common.optimizer_base import OptParam 11 | from axlearn.common.utils import Nested 12 | 13 | 14 | class LearnerModule(Module): 15 | """Any stateful module used inside a `BaseLearner`, including the learner itself. 16 | 17 | E.g., an `Ema` module that could be used to compute an EMA in all the places we need to compute 18 | an EMA in optimizers. 19 | """ 20 | 21 | def create_state_partition_specs(self, model_param_specs: Nested[ParameterSpec]) -> Any: 22 | """Creates learner state partition_specs. 23 | 24 | The return type is a pytree with `TensorSpec`s as leaves. 25 | Must have the same tree structure returned by `init()`. 26 | """ 27 | raise NotImplementedError(type(self)) 28 | 29 | def init(self, model_params: Nested[OptParam]) -> Any: 30 | """Initializes learner state. 31 | 32 | The return type is a pytree with `Tensor`s as leaves. 33 | Must have the same tree structure returned by `create_state_partition_specs()`. 34 | """ 35 | raise NotImplementedError(type(self)) 36 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1-flash_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[32768, 8192], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(8192, 64, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(8192, 64, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(8192, 28672), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(8192, 28672), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(28672, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) 10 | decoder/lm_head/weight: normal(0, 1.0 / fan_in), shape=(32768, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2-flash_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[32768, 8192], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(8192, 80, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(8192, 64, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(8192, 28672), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(8192, 28672), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(28672, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) 10 | decoder/lm_head/weight: normal(0, 1.0 / fan_in), shape=(32768, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[131072, 8192], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(8192, 80, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(8192, 64, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(8192, 28672), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(8192, 28672), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(28672, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) 10 | decoder/lm_head/weight: normal(0, 1.0 / fan_in), shape=(131072, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-tiktoken_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 8192], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(8192, 80, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(8192, 64, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(8192, 28672), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(8192, 28672), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(28672, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) 10 | decoder/lm_head/weight: normal(0, 1.0 / fan_in), shape=(128256, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 4096], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 48, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 32, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(4096, 14336), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(4096, 14336), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(14336, 4096), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) 10 | decoder/lm_head/weight: normal(0, 1.0 / fan_in), shape=(128256, 4096), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-tiktoken-flash_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 8192], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(8192, 80, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(8192, 64, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(8192, 28672), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(8192, 28672), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(28672, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) 10 | decoder/lm_head/weight: normal(0, 1.0 / fan_in), shape=(128256, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-flash_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 4096], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 48, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 32, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(4096, 14336), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(4096, 14336), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(14336, 4096), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) 10 | decoder/lm_head/weight: normal(0, 1.0 / fan_in), shape=(128256, 4096), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-single-host_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 4096], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 48, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 32, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(4096, 14336), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(4096, 14336), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(14336, 4096), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) 10 | decoder/lm_head/weight: normal(0, 1.0 / fan_in), shape=(128256, 4096), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-flash-single-host_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 4096], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 48, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 32, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(4096, 14336), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 7 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(4096, 14336), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 8 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(14336, 4096), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/output_norm/scale: constant(1.0) 10 | decoder/lm_head/weight: normal(0, 1.0 / fan_in), shape=(128256, 4096), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/gala-test-sp-rp_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[32, 8], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(8, 4, 2), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(8, 4, 2), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/self_attention/attention/scale_query/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/self_attention/attention/scale_key/norm/scale: constant(1.0) 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 8 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(8, 32), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(8, 32), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 10 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(32, 8), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 11 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/gala-test-flash-sp-rp_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[32, 8], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(8, 4, 2), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(8, 4, 2), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/self_attention/attention/scale_query/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/self_attention/attention/scale_key/norm/scale: constant(1.0) 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 8 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(8, 32), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(8, 32), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 10 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(32, 8), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 11 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/honeycrisp-test-sp-rp_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[32, 8], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(8, 8, 2), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(8, 4, 2), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/self_attention/attention/scale_query/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/self_attention/attention/scale_key/norm/scale: constant(1.0) 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 8 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(8, 32), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(8, 32), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 10 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(32, 8), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 11 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/honeycrisp-test-flash-sp-rp_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[32, 8], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(8, 8, 2), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(8, 4, 2), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/self_attention/attention/scale_query/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/self_attention/attention/scale_key/norm/scale: constant(1.0) 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 8 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(8, 32), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(8, 32), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 10 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(32, 8), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 11 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/cloud/gcp/job_pathways.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2025 Apple Inc. 2 | 3 | """A helper module to launch and manage Pathways jobset on GKE.""" 4 | 5 | from typing import Any 6 | 7 | from axlearn.cloud.gcp.job import GKEJob 8 | from axlearn.cloud.gcp.pathways_utils import ( 9 | _PATHWAYS_HEAD_REPLICATED_JOB_NAME, 10 | PathwaysReplicatedJob, 11 | ) 12 | from axlearn.common.utils import Nested 13 | 14 | 15 | class GKEPathwaysJobSet(GKEJob): 16 | """A Job that manages Pathways jobset""" 17 | 18 | def __init__(self, cfg: GKEJob.Config, *, bundler): 19 | super().__init__(cfg, bundler=bundler) 20 | # TODO(ethanli): Refactor to generalize so we don't need the special case here. 21 | if not isinstance(cfg.builder, PathwaysReplicatedJob.Config): 22 | raise NotImplementedError(type(cfg.builder)) 23 | 24 | def _build_jobset(self) -> Nested[Any]: 25 | jobset = super()._build_jobset() 26 | 27 | # TODO (ethanli): Consider refactoring with the modifiers pattern. 28 | jobset["spec"]["coordinator"] = dict( 29 | replicatedJob=_PATHWAYS_HEAD_REPLICATED_JOB_NAME, 30 | jobIndex=0, 31 | podIndex=0, 32 | ) 33 | 34 | jobset["spec"]["successPolicy"] = dict( 35 | operator="All", 36 | targetReplicatedJobs=[ 37 | _PATHWAYS_HEAD_REPLICATED_JOB_NAME, 38 | ], 39 | ) 40 | 41 | return jobset 42 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/gala-85M-sp-rp_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[32768, 768], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(768, 12, 64), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(768, 12, 64), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/self_attention/attention/scale_query/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/self_attention/attention/scale_key/norm/scale: constant(1.0) 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 8 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(768, 2048), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(768, 2048), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 10 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(2048, 768), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 11 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/vision/nms_test.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | """Tests NMS utils.""" 4 | import jax.numpy as jnp 5 | import tensorflow as tf 6 | from absl.testing import absltest, parameterized 7 | 8 | from axlearn.vision import nms 9 | 10 | 11 | class NMSTest(parameterized.TestCase, tf.test.TestCase): 12 | def test_nms(self): 13 | scores = jnp.asarray([0.9, 0.8, 0.7]) 14 | boxes = jnp.asarray( 15 | [ 16 | [1.0, 10.0, 2, 20.0], # will be kept. 17 | [1.1, 10.1, 2.1, 20.1], # will be suppressed by the first box. 18 | [30.0, 50.0, 40.0, 60.0], # will be kept. 19 | ] 20 | ) 21 | max_output_size = 3 22 | iou_threshold = 0.5 23 | 24 | nmsed_scores, nmsed_boxes = nms.non_max_suppression_padded( 25 | jnp.expand_dims(scores, axis=0), 26 | jnp.expand_dims(boxes, axis=0), 27 | max_output_size, 28 | iou_threshold, 29 | ) 30 | expected_scores = [[0.9, 0.7, 0.0]] 31 | expected_boxes = [ 32 | [ 33 | [1.0, 10.0, 2, 20.0], 34 | [30.0, 50.0, 40.0, 60.0], 35 | [0.0, 0.0, 0.0, 0.0], # padded box. 36 | ] 37 | ] 38 | self.assertAllClose(nmsed_scores, expected_scores) 39 | self.assertAllClose(nmsed_boxes, expected_boxes) 40 | 41 | 42 | if __name__ == "__main__": 43 | absltest.main() 44 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/gala-1B-sp-rp_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[32768, 2048], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 32, 64), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 32, 64), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/self_attention/attention/scale_query/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/self_attention/attention/scale_key/norm/scale: constant(1.0) 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 8 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(2048, 5632), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(2048, 5632), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 10 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(5632, 2048), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 11 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/gala-302M-sp-rp_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[32768, 1024], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(1024, 16, 64), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(1024, 16, 64), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/self_attention/attention/scale_query/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/self_attention/attention/scale_key/norm/scale: constant(1.0) 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 8 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(1024, 2816), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(1024, 2816), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 10 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(2816, 1024), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 11 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/gala-85M-flash-sp-rp_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[32768, 768], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(768, 12, 64), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(768, 12, 64), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/self_attention/attention/scale_query/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/self_attention/attention/scale_key/norm/scale: constant(1.0) 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 8 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(768, 2048), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(768, 2048), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 10 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(2048, 768), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 11 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/honeycrisp-85M-sp-rp_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[49152, 768], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(768, 24, 64), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(768, 12, 64), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/self_attention/attention/scale_query/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/self_attention/attention/scale_key/norm/scale: constant(1.0) 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 8 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(768, 2304), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(768, 2304), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 10 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(2304, 768), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 11 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/gala-1B-flash-sp-rp_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[32768, 2048], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 32, 64), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 32, 64), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/self_attention/attention/scale_query/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/self_attention/attention/scale_key/norm/scale: constant(1.0) 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 8 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(2048, 5632), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(2048, 5632), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 10 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(5632, 2048), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 11 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/gala-302M-flash-sp-rp_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[32768, 1024], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(1024, 16, 64), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(1024, 16, 64), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/self_attention/attention/scale_query/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/self_attention/attention/scale_key/norm/scale: constant(1.0) 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 8 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(1024, 2816), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(1024, 2816), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 10 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(2816, 1024), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 11 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/gala-7B-sp-rp_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[32768, 4096], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 32, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 32, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/self_attention/attention/scale_query/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/self_attention/attention/scale_key/norm/scale: constant(1.0) 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 8 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(4096, 11008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(4096, 11008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 10 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(11008, 4096), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 11 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/honeycrisp-3B-sp-rp_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[49152, 3072], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(3072, 40, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(3072, 24, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/self_attention/attention/scale_query/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/self_attention/attention/scale_key/norm/scale: constant(1.0) 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 8 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(3072, 8064), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(3072, 8064), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 10 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(8064, 3072), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 11 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/honeycrisp-85M-flash-sp-rp_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[49152, 768], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(768, 24, 64), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(768, 12, 64), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/self_attention/attention/scale_query/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/self_attention/attention/scale_key/norm/scale: constant(1.0) 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 8 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(768, 2304), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(768, 2304), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 10 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(2304, 768), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 11 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/cloud/gcp/scopes.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | """Pre-defined scopes for gcloud credentials. 4 | 5 | See: https://developers.google.com/identity/protocols/oauth2/scopes 6 | """ 7 | 8 | _OPEN_ID = "openid" # Google's OAuth 2.0 API for OpenID Connect. 9 | 10 | # Detailed documentation and permissions for each scope can be found on 11 | # https://developers.google.com/identity/protocols/oauth2/scopes 12 | _CLOUD_PLATFORM = "https://www.googleapis.com/auth/cloud-platform" # general GCP API call 13 | _CLOUD_TPU = "https://www.googleapis.com/auth/cloud.tpu" # TPU operations 14 | _COMPUTE = "https://www.googleapis.com/auth/compute" # GCE operations 15 | # Read primary email address of token owner. 16 | _EMAIL = "https://www.googleapis.com/auth/userinfo.email" 17 | _SQL_LOGIN = "https://www.googleapis.com/auth/sqlservice.login" # CloudSQL login 18 | _STORAGE_RW = "https://www.googleapis.com/auth/devstorage.read_write" # GCS Read-Write 19 | 20 | 21 | # For typical TPU node and queued resource operations, including list, create, delete etc. 22 | # See: https://cloud.google.com/tpu/docs/reference/rest 23 | # TODO(Zhaoyi): create more granular scopes for TPU operations. 24 | DEFAULT_TPU_SCOPES = [_CLOUD_TPU, _CLOUD_PLATFORM] 25 | 26 | # Same scopes used by gcloud auth application-default login. 27 | # based on https://cloud.google.com/sdk/gcloud/reference/auth/application-default/login 28 | DEFAULT_APPLICATION = [_OPEN_ID, _EMAIL, _CLOUD_PLATFORM, _SQL_LOGIN] 29 | -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/gala-7B-flash-sp-rp_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[32768, 4096], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 32, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 32, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/self_attention/attention/scale_query/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/self_attention/attention/scale_key/norm/scale: constant(1.0) 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 8 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(4096, 11008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(4096, 11008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 10 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(11008, 4096), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 11 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/honeycrisp-3B-flash-sp-rp_init.txt: -------------------------------------------------------------------------------- 1 | decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[49152, 3072], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 2 | decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) 3 | decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(3072, 40, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) 4 | decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(3072, 24, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) 5 | decoder/transformer/repeat/layer/self_attention/attention/scale_query/norm/scale: constant(1.0) 6 | decoder/transformer/repeat/layer/self_attention/attention/scale_key/norm/scale: constant(1.0) 7 | decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) 8 | decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(3072, 8064), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 9 | decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(3072, 8064), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 10 | decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(8064, 3072), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) 11 | decoder/output_norm/scale: constant(1.0) -------------------------------------------------------------------------------- /axlearn/common/debug_utils_test.py: -------------------------------------------------------------------------------- 1 | """Tests for debug_utils.py""" 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | from jax.experimental import checkify 6 | 7 | from axlearn.common import test_utils 8 | from axlearn.common.debug_utils import checkify_pjit 9 | 10 | 11 | class TestPjitWrappers(test_utils.TestCase): 12 | """Tests for functions that produce `JitFn` wrappers.""" 13 | 14 | def setUp(self): 15 | super().setUp() 16 | self.enter_context(jax.checking_leaks()) 17 | 18 | def test_checkify_pjit(self): 19 | """Tests `checkify_pjit` in "JIT" mode.""" 20 | wrapped_pjit = checkify_pjit(errors=checkify.float_checks) 21 | 22 | @wrapped_pjit 23 | def fn(x, y): 24 | return x / y 25 | 26 | self.assertNestedAllClose(fn(8, 2), 4.0, atol=0, rtol=1e-6) 27 | with self.assertRaisesRegex(checkify.JaxRuntimeError, "division by zero"): 28 | fn(6, 0) 29 | 30 | def test_checkify_pjit_compiled(self): 31 | """Tests `checkify_pjit` in "ahead-of-time ompiled" mode.""" 32 | wrapped_pjit = checkify_pjit(errors=checkify.nan_checks) 33 | 34 | @wrapped_pjit 35 | def fn(x): 36 | return jnp.log(x) 37 | 38 | compiled_fn = fn.lower(1.0).compile() 39 | self.assertNestedAllClose(compiled_fn(jnp.exp(1.0)), 1.0, atol=0, rtol=1e-6) 40 | with self.assertRaisesRegex(checkify.JaxRuntimeError, "nan generated by primitive: log."): 41 | compiled_fn(-1.0) 42 | -------------------------------------------------------------------------------- /axlearn/experiments/vision/resnet/common_test.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | """Tests ResNet config builders.""" 4 | 5 | from absl.testing import parameterized 6 | 7 | from axlearn.common import schedule 8 | from axlearn.common.config import config_for_function 9 | from axlearn.common.test_utils import TestCase 10 | from axlearn.experiments.vision.resnet.common import learner_config, model_config 11 | from axlearn.vision.resnet import ResNet 12 | 13 | 14 | class ConfigTest(TestCase): 15 | """Tests configs.""" 16 | 17 | @parameterized.product( 18 | learning_rate=[ 19 | 1.0, 20 | schedule.polynomial(end_value=10), 21 | config_for_function(schedule.polynomial).set(end_value=10), 22 | ], 23 | ema_decay=[None, 0.9], 24 | ) 25 | def test_learner_config(self, **kwargs): 26 | cfg = learner_config(**kwargs) 27 | self.assertEqual(cfg.optimizer.learning_rate, kwargs["learning_rate"]) 28 | # Make sure that we can instantiate. 29 | learner = cfg.set(name="test").instantiate(parent=None) 30 | if kwargs["ema_decay"] is not None: 31 | self.assertIsNotNone(learner.ema) 32 | 33 | def test_model_config(self): 34 | cfg = model_config() 35 | # We should be able to cfg.set(backbone=..., num_classes=...). 36 | cfg.set(backbone=ResNet.resnet18_config(), num_classes=100) 37 | # Make sure we can instantiate. 38 | cfg.set(name="test").instantiate(parent=None) 39 | -------------------------------------------------------------------------------- /axlearn/experiments/trainer_config_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | """Trainer config utilities.""" 4 | from typing import Optional 5 | 6 | from typing_extensions import Protocol 7 | 8 | from axlearn.common.config import InstantiableConfig, config_class 9 | from axlearn.common.flash_attention.layer import FlashBlockSizeModifier 10 | 11 | 12 | class TrainerConfigFn(Protocol): 13 | """A TrainerConfigFn takes a data_dir as argument and returns a Config for instantiating a 14 | Trainer, e.g. SpmdTrainer. 15 | """ 16 | 17 | # Note: avoid using SpmdTrainer.Config so we don't need to introduce a dependency to trainer. 18 | # This also makes it possible to define custom trainers with the same protocol. 19 | def __call__(self, data_dir: Optional[str] = None) -> InstantiableConfig: 20 | ... 21 | 22 | 23 | def with_overrides(trainer_config_fn: TrainerConfigFn, **kwargs) -> TrainerConfigFn: 24 | """Patches the trainer config produced by the trainer_config_fn.""" 25 | 26 | def wrapped_fn(): 27 | trainer_cfg = trainer_config_fn() 28 | trainer_cfg.set(**kwargs) 29 | return trainer_cfg 30 | 31 | return wrapped_fn 32 | 33 | 34 | class V6eFlashConfigModifier(FlashBlockSizeModifier): 35 | """Modified the tpu_block_size config for better performance on TPU v6e.""" 36 | 37 | @config_class 38 | class Config(FlashBlockSizeModifier.Config): 39 | """Configures V6eFlashConfigModifier.""" 40 | 41 | tpu_block_size: int = 1024 42 | -------------------------------------------------------------------------------- /conftest.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2024 Apple Inc. 2 | 3 | """Configures pytest to distribute tests to multiple GPUs. 4 | 5 | This is not enabled by default and requires explicit opt-in by setting the environment variable 6 | `AXLEARN_CI_GPU_TESTS`. This is because not all GPU tests are single-GPU tests. 7 | An additional environment variable `AXLEARN_CI_NUM_DEVICES_PER_WORKER` can be set to control the 8 | number of GPUs visible to each worker. 9 | 10 | Example usage on 8-GPU machines: 11 | - 1 GPU per worker: 12 | AXLEARN_CI_GPU_TESTS=1 pytest -n 8 axlearn/common/flash_attention/gpu_attention_test.py 13 | - 4 GPUs per worker: 14 | AXLEARN_CI_GPU_TESTS=1 pytest -n 32 axlearn/common/flash_attention/gpu_attention_test.py 15 | """ 16 | import os 17 | 18 | 19 | # pylint: disable-next=unused-argument 20 | def pytest_configure(config): 21 | if "AXLEARN_CI_GPU_TESTS" not in os.environ: 22 | return 23 | worker_idx = int(os.getenv("PYTEST_XDIST_WORKER", "gw0").lstrip("gw")) 24 | # Evenly distribute work to all GPUs. 25 | num_devices_per_worker = int(os.environ.get("AXLEARN_CI_NUM_DEVICES_PER_WORKER", "1")) 26 | num_devices = int(os.environ.get("AXLEARN_CI_NUM_DEVICES", "8")) 27 | starting_device_idx = ( 28 | worker_idx % (num_devices // num_devices_per_worker) 29 | ) * num_devices_per_worker 30 | os.environ["CUDA_VISIBLE_DEVICES"] = ",".join( 31 | str(device_idx) 32 | for device_idx in range(starting_device_idx, starting_device_idx + num_devices_per_worker) 33 | ) 34 | --------------------------------------------------------------------------------