├── paxml
├── docs
│ ├── images
│ │ ├── 16B-loss.png
│ │ ├── 16B-pplx.png
│ │ ├── 1B-loss.png
│ │ ├── 1B-pplx.png
│ │ ├── GPT3-XL-loss.png
│ │ ├── GPT3-XL-pplx.png
│ │ └── Weak_scaling_of_large_language_model_training_on_TPU_v4.png
│ ├── README.md
│ ├── tasks.md
│ ├── tutorials
│ │ ├── inputs_in_Pax-eval.ipynb
│ │ └── inputs_in_Pax-train.ipynb
│ ├── models.md
│ └── hands-on-tutorials.md
├── tasks
│ ├── lm
│ │ ├── testdata
│ │ │ └── tfrecords
│ │ ├── __init__.py
│ │ ├── BUILD
│ │ ├── params
│ │ │ ├── optimal_scaling.py
│ │ │ ├── c4_test.py
│ │ │ └── BUILD
│ │ └── input_generator_test.py
│ ├── test
│ │ ├── BUILD
│ │ └── synthetic.py
│ └── vision
│ │ ├── BUILD
│ │ ├── input_generator_test.py
│ │ └── params
│ │ └── BUILD
├── contrib
│ └── gpu
│ │ ├── README.md
│ │ └── scripts_gpu
│ │ ├── download_boolq.py
│ │ ├── download_the_pile.py
│ │ ├── download_lambada.py
│ │ ├── checkpoint_utils.py
│ │ ├── run_pile_singlenode.sh
│ │ ├── run_lambada_singlenode.sh
│ │ ├── run_base_config_multinode.sh
│ │ ├── run_pile_multinode.sh
│ │ └── lora_utils.py
├── AUTHORS
├── pip_package
│ ├── cloudbuild-postsubmit.yaml
│ ├── cloudbuild-presubmit.yaml
│ ├── postsubmit.Dockerfile
│ ├── compile_requirements_helper.sh
│ ├── cloudbuild.yaml
│ ├── presubmit.Dockerfile
│ ├── cloudbuild-release.yaml
│ ├── collect_wheels.sh
│ ├── compile_requirements.sh
│ ├── build_pip_pkg.sh
│ ├── release.Dockerfile
│ ├── build.sh
│ ├── prepare_release.sh
│ └── Dockerfile
├── first_result_metric_callback.py
├── test_helper.py
├── lineage_logging.py
├── tools
│ ├── fiddle
│ │ ├── graphviz_utils_test.py
│ │ ├── wrap_nested_maps.py
│ │ ├── convert_seqio_task_objects_test.py
│ │ ├── codegen_pax_code_ir.py
│ │ ├── convert_seqio_task_objects.py
│ │ ├── wrap_nested_maps_test.py
│ │ ├── codegen_highlevel_parameterization_test.py
│ │ ├── graphviz_utils.py
│ │ ├── unshare_sharding.py
│ │ ├── remove_sharding.py
│ │ ├── make_parameterized_experiment.py
│ │ ├── codegen_highlevel_parameterization.py
│ │ ├── codegen_tracer.py
│ │ ├── codegen_external_init_checkpoint_fns_test.py
│ │ └── remove_sharding_test.py
│ ├── validate_config.py
│ ├── dump_hparams.py
│ ├── dump_input_specs.py
│ ├── BUILD
│ └── dump_input_specs_lib.py
├── experimental
│ ├── nested_map_config_helper_test.py
│ ├── nested_map_config_helper.py
│ └── BUILD
├── base_task.py
├── paxml.bzl
├── ml_monitoring.py
├── checkpoint_types.py
├── experiment_imports_all_test.py
├── partitioning_test.py
├── base_executor.py
├── experiment_vars_summary_parser.py
├── checkpoint_version.py
├── main_test.py
├── experiment_vars_summary_test.py
├── host_callback_test.py
├── setup_jax.py
├── ghostnorm
│ └── BUILD
├── base_inference_runner_test.py
├── profiling.py
├── base_inference_runner.py
├── xla_passthrough.py
├── train_states.py
├── xla_passthrough_test.py
└── host_callback.py
├── requirements.in
├── CONTRIBUTING.md
├── WORKSPACE
└── setup.py
/paxml/docs/images/16B-loss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/paxml/HEAD/paxml/docs/images/16B-loss.png
--------------------------------------------------------------------------------
/paxml/docs/images/16B-pplx.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/paxml/HEAD/paxml/docs/images/16B-pplx.png
--------------------------------------------------------------------------------
/paxml/docs/images/1B-loss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/paxml/HEAD/paxml/docs/images/1B-loss.png
--------------------------------------------------------------------------------
/paxml/docs/images/1B-pplx.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/paxml/HEAD/paxml/docs/images/1B-pplx.png
--------------------------------------------------------------------------------
/paxml/docs/images/GPT3-XL-loss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/paxml/HEAD/paxml/docs/images/GPT3-XL-loss.png
--------------------------------------------------------------------------------
/paxml/docs/images/GPT3-XL-pplx.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/paxml/HEAD/paxml/docs/images/GPT3-XL-pplx.png
--------------------------------------------------------------------------------
/paxml/tasks/lm/testdata/tfrecords:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/paxml/HEAD/paxml/tasks/lm/testdata/tfrecords
--------------------------------------------------------------------------------
/paxml/docs/images/Weak_scaling_of_large_language_model_training_on_TPU_v4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/paxml/HEAD/paxml/docs/images/Weak_scaling_of_large_language_model_training_on_TPU_v4.png
--------------------------------------------------------------------------------
/paxml/contrib/gpu/README.md:
--------------------------------------------------------------------------------
1 | # Pax on GPUs
2 | Please refer to [Rosetta](https://github.com/NVIDIA/JAX-Toolbox/tree/main/rosetta/rosetta/projects/pax), NVIDIA's project that enables seamless training of LLMs, CV models and multimodal models in JAX, for information about running experiments on GPUs in Pax.
3 |
--------------------------------------------------------------------------------
/paxml/AUTHORS:
--------------------------------------------------------------------------------
1 | # This is the list of Pax's significant contributors.
2 | #
3 | # This does not necessarily list everyone who has contributed code,
4 | # especially since many employees of one corporation may be contributing.
5 | # To see the full list of contributors, see the revision history in
6 | # source control.
7 | Google LLC
8 | NVIDIA Corporation
--------------------------------------------------------------------------------
/paxml/pip_package/cloudbuild-postsubmit.yaml:
--------------------------------------------------------------------------------
1 | steps:
2 | - name: 'gcr.io/cloud-builders/docker'
3 | args: [
4 | 'build',
5 | '--build-arg', 'image_name=${_IMAGE_NAME}',
6 | '-f', 'paxml/pip_package/postsubmit.Dockerfile', '.'
7 | ]
8 |
9 | substitutions:
10 | _PYTHON_VERSION: '3.10'
11 | _RELEASE_VERSION: 'nightly' # or rX.Y
12 | _IMAGE_NAME: 'paxml_${_RELEASE_VERSION}_${_PYTHON_VERSION}'
13 | options:
14 | dynamic_substitutions: true
15 | substitution_option: 'ALLOW_LOOSE'
16 | timeout: 1200s
17 |
--------------------------------------------------------------------------------
/paxml/pip_package/cloudbuild-presubmit.yaml:
--------------------------------------------------------------------------------
1 | steps:
2 | - name: 'gcr.io/cloud-builders/docker'
3 | args: [
4 | 'build',
5 | '--build-arg', 'image_name=${_IMAGE_NAME}',
6 | '-f', 'paxml/pip_package/presubmit.Dockerfile', '.'
7 | ]
8 |
9 | substitutions:
10 | _PYTHON_VERSION: '3.10'
11 | _RELEASE_VERSION: 'nightly' # or rX.Y
12 | _IMAGE_NAME: 'paxml_${_RELEASE_VERSION}_${_PYTHON_VERSION}'
13 | options:
14 | dynamic_substitutions: true
15 | substitution_option: 'ALLOW_LOOSE'
16 | machineType: E2_HIGHCPU_8
17 | timeout: 1200s
18 |
--------------------------------------------------------------------------------
/requirements.in:
--------------------------------------------------------------------------------
1 | # To update requirements.txt for praxis and paxml, run:
2 | # bash ./compile_requirements.sh
3 |
4 | absl-py
5 | clu @ git+https://github.com/google/CommonLoopUtils
6 | etils
7 | flax @ git+https://github.com/google/flax
8 | graphviz
9 | jax @ git+https://github.com/google/jax
10 | lingvo
11 | numpy
12 | orbax-checkpoint @ git+https://github.com/google/orbax/#subdirectory=checkpoint
13 | praxis
14 | protobuf==3.19.6
15 | pyglove
16 | seqio-nightly
17 | t5
18 | tensorflow~=2.9.2
19 | tensorflow-text~=2.9.0
20 | tensorstore
21 | tensorflow-datasets==4.8.3
22 | tfds-nightly==4.8.3.dev202303280045
23 | tensorflow-metadata==1.12.0
24 |
--------------------------------------------------------------------------------
/paxml/tasks/lm/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
17 |
--------------------------------------------------------------------------------
/paxml/first_result_metric_callback.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | def train_first_result_callback_fn(metric: bool) -> None:
17 | return
18 |
--------------------------------------------------------------------------------
/paxml/contrib/gpu/scripts_gpu/download_boolq.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import tensorflow_datasets as tfds
17 |
18 | # This will download 'super_glue/boolq' to TFDS_DATA_DIR (environment variable).
19 | ds = tfds.load('super_glue/boolq')
20 |
--------------------------------------------------------------------------------
/paxml/contrib/gpu/scripts_gpu/download_the_pile.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import paxml.contrib.gpu.scripts_gpu.tfds_pile
17 | import tensorflow_datasets as tfds
18 |
19 | # This will download 'ThePile' to TFDS_DATA_DIR (environment variable).
20 | ds = tfds.load('ThePile')
21 |
--------------------------------------------------------------------------------
/paxml/test_helper.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Helpers for unit tests."""
17 |
18 | import os
19 |
20 |
21 | def test_src_dir_path(relative_path: str) -> str:
22 | return os.path.join(os.environ['TEST_SRCDIR'], '__main__/paxml',
23 | relative_path)
24 |
--------------------------------------------------------------------------------
/paxml/contrib/gpu/scripts_gpu/download_lambada.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import paxml.contrib.gpu.scripts_gpu.tfds_lambada
17 | import tensorflow_datasets as tfds
18 |
19 | # This will download 'MyLambada' to TFDS_DATA_DIR (environment variable).
20 | ds = tfds.load('MyLambada')
21 |
--------------------------------------------------------------------------------
/paxml/lineage_logging.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Utils for lineage logging."""
17 |
18 | # internal lineage logging module import
19 |
20 |
21 | def log_config_file_lineage(filepath: str) -> None:
22 | """Logs config file lineage."""
23 | del filepath
24 | # Internal lineage logging implementation
25 |
--------------------------------------------------------------------------------
/paxml/pip_package/postsubmit.Dockerfile:
--------------------------------------------------------------------------------
1 | ARG image_name
2 | ARG base_image="gcr.io/pax-on-cloud-project/${image_name}:latest"
3 | FROM $base_image
4 |
5 | RUN rm -rf /praxis && rm -rf /paxml/paxml && rm -rf /paxml/praxis
6 | COPY . /paxml_new
7 | RUN git clone https://github.com/google/praxis.git
8 | RUN mv /praxis/praxis /paxml/ && mv /paxml_new/paxml /paxml/
9 | RUN pip3 uninstall -y fiddle
10 | RUN pip3 uninstall -y seqio
11 | RUN pip3 uninstall -y flax
12 | RUN pip3 uninstall -y jax
13 | RUN pip3 install --no-deps -r /paxml/paxml/pip_package/requirements.txt
14 |
15 | RUN cd /paxml && bazel build ...
16 | RUN cd /paxml && \
17 | bazel test \
18 | --test_output=all \
19 | --test_verbose_timeout_warnings \
20 | -- \
21 | paxml/... \
22 | -paxml/tasks/lm/params:c4_test \
23 | -paxml/tasks/vision:input_generator_test \
24 | -paxml:checkpoint_managers_test \
25 | -paxml:seqio_input_test \
26 | -paxml:tasks_lib_test
27 |
28 | WORKDIR /
29 |
30 | CMD ["/bin/bash"]
31 |
--------------------------------------------------------------------------------
/paxml/pip_package/compile_requirements_helper.sh:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | #!/bin/bash
17 |
18 | set -e -x
19 |
20 | pip3 install -U pip-tools
21 | cd /tmp/requirements
22 | pip-compile --quiet --output-file paxml-requirements.txt praxis-requirements.in paxml-requirements.in
23 | pip-compile --quiet --output-file praxis-requirements.txt praxis-requirements.in
24 |
--------------------------------------------------------------------------------
/paxml/tools/fiddle/graphviz_utils_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for graphviz."""
17 |
18 | from absl.testing import absltest
19 | from paxml.tools.fiddle import graphviz_utils
20 | from paxml.tools.fiddle import test_fixtures
21 |
22 |
23 | class GraphvizTest(absltest.TestCase):
24 |
25 | def test_smoke_render(self):
26 | config = test_fixtures.SampleExperimentNewBaseline().experiment_fixture()
27 | graphviz_utils.render(config=config)
28 |
29 |
30 | if __name__ == "__main__":
31 | absltest.main()
32 |
--------------------------------------------------------------------------------
/paxml/tools/validate_config.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | r"""A simple utility to validate an experiment config.
17 |
18 | The binary target `:validate_config` is defined by `pax_targets()` in the `BUILD`
19 | file.
20 |
21 | Example commandline:
22 | bazel run //PATH/TO/PAX/TARGETS:validate_config -- \
23 | --exp=lm1b.Lm1bTransformerL32H8kSPMD8x8Repeat \
24 | --completeness=light
25 | """
26 |
27 | from paxml.tools import validate_config_lib
28 |
29 |
30 | if __name__ == '__main__':
31 | validate_config_lib.main()
32 |
33 |
--------------------------------------------------------------------------------
/paxml/pip_package/cloudbuild.yaml:
--------------------------------------------------------------------------------
1 | steps:
2 | - name: 'gcr.io/cloud-builders/docker'
3 | args: [
4 | 'build',
5 | '--build-arg', 'wheel_folder=${_WHEEL_FOLDER}',
6 | '-t', 'gcr.io/${PROJECT_ID}/${_IMAGE_NAME}',
7 | '-f', 'paxml/pip_package/Dockerfile', '.'
8 | ]
9 | timeout: 3600s
10 | - name: 'gcr.io/cloud-builders/docker'
11 | args: ['push', '--all-tags', 'gcr.io/${PROJECT_ID}/${_IMAGE_NAME}']
12 | timeout: 1800s
13 | - name: 'gcr.io/${PROJECT_ID}/${_IMAGE_NAME}'
14 | entrypoint: 'bash'
15 | args: ['-c', 'source paxml/pip_package/collect_wheels.sh && collect_wheels ${_RELEASE_VERSION} ${_WHEEL_FOLDER}']
16 |
17 | substitutions:
18 | _PYTHON_VERSION: '3.10'
19 | _RELEASE_VERSION: 'nightly' # or rX.Y
20 | _IMAGE_NAME: 'paxml_${_RELEASE_VERSION}_${_PYTHON_VERSION}'
21 | _WHEEL_FOLDER: '/tmp/wheels'
22 | options:
23 | dynamic_substitutions: true
24 | substitution_option: 'ALLOW_LOOSE'
25 | machineType: E2_HIGHCPU_32
26 | timeout: 5400s
27 | artifacts:
28 | objects:
29 | location: 'gs://pax-on-cloud-tpu-project/wheels/$(date -u +%Y%m%d)'
30 | paths: ['/**/*.whl', '/**/*.txt']
31 |
--------------------------------------------------------------------------------
/paxml/pip_package/presubmit.Dockerfile:
--------------------------------------------------------------------------------
1 | ARG image_name
2 | ARG base_image="gcr.io/pax-on-cloud-project/${image_name}:latest"
3 | FROM $base_image
4 |
5 | RUN rm -rf /praxis && rm -rf /paxml/paxml && rm -rf /paxml/praxis
6 | COPY . /paxml_new
7 | RUN git clone https://github.com/google/praxis.git
8 | RUN mv /praxis/praxis /paxml/ && mv /paxml_new/paxml /paxml/
9 | RUN pip3 uninstall -y fiddle
10 | RUN pip3 uninstall -y seqio
11 | RUN pip3 uninstall -y flax
12 | RUN pip3 uninstall -y jax
13 | RUN pip3 install --no-deps -r /paxml/paxml/pip_package/requirements.txt
14 | RUN cd /paxml && bazel build ...
15 |
16 | # RUN cd /paxml && bazel test paxml/... --test_output=all --test_verbose_timeout_warnings
17 | # RUN cd /paxml && bazel test paxml:tasks_lib_test --test_output=all --test_verbose_timeout_warnings
18 | RUN cd /paxml && \
19 | bazel test \
20 | --test_output=all \
21 | --test_verbose_timeout_warnings \
22 | -- \
23 | paxml/... \
24 | -paxml/tasks/lm/params:c4_test \
25 | -paxml/tasks/vision:input_generator_test \
26 | -paxml:checkpoint_managers_test \
27 | -paxml:seqio_input_test \
28 | -paxml:tasks_lib_test
29 | WORKDIR /
30 |
31 | CMD ["/bin/bash"]
32 |
--------------------------------------------------------------------------------
/paxml/experimental/nested_map_config_helper_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for nested_map_config_helper."""
17 |
18 | from absl.testing import absltest
19 | from paxml.experimental import nested_map_config_helper
20 |
21 |
22 | class NestedMapConfigHelperTest(absltest.TestCase):
23 |
24 | def test_make_nested_map(self):
25 | result = nested_map_config_helper.make_nested_map(base={"foo": 1, "bar": 2})
26 | self.assertEqual(result.foo, 1)
27 | self.assertEqual(result.bar, 2)
28 |
29 |
30 | if __name__ == "__main__":
31 | absltest.main()
32 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | We'd love to accept your patches and contributions to this project. There are
4 | just a few small guidelines you need to follow.
5 |
6 | ## Contributor License Agreement
7 |
8 | Contributions to this project must be accompanied by a Contributor License
9 | Agreement. You (or your employer) retain the copyright to your contribution;
10 | this simply gives us permission to use and redistribute your contributions as
11 | part of the project. Head over to to see
12 | your current agreements on file or to sign a new one.
13 |
14 | You generally only need to submit a CLA once, so if you've already submitted one
15 | (even if it was for a different project), you probably don't need to do it
16 | again.
17 |
18 | ## Code reviews
19 |
20 | All submissions, including submissions by project members, require review. We
21 | use GitHub pull requests for this purpose. Consult
22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
23 | information on using pull requests.
24 |
25 | ## Community Guidelines
26 |
27 | This project follows
28 | [Google's Open Source Community Guidelines](https://opensource.google.com/conduct/).
29 |
--------------------------------------------------------------------------------
/paxml/pip_package/cloudbuild-release.yaml:
--------------------------------------------------------------------------------
1 | steps:
2 | - name: 'gcr.io/cloud-builders/docker'
3 | args: [
4 | 'build',
5 | '-t', 'gcr.io/${PROJECT_ID}/${_IMAGE_NAME}',
6 | '-f', 'paxml/pip_package/release.Dockerfile', '.',
7 | '--build-arg', 'wheel_folder=${_WHEEL_FOLDER}',
8 | '--build-arg', 'praxis_version=${_PRAXIS_VERSION}',
9 | ]
10 | timeout: 3600s
11 | - name: 'gcr.io/cloud-builders/docker'
12 | args: ['push', '--all-tags', 'gcr.io/${PROJECT_ID}/${_IMAGE_NAME}']
13 | timeout: 1800s
14 | - name: 'gcr.io/${PROJECT_ID}/${_IMAGE_NAME}'
15 | entrypoint: 'bash'
16 | args: ['-c', 'mv ${_WHEEL_FOLDER}/*.whl .']
17 |
18 | substitutions:
19 | _PYTHON_VERSION: '3.10'
20 | _RELEASE_VERSION: '1.4.0' # or rX.Y
21 | _PRAXIS_VERSION: '1.4.0' # or rX.Y
22 | _IMAGE_NAME: 'paxml_${_RELEASE_VERSION}_${_PYTHON_VERSION}'
23 | _WHEEL_FOLDER: '/tmp/wheels'
24 | options:
25 | dynamic_substitutions: true
26 | substitution_option: 'ALLOW_LOOSE'
27 | machineType: E2_HIGHCPU_8
28 | timeout: 5400s
29 | artifacts:
30 | objects:
31 | location: 'gs://pax-on-cloud-tpu-project/wheels/$(date -u +%Y%m%d)-paxml-${_RELEASE_VERSION}'
32 | paths: ['/**/*.whl']
33 |
--------------------------------------------------------------------------------
/paxml/tasks/test/BUILD:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | # Description:
17 | # Test model configurations.
18 |
19 | load("@rules_python//python:defs.bzl", "py_library")
20 | load("//praxis:build-visibility.bzl", "JAX_VISIBILITY")
21 |
22 | package(default_visibility = JAX_VISIBILITY)
23 |
24 | licenses(["notice"])
25 |
26 | py_library(
27 | name = "synthetic",
28 | srcs = [
29 | "synthetic.py",
30 | ],
31 | tags = ["keep_dep"],
32 | deps = [
33 | "//paxml:base_experiment",
34 | "//paxml:experiment_registry",
35 | "//praxis:pax_fiddle",
36 | "//praxis/layers",
37 | ],
38 | )
39 |
--------------------------------------------------------------------------------
/paxml/base_task.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Base class for all tasks.
17 |
18 | A model solely consists of the network, while a task combines one or several
19 | models with one or several learners/optimizers.
20 | """
21 |
22 | from __future__ import annotations
23 |
24 | import abc
25 |
26 | from praxis import base_hyperparams
27 |
28 |
29 | class BaseTask(
30 | base_hyperparams.FiddleBaseParameterizable, metaclass=abc.ABCMeta
31 | ):
32 | """Abstract base class for all tasks."""
33 |
34 | def __post_init__(self):
35 | assert self.name, (
36 | 'Task params for %s must have a "name"' % self.__class__.__name__
37 | )
38 |
--------------------------------------------------------------------------------
/paxml/tasks/test/synthetic.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Test model configuration using synthetic data."""
17 |
18 | from paxml import base_experiment
19 | from paxml import experiment_registry
20 | from praxis import layers
21 | from praxis import pax_fiddle
22 |
23 |
24 | @experiment_registry.register
25 | class SyntheticClassifier(base_experiment.BaseExperiment):
26 | # TODO(shafey): Implement a real test model.
27 |
28 | def datasets(self):
29 | return []
30 |
31 | def task(self):
32 | act_p = pax_fiddle.Config(layers.Identity)
33 | return act_p
34 |
35 |
36 | @experiment_registry.register
37 | class SharedNameExperiment(SyntheticClassifier):
38 | pass
39 |
--------------------------------------------------------------------------------
/paxml/pip_package/collect_wheels.sh:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | #!/bin/bash
17 |
18 | function collect_wheels() {
19 | release_version=$1
20 | wheel_version="${release_version}"
21 | wheel_folder=$2
22 | if [ "${release_version}" != "nightly" ]; then
23 | wheel_version=$( echo "${release_version}" | grep -oP '\d+.\d+(.\d+)?' )
24 | fi
25 |
26 | mkdir /tmp/staging-wheels
27 | pushd /tmp/staging-wheels
28 | cp $wheel_folder/*.whl .
29 | rename -v "s/^paxml-(.*?)-py3/paxml-${wheel_version}+$(date -u +%Y%m%d)-py3/" *.whl
30 | rename -v "s/^praxis-(.*?)-py3/praxis-${wheel_version}+$(date -u +%Y%m%d)-py3/" *.whl
31 | popd
32 | mv /tmp/staging-wheels/* .
33 | mv $wheel_folder/*.txt .
34 | }
35 |
--------------------------------------------------------------------------------
/paxml/paxml.bzl:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Implements custom rules for Paxml."""
17 |
18 | # Placeholder to use until bazel supports pytype_*.
19 | def pytype_library(name, **kwargs):
20 | native.py_library(name = name, **kwargs)
21 |
22 | def pytype_strict_library(name, **kwargs):
23 | native.py_library(name = name, **kwargs)
24 |
25 | def pytype_binary(name, **kwargs):
26 | native.py_binary(name = name, **kwargs)
27 |
28 | def pytype_strict_binary(name, **kwargs):
29 | native.py_binary(name = name, **kwargs)
30 |
31 | def pytype_strict_test(name, **kwargs):
32 | native.py_test(name = name, **kwargs)
33 |
34 | # Placeholder to use until bazel supports py_strict_test.
35 | def py_strict_test(name, **kwargs):
36 | native.py_test(name = name, **kwargs)
37 |
--------------------------------------------------------------------------------
/paxml/experimental/nested_map_config_helper.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Simple shims for Fiddle configs to reduce custom objects.
17 |
18 | Inserting custom objects (anything that's not a Fiddle Buildable, list, tuple,
19 | dict, enum, or primitive) into configuration can limit use of configs with
20 | Fiddle's various tooling.
21 |
22 | Therefore, we have a very simple helper function to convert primitive dicts
23 | (which can live in the config) to NestedMap's (which we want after building the
24 | config).
25 | """
26 |
27 | from typing import Any
28 |
29 | from praxis import pytypes
30 |
31 |
32 | def make_nested_map(base: dict[str, Any]) -> pytypes.NestedMap:
33 | """Converts a dict to a NestedMap.
34 |
35 | Args:
36 | base: The dict to convert.
37 |
38 | Returns:
39 | A NestedMap.
40 | """
41 | return pytypes.NestedMap(base)
42 |
--------------------------------------------------------------------------------
/paxml/tools/dump_hparams.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | r"""A simple utility to dump experiment hparams to stdout or a txt file.
17 |
18 | The binary target `:dump_hparams` is defined by `pax_targets()` in the `BUILD`
19 | file.
20 |
21 | Example commandline:
22 | python paxml/tools/dump_hparams.py \
23 | --exp=tasks.lm.params.lm_cloud.LmCloudTransformerAdamTest \
24 | --params_ofile=/tmp/bert.txt
25 |
26 | Alternatively, omitting --params_ofile just prints to stdout, which might be
27 | useful in shell pipelines.
28 |
29 | You may also additionally specify --post_init_params_ofile, to write a second
30 | file consisting of the post-init model params (this takes longer to generate):
31 | --post_init_params_ofile=/tmp/lm_post.txt
32 | """
33 |
34 | from paxml.tools import dump_hparams_lib
35 |
36 |
37 | if __name__ == '__main__':
38 | dump_hparams_lib.main()
39 |
--------------------------------------------------------------------------------
/paxml/tools/fiddle/wrap_nested_maps.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Wraps NestedMap instances into config sub-graphs that produce them.
17 |
18 | Custom objects like NestedMap can cause problems, so we rewrite them into a
19 | fdl.Config of a helper function.
20 | """
21 |
22 | from fiddle import daglish
23 | from paxml.experimental import nested_map_config_helper
24 | from praxis import pax_fiddle
25 | from praxis import pytypes
26 |
27 |
28 | def wrap_nested_maps(config):
29 | """Wraps NestedMap instances into config sub-graphs that produce them."""
30 |
31 | def traverse(value, state: daglish.State):
32 | if isinstance(value, pytypes.NestedMap):
33 | value = pax_fiddle.Config(
34 | nested_map_config_helper.make_nested_map, base=dict(value)
35 | )
36 | return state.map_children(value)
37 |
38 | return daglish.MemoizedTraversal.run(traverse, config)
39 |
--------------------------------------------------------------------------------
/paxml/ml_monitoring.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """ML Monitoring for PAX."""
17 |
18 | import contextlib
19 | import enum
20 |
21 |
22 | class MlEvent(enum.Enum):
23 | """ML events to be recorded."""
24 |
25 | INITIALIZE_BACKEND = enum.auto()
26 | INITIALIZE_SETUP = enum.auto()
27 | MAIN_LOOP = enum.auto()
28 | TRAIN_STEP = enum.auto()
29 | EVAL_STEP = enum.auto()
30 | DECODE_STEP = enum.auto()
31 |
32 |
33 | class EventBoundary(enum.Enum):
34 | """Event boundary to be recorded."""
35 |
36 | START = enum.auto()
37 | END = enum.auto()
38 |
39 |
40 | def record_step_number(step_number: int):
41 | """Records the step number."""
42 | pass
43 |
44 |
45 | def record_event_boundary(event: MlEvent, boundary: EventBoundary, **kwargs):
46 | """Records the event boundary."""
47 | pass
48 |
49 |
50 | @contextlib.contextmanager
51 | def ml_event_logger(event: MlEvent, **kwargs):
52 | yield
53 |
--------------------------------------------------------------------------------
/paxml/tools/fiddle/convert_seqio_task_objects_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for convert_seqio_task_objects."""
17 |
18 | from absl.testing import absltest
19 | import fiddle as fdl
20 | from paxml.tools.fiddle import convert_seqio_task_objects
21 | from praxis import pax_fiddle
22 | import seqio
23 | from seqio import test_utils
24 |
25 |
26 | class ConvertSeqioTaskObjectsTest(test_utils.FakeTaskTest):
27 |
28 | def test_convert_seqio_tasks(self):
29 | config = {"task": seqio.get_mixture_or_task("tfds_task")}
30 | transformed = convert_seqio_task_objects.convert_seqio_tasks(config=config)
31 | self.assertIsInstance(transformed["task"], pax_fiddle.Config)
32 | self.assertEqual(
33 | fdl.ordered_arguments(transformed["task"]), # pytype: disable=wrong-arg-types
34 | {"task_or_mixture_name": "tfds_task"},
35 | )
36 |
37 |
38 | if __name__ == "__main__":
39 | absltest.main()
40 |
--------------------------------------------------------------------------------
/paxml/docs/README.md:
--------------------------------------------------------------------------------
1 | # Pax Basics (Start here)
2 |
3 | The **Pax Basics (Start here)** section is designed to provide all the of basics
4 | concepts necessary to understand the rest of the site.
5 |
6 | | **Section** | **Description** |
7 | | -------------------------------- | ----------------------------------------- |
8 | | [About Pax][about-pax] | A gentle introduction to what Pax is, why its important, and its major components. |
9 | | [Learning Pax][learning-pax] | An introduction to the Pax Components along with resources for learning more. |
10 | | [Key Concepts][concepts] | A set of concepts Pax users will frequently encounter. |
11 | | [Tutorials][tutorials] | Notebook examples for hands-on introduction. |
12 | | [Life of a Pax Experiment][life] (coming soon)| The main concepts involved in defining and running machine learning experiments using Pax. |
13 |
14 | Armed with the Pax concepts presented in these pages, you should be able to
15 | grasp all of the user journeys and how-to's presented in this site.
16 |
17 |
18 |
19 |
20 | [about-pax]: https://github.com/google/paxml/tree/main/paxml/docs/about-pax.md
21 | [life]: https://github.com/google/paxml/tree/main/paxml/docs/life_of_an_experiment.md
22 | [concepts]: https://github.com/google/paxml/tree/main/paxml/docs/concepts.md
23 | [learning-pax]: https://github.com/google/paxml/tree/main/paxml/docs/learning-pax.md
24 | [tutorials]: https://github.com/google/paxml/tree/main/paxml/docs/hands-on-tutorials.md
--------------------------------------------------------------------------------
/paxml/contrib/gpu/scripts_gpu/checkpoint_utils.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | from abc import ABC
17 |
18 | from paxml import tasks_lib
19 | from praxis import pax_fiddle
20 |
21 |
22 | class CheckpointRestoreMixin(ABC):
23 | CHECKPOINT_RESTORE_PATH = None
24 | CHECKPOINT_IGNORE_RULES = None
25 |
26 | def configure_checkpoint_restore(
27 | self, task_p: pax_fiddle.Config[tasks_lib.SingleTask]
28 | ) -> pax_fiddle.Config[tasks_lib.SingleTask]:
29 | train_p = task_p.train
30 |
31 | if self.CHECKPOINT_RESTORE_PATH:
32 | train_p.init_from_checkpoint_rules = {
33 | self.CHECKPOINT_RESTORE_PATH: tasks_lib.CheckpointLoadingRules(
34 | task_p=task_p.clone(),
35 | load_rules=[("(.*)", "{}")],
36 | ignore_rules=self.CHECKPOINT_IGNORE_RULES,
37 | input_specs_provider_p=self.get_input_specs_provider_params(),
38 | )
39 | }
40 |
41 | return task_p
42 |
--------------------------------------------------------------------------------
/paxml/tools/fiddle/codegen_pax_code_ir.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """A few small code IR nodes for Pax."""
17 |
18 | import dataclasses
19 | from typing import Any
20 |
21 | from fiddle._src.codegen.auto_config import code_ir
22 | import libcst as cst
23 |
24 |
25 | @dataclasses.dataclass
26 | class PaxCodegenTask(code_ir.CodegenTask):
27 | """CodegenTask that tracks a few extra bits of state.
28 |
29 | Attributes:
30 | highlevel_accesses: Accesses to high-level settings. These become fields on
31 | the generated class.
32 | sharding_diff_module: CST module containing a fiddler that will re-add
33 | sharding to a model. Factoring our fixtures into ones that generate an
34 | unsharded module, and a function that re-adds the sharding can be more
35 | readable. (You can disable this in `codegen.py`.)
36 | """
37 |
38 | highlevel_accesses: dict[str, Any] = dataclasses.field(default_factory=dict)
39 | sharding_diff_module: cst.Module | None = None
40 |
--------------------------------------------------------------------------------
/paxml/tasks/vision/BUILD:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | # Description:
17 | # Vision modeling-specific libraries and model configurations
18 |
19 | load("//paxml:paxml.bzl", "pytype_strict_library")
20 | load("//paxml:paxml.bzl", "py_strict_test")
21 | load("//praxis:build-visibility.bzl", "JAX_VISIBILITY")
22 |
23 | licenses(["notice"])
24 |
25 | package(default_visibility = JAX_VISIBILITY)
26 |
27 | pytype_strict_library(
28 | name = "input_generator",
29 | srcs = [
30 | "input_generator.py",
31 | "resnet_preprocessing.py",
32 | ],
33 | deps = [
34 | # Implicit absl.logging dependency.
35 | # Implicit tensorflow_no_contrib dependency.
36 | ],
37 | )
38 |
39 | py_strict_test(
40 | name = "input_generator_test",
41 | srcs = ["input_generator_test.py"],
42 | tags = [
43 | "external",
44 | "notap",
45 | "requires-net:external",
46 | ],
47 | deps = [
48 | ":input_generator",
49 | # Implicit absl.testing.absltest.absltest dependency.
50 | ],
51 | )
52 |
--------------------------------------------------------------------------------
/paxml/checkpoint_types.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Module defining possible checkpoint types and utility methods."""
17 |
18 | import enum
19 | from etils import epath
20 | from paxml import base_task
21 | from praxis import pax_fiddle
22 | from praxis import py_utils
23 |
24 |
25 | @enum.unique
26 | class CheckpointType(str, enum.Enum):
27 | """The type of the checkpointing format."""
28 |
29 | UNSPECIFIED = 'unspecified'
30 | FLAX = 'flax'
31 | GDA = 'gda'
32 | PERSISTENCE = 'persistence'
33 |
34 |
35 | def retrieve_checkpoint_type(
36 | maybe_use_persistence_checkpointing,
37 | task: base_task.BaseTask | pax_fiddle.Config[base_task.BaseTask],
38 | ) -> CheckpointType:
39 | """Retrieves the CheckpointType given the input arguments."""
40 | using_pjit = task.model.mesh_shape is not None # pytype: disable=attribute-error
41 | if using_pjit or py_utils.pmap_use_tensorstore():
42 | if maybe_use_persistence_checkpointing:
43 | return CheckpointType.PERSISTENCE
44 | else:
45 | return CheckpointType.GDA
46 | else:
47 | # pmap uses FLAX, Persistence-based or not.
48 | return CheckpointType.FLAX
49 |
--------------------------------------------------------------------------------
/paxml/pip_package/compile_requirements.sh:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | #!/bin/bash
17 | # This script generates new requirements.txt for Praxis and Paxml
18 | # It pulls the nightly build docker image, and re-compile requirements.in
19 |
20 | set -e -x
21 |
22 | export TMP_FOLDER="$HOME/tmp/requirements"
23 |
24 | [ -f $TMP_FOLDER ] && rm -rf $TMP_FOLDER
25 | mkdir -p $TMP_FOLDER
26 | cp ../../paxml/pip_package/requirements.in $TMP_FOLDER/paxml-requirements.in
27 | cp ../../praxis/pip_package/requirements.in $TMP_FOLDER/praxis-requirements.in
28 | cp ./compile_requirements_helper.sh $TMP_FOLDER/
29 | sed -i 's/praxis/#praxis/' $TMP_FOLDER/paxml-requirements.in
30 |
31 | docker pull gcr.io/pax-on-cloud-project/paxml_nightly_3.10:latest
32 | docker run --rm -a stdin -a stdout -a stderr -v $TMP_FOLDER:/tmp/requirements \
33 | --name container1 gcr.io/pax-on-cloud-project/paxml_nightly_3.10:latest \
34 | bash /tmp/requirements/compile_requirements_helper.sh
35 |
36 | cp $TMP_FOLDER/paxml-requirements.txt ../../paxml/pip_package/requirements.txt
37 | cp $TMP_FOLDER/praxis-requirements.txt ../../praxis/pip_package/requirements.txt
38 |
39 | rm -rf $TMP_FOLDER
40 |
--------------------------------------------------------------------------------
/paxml/tools/dump_input_specs.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Dumps the input specs of an experiment config."""
17 |
18 | from collections.abc import Sequence
19 |
20 | from absl import app
21 | from absl import flags
22 | from paxml import experiment_registry
23 | from paxml.tools import dump_input_specs_lib
24 | import tensorflow.compat.v2 as tf
25 |
26 | _EXP = flags.DEFINE_string('exp', None, 'A registered experiment name.')
27 | _OUTPUT_FILENAME = flags.DEFINE_string(
28 | 'output_filename', None, 'Output filename for dumping the input_specs.')
29 |
30 | FLAGS = flags.FLAGS
31 |
32 |
33 | def main(argv: Sequence[str]) -> None:
34 | if len(argv) > 1:
35 | raise app.UsageError('Too many command-line arguments.')
36 |
37 | experiment_config = experiment_registry.get(_EXP.value)()
38 |
39 | specs = dump_input_specs_lib.extract_input_specs(experiment_config)
40 | out_str = dump_input_specs_lib.specs_to_string(FLAGS.exp, specs)
41 | with tf.io.gfile.GFile(_OUTPUT_FILENAME.value, 'w') as fout:
42 | fout.write(out_str)
43 | print(out_str)
44 |
45 |
46 | if __name__ == '__main__':
47 | flags.mark_flags_as_required(['exp', 'output_filename'])
48 | app.run(main)
49 |
--------------------------------------------------------------------------------
/paxml/tasks/vision/input_generator_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for input_generator."""
17 |
18 | from absl.testing import absltest
19 | from paxml.tasks.vision import input_generator
20 |
21 |
22 | class InputGeneratorTest(absltest.TestCase):
23 |
24 | def _check_data_shape(self,
25 | p,
26 | is_multlabels=False,
27 | batch_size=8,
28 | image_size=33):
29 | p.data_shape = (image_size, image_size, 3)
30 | p.batch_size = batch_size
31 | p.num_batcher_threads = 1
32 | p.file_parallelism = 1
33 | p.file_buffer_size = 1
34 | inp = p.Instantiate()
35 | batch = inp.GetPreprocessedInputBatch()
36 | b, (h, w, d) = p.batch_size, p.data_shape
37 | self.assertEqual(batch.image.shape, (b, h, w, d))
38 | self.assertEqual(batch.label_probs.shape, (b, p.num_classes))
39 |
40 | def testImageNetValidation(self):
41 | p = input_generator.ImageNetValidation.Params()
42 | self._check_data_shape(p)
43 |
44 | def testImageNetTrain(self):
45 | p = input_generator.ImageNetTrain.Params()
46 | self._check_data_shape(p)
47 |
48 |
49 | if __name__ == '__main__':
50 | absltest.main()
51 |
--------------------------------------------------------------------------------
/paxml/experiment_imports_all_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Test experiment configurations import and construction."""
17 |
18 | from absl import app
19 | from absl import flags
20 | from absl.testing import absltest
21 | import jax
22 | from paxml import experiment_imports_test_helper
23 | from paxml import experiment_registry
24 |
25 | flags.DEFINE_list(
26 | 'exclude_regexes', [],
27 | 'Exclusion regexes of experiment configurations to be passed to the smoke '
28 | 'test. The matching experiment configurations will be disabled from the '
29 | 'smoke test.')
30 | flags.DEFINE_list(
31 | 'include_only_regexes', [],
32 | 'If provided, only experiments with names matching these regexes will be '
33 | 'tested.')
34 |
35 | FLAGS = flags.FLAGS
36 |
37 |
38 | class Test(experiment_imports_test_helper.ExperimentImportsTestHelper):
39 | pass
40 |
41 |
42 | def main(args):
43 | del args # Unused.
44 |
45 | n = Test.create_test_methods_for_all_registered_experiments(
46 | experiment_registry,
47 | task_regexes=[''],
48 | exclude_regexes=FLAGS.exclude_regexes,
49 | include_only_regexes=FLAGS.include_only_regexes)
50 | assert n > 0, 'No experiment registered!'
51 |
52 | absltest.main()
53 |
54 |
55 | if __name__ == '__main__':
56 | jax.config.parse_flags_with_absl()
57 | app.run(main)
58 |
--------------------------------------------------------------------------------
/paxml/tools/fiddle/convert_seqio_task_objects.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Converts SeqIO task objects to configs referencing their name."""
17 |
18 | from typing import TypeVar
19 |
20 | from fiddle import daglish
21 | from praxis import pax_fiddle
22 | import seqio
23 |
24 | _T = TypeVar("_T")
25 |
26 |
27 | def convert_seqio_tasks(config: _T) -> _T:
28 | """Converts SeqIO Task objects within a config to more config-like objects.
29 |
30 | Currently, SeqIO tasks are identifiable by their name, which appears in a
31 | global registry. This is not an ideal pattern, but since it's hard to convert
32 | back from a Task object into a config producing that Task, using the name
33 | seems to be a reasonable strategy for now.
34 |
35 | Args:
36 | config: Fiddle config, or nested structure of configs.
37 |
38 | Returns:
39 | Version of config without SeqIO task instances.
40 | """
41 |
42 | def transform(value, state: daglish.State):
43 | if isinstance(value, seqio.Task):
44 | # Note: The following object has a `task_or_mixture_name` parameter
45 | # instead of a `name` parameter, effectively changing the config API
46 | # slightly.
47 | return pax_fiddle.Config(seqio.get_mixture_or_task, value.name)
48 | return state.map_children(value)
49 |
50 | return daglish.MemoizedTraversal.run(transform, config)
51 |
--------------------------------------------------------------------------------
/paxml/tools/fiddle/wrap_nested_maps_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for wrap_nested_maps."""
17 |
18 | from absl.testing import absltest
19 | import fiddle.testing
20 | from paxml.experimental import nested_map_config_helper
21 | from paxml.tools.fiddle import wrap_nested_maps
22 | from praxis import pax_fiddle
23 | from praxis import pytypes
24 |
25 |
26 | class WrapNestedMapsTest(fiddle.testing.TestCase):
27 |
28 | def test_wrap_nested_maps(self):
29 | shared = {"value": 1}
30 | config = [
31 | pytypes.NestedMap(
32 | foo={
33 | "subdict": pytypes.NestedMap(bar=shared),
34 | "another_value": (shared,),
35 | }
36 | )
37 | ]
38 | result = wrap_nested_maps.wrap_nested_maps(config=config)
39 | shared2 = {"value": 1}
40 | expected = [
41 | pax_fiddle.Config(
42 | nested_map_config_helper.make_nested_map,
43 | base={
44 | "foo": {
45 | "subdict": pax_fiddle.Config(
46 | nested_map_config_helper.make_nested_map,
47 | base={"bar": shared2},
48 | ),
49 | "another_value": (shared2,),
50 | }
51 | },
52 | )
53 | ]
54 | self.assertDagEqual(result, expected)
55 |
56 |
57 | if __name__ == "__main__":
58 | absltest.main()
59 |
--------------------------------------------------------------------------------
/paxml/experimental/BUILD:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | # Experimental packages.
17 |
18 | load("//paxml:paxml.bzl", "pytype_strict_library", "pytype_strict_test")
19 | load("//praxis:build-visibility.bzl", "JAX_VISIBILITY")
20 |
21 | package(default_visibility = JAX_VISIBILITY)
22 |
23 | licenses(["notice"])
24 |
25 | pytype_strict_library(
26 | name = "baseline_experiment",
27 | srcs = ["baseline_experiment.py"],
28 | deps = [
29 | # Implicit fiddle dependency.
30 | "//paxml:parameterized_experiment",
31 | "//praxis:pax_fiddle",
32 | ],
33 | )
34 |
35 | pytype_strict_library(
36 | name = "nested_map_config_helper",
37 | srcs = ["nested_map_config_helper.py"],
38 | deps = ["//praxis:pytypes"],
39 | )
40 |
41 | pytype_strict_test(
42 | name = "nested_map_config_helper_test",
43 | srcs = ["nested_map_config_helper_test.py"],
44 | deps = [
45 | ":nested_map_config_helper",
46 | # Implicit absl.testing.absltest.absltest dependency.
47 | ],
48 | )
49 |
50 | pytype_strict_test(
51 | name = "baseline_experiment_test",
52 | srcs = ["baseline_experiment_test.py"],
53 | deps = [
54 | ":baseline_experiment",
55 | # Implicit absl.testing.absltest.absltest dependency.
56 | "//paxml:parameterized_experiment",
57 | "//paxml:tasks_lib",
58 | "//praxis:base_model",
59 | "//praxis:pax_fiddle",
60 | ],
61 | )
62 |
--------------------------------------------------------------------------------
/WORKSPACE:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Workspace file for PAXML."""
17 |
18 | load("//paxml:build_defs.bzl", "pax_targets") # @unused
19 | load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
20 |
21 | http_archive(
22 | name = "bazel_skylib",
23 | urls = [
24 | "https://mirror.bazel.build/github.com/bazelbuild/bazel-skylib/releases/download/1.2.1/bazel-skylib-1.2.1.tar.gz",
25 | "https://github.com/bazelbuild/bazel-skylib/releases/download/1.2.1/bazel-skylib-1.2.1.tar.gz",
26 | ],
27 | sha256 = "f7be3474d42aae265405a592bb7da8e171919d74c16f082a5457840f06054728",
28 | )
29 |
30 | load("@bazel_skylib//:workspace.bzl", "bazel_skylib_workspace") #buildifier: disable=load-on-top
31 |
32 | bazel_skylib_workspace()
33 |
34 | http_archive(
35 | name = "rules_python",
36 | sha256 = "cdf6b84084aad8f10bf20b46b77cb48d83c319ebe6458a18e9d2cebf57807cdd",
37 | strip_prefix = "rules_python-0.8.1",
38 | url = "https://github.com/bazelbuild/rules_python/archive/refs/tags/0.8.1.tar.gz",
39 | )
40 |
41 | http_archive(
42 | name = "zlib",
43 | build_file = "@com_google_protobuf//:third_party/zlib.BUILD",
44 | sha256 = "c3e5e9fdd5004dcb542feda5ee4f0ff0744628baf8ed2dd5d66f8ca1197cb1a1",
45 | strip_prefix = "zlib-1.2.11",
46 | urls = [
47 | "https://mirror.bazel.build/zlib.net/zlib-1.2.11.tar.gz",
48 | "https://zlib.net/zlib-1.2.11.tar.gz",
49 | ],
50 | )
51 |
--------------------------------------------------------------------------------
/paxml/contrib/gpu/scripts_gpu/run_pile_singlenode.sh:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | #! /bin/bash
17 | set -u
18 | set -o pipefail
19 |
20 | TFDS_DATA_DIR=$1
21 | VOCAB_PATH=$2
22 | PREC=${3:-"bfloat16"} # Precision (float32, bfloat16)
23 | NUM_GPUS=${4:-8} # Number of GPUs (1, 2, 4, 8)
24 | PERCORE_BATCH_SIZE=${5:-4}
25 | LOG_DIR=${6:-"test_logdir"}
26 |
27 | export VOCAB_PATH=$VOCAB_PATH
28 |
29 | BASE_XLA_FLAGS=${BASE_XLA_FLAGS:-"--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_triton_gemm=false
30 | --xla_gpu_enable_highest_priority_async_stream=true
31 | --xla_gpu_all_reduce_combine_threshold_bytes=51200
32 | --xla_gpu_enable_command_buffer=''"}
33 | export XLA_FLAGS="$BASE_XLA_FLAGS ${XLA_FLAGS:-}"
34 |
35 |
36 | mkdir -p ${LOG_DIR}
37 | python3 -u -m paxml.main \
38 | --job_log_dir=${LOG_DIR} \
39 | --fdl_config=paxml.contrib.gpu.scripts_gpu.configs.Pile126M \
40 | --fdl.FPROP_DTYPE=\"${PREC}\" \
41 | --fdl.ICI_MESH_SHAPE="[${NUM_GPUS}, 1, 1]" \
42 | --fdl.DCN_MESH_SHAPE="[1,1,1]" \
43 | --fdl.PERCORE_BATCH_SIZE=$PERCORE_BATCH_SIZE \
44 | --tfds_data_dir=$TFDS_DATA_DIR \
45 | --alsologtostderr \
46 | 2>&1 | tee ${LOG_DIR}/pile_output.log
47 |
48 | EXP_STATUS=$?
49 |
50 | if [ $EXP_STATUS != 0 ]; then
51 | echo "Run failed"
52 | else
53 | echo "Run succeeded!"
54 | fi
55 |
56 | echo Output written to ${LOG_DIR}/pile_output.log
57 |
--------------------------------------------------------------------------------
/paxml/partitioning_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for partitioning."""
17 | from absl.testing import absltest
18 | from absl.testing import parameterized
19 | import jax
20 | from paxml import partitioning
21 | from praxis import py_utils
22 | from praxis import test_utils
23 |
24 | NestedMap = py_utils.NestedMap
25 | PartitionSpec = jax.sharding.PartitionSpec
26 |
27 |
28 | class PartitioningTest(test_utils.TestCase):
29 |
30 | @parameterized.parameters([(NestedMap, dict), (dict, NestedMap)])
31 | def test_filter_nested_map_basics(self, src_type, filter_type):
32 | full_set = src_type(a=1, b=src_type(c=2, d=[3, src_type(e=6, f=7)]))
33 | partial_set = filter_type(a=0, b=filter_type(d=[0, filter_type(e=0)]))
34 |
35 | expected = src_type(a=1, b=src_type(d=[3, src_type(e=6)]))
36 | actual = partitioning.filter_nestedmap(full_set, partial_set)
37 |
38 | self.assertIsInstance(actual, src_type)
39 | self.assertEqual(expected, actual)
40 |
41 | def test_filter_nested_map_with_partition_spec(self):
42 | full_set = dict(a=[PartitionSpec(None), dict(b=2, c=PartitionSpec(None))])
43 | partial_set = dict(a=[0, dict(c=0)])
44 |
45 | expected = dict(a=[PartitionSpec(None), dict(c=PartitionSpec(None))])
46 | actual = partitioning.filter_nestedmap(full_set, partial_set)
47 |
48 | self.assertIsInstance(actual, dict)
49 | self.assertEqual(expected, actual)
50 |
51 |
52 | if __name__ == '__main__':
53 | absltest.main()
54 |
--------------------------------------------------------------------------------
/paxml/contrib/gpu/scripts_gpu/run_lambada_singlenode.sh:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | #! /bin/bash
17 | set -u
18 | set -o pipefail
19 |
20 | TFDS_DATA_DIR=$1
21 | VOCAB_PATH=$2
22 | PREC=${3:-"bfloat16"} # Precision (float32, bfloat16)
23 | NUM_GPUS=${4:-8} # Number of GPUs (1, 2, 4, 8)
24 | PERCORE_BATCH_SIZE=${5:-4}
25 | ### path to pretrained log_dir
26 | LOG_DIR=$6
27 |
28 | export VOCAB_PATH=$VOCAB_PATH
29 | BASE_XLA_FLAGS=${BASE_XLA_FLAGS:-"--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_triton_gemm=false
30 | --xla_gpu_enable_highest_priority_async_stream=true
31 | --xla_gpu_all_reduce_combine_threshold_bytes=51200
32 | --xla_gpu_enable_command_buffer=''"}
33 | export XLA_FLAGS="$BASE_XLA_FLAGS ${XLA_FLAGS:-}"
34 |
35 |
36 | mkdir -p ${LOG_DIR}
37 | python3 -u -m paxml.main \
38 | --job_log_dir=${LOG_DIR} \
39 | --fdl_config=paxml.contrib.gpu.scripts_gpu.configs.Lambada126M \
40 | --fdl.FPROP_DTYPE=\"${PREC}\" \
41 | --fdl.ICI_MESH_SHAPE="[${NUM_GPUS}, 1, 1]" \
42 | --fdl.DCN_MESH_SHAPE="[1,1,1]" \
43 | --fdl.PERCORE_BATCH_SIZE=$PERCORE_BATCH_SIZE \
44 | --tfds_data_dir=$TFDS_DATA_DIR \
45 | --mode='eval' \
46 | --alsologtostderr \
47 | 2>&1 | tee ${LOG_DIR}/lambada_output.log
48 |
49 | EXP_STATUS=$?
50 |
51 | if [ $EXP_STATUS != 0 ]; then
52 | echo "Run failed"
53 | else
54 | echo "Run succeeded!"
55 | fi
56 |
57 | echo Output written to ${LOG_DIR}/lambada_output.log
58 |
--------------------------------------------------------------------------------
/paxml/contrib/gpu/scripts_gpu/run_base_config_multinode.sh:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | #! /bin/bash
17 | # Assumes you are using a SLURM cluster. Edit flags under --multiprocess_gpu below to suit your setup
18 | set -u
19 |
20 | CONFIG=$1
21 | PREC=${2:-"bfloat16"} # Precision (float32, bfloat16)
22 | NUM_GPUS=${3:-8} # Number of GPUs (1, 2, 4, 8)
23 | LOG_DIR=${4:-"test_logdir"}
24 | TFDS_DATA_DIR=${5:-'None'}
25 | VOCAB_PATH=${6:-'gs://t5-data/vocabs/cc_all.32000.100extra/sentencepiece.model'}
26 | ADDITIONAL_ARGS=${7:-""}
27 |
28 | export VOCAB_PATH=$VOCAB_PATH
29 |
30 | export XLA_PYTHON_CLIENT_MEM_FRACTION=${XLA_PYTHON_CLIENT_MEM_FRACTION:-0.9}
31 | BASE_XLA_FLAGS=${BASE_XLA_FLAGS:-"\
32 | --xla_gpu_enable_latency_hiding_scheduler=true \
33 | --xla_gpu_enable_triton_gemm=false \
34 | --xla_gpu_enable_highest_priority_async_stream=true \
35 | --xla_gpu_all_reduce_combine_threshold_bytes=51200 \
36 | --xla_gpu_enable_command_buffer=''"}
37 | export XLA_FLAGS="$BASE_XLA_FLAGS ${XLA_FLAGS:-}"
38 |
39 |
40 | mkdir -p $LOG_DIR
41 | python3 -u -m paxml.main \
42 | --job_log_dir=$LOG_DIR \
43 | --fdl_config=paxml.contrib.gpu.scripts_gpu.configs.${CONFIG} \
44 | --fdl.FPROP_DTYPE=\"${PREC}\" \
45 | --multiprocess_gpu \
46 | --server_addr=${SLURM_LAUNCH_NODE_IPADDR}:12345 \
47 | --num_hosts=$SLURM_NTASKS \
48 | --host_idx=$SLURM_PROCID \
49 | --alsologtostderr \
50 | $([[ $TFDS_DATA_DIR != "None" ]] && echo --tfds_data_dir=$TFDS_DATA_DIR) \
51 | ${ADDITIONAL_ARGS}
52 |
53 |
--------------------------------------------------------------------------------
/paxml/contrib/gpu/scripts_gpu/run_pile_multinode.sh:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | #! /bin/bash
17 | # Assumes you are using a SLURM cluster. Edit flags under --multiprocess_gpu below to suit your setup
18 | set -u
19 |
20 | TFDS_DATA_DIR=$1
21 | VOCAB_PATH=$2
22 | PREC=${3:-"bfloat16"} # Precision (float32, bfloat16)
23 | NUM_GPUS=${4:-8} # Number of GPUs (1, 2, 4, 8)
24 | PERCORE_BATCH_SIZE=${5:-4}
25 | LOG_DIR=${6:-"test_logdir"}
26 |
27 | export VOCAB_PATH=$VOCAB_PATH
28 | export XLA_PYTHON_CLIENT_MEM_FRACTION=${XLA_PYTHON_CLIENT_MEM_FRACTION:-0.85}
29 | BASE_XLA_FLAGS=${BASE_XLA_FLAGS:-"--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_triton_gemm=false
30 | --xla_gpu_enable_highest_priority_async_stream=true
31 | --xla_gpu_all_reduce_combine_threshold_bytes=51200
32 | --xla_gpu_enable_command_buffer=''"}
33 | export XLA_FLAGS="$BASE_XLA_FLAGS ${XLA_FLAGS:-}"
34 |
35 |
36 | ## NOTE: 126M trained with pure data parallel
37 | mkdir -p $LOG_DIR
38 | python3 -u -m paxml.main \
39 | --job_log_dir=$LOG_DIR \
40 | --fdl_config=paxml.contrib.gpu.scripts_gpu.configs.Pile126M \
41 | --tfds_data_dir=$TFDS_DATA_DIR \
42 | --fdl.FPROP_DTYPE=\"${PREC}\" \
43 | --fdl.ICI_MESH_SHAPE="[${NUM_GPUS},1,1]" \
44 | --fdl.DCN_MESH_SHAPE="[${SLURM_JOB_NUM_NODES},1,1]" \
45 | --fdl.PERCORE_BATCH_SIZE=$PERCORE_BATCH_SIZE \
46 | --multiprocess_gpu \
47 | --server_addr=${SLURM_LAUNCH_NODE_IPADDR}:12345 \
48 | --num_hosts=$SLURM_NTASKS \
49 | --host_idx=$SLURM_PROCID \
50 | --alsologtostderr
51 |
52 |
--------------------------------------------------------------------------------
/paxml/base_executor.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Program executor that drives the training/evaluation loops."""
17 |
18 | import abc
19 | from typing import Any, Sequence
20 |
21 | from etils import epath
22 | from paxml import decode_programs as decode_program_lib
23 | from paxml import partitioning
24 | from paxml import programs
25 | from paxml import tasks_lib
26 | from paxml import trainer_lib
27 | from praxis import base_input
28 | from praxis import pax_fiddle
29 |
30 |
31 | class BaseExecutor(metaclass=abc.ABCMeta):
32 |
33 | @abc.abstractmethod
34 | def setup(
35 | self,
36 | jax_task: tasks_lib.SingleTask,
37 | job_log_dir: epath.Path,
38 | checkpointer: Any,
39 | partitioner: partitioning.Partitioner,
40 | input_specs_provider: base_input.BaseInputSpecsProvider,
41 | # TODO(laigd): encapsulate train_input_p in train_program.
42 | train_input_p: pax_fiddle.Config[base_input.BaseInput],
43 | train_program: programs.BaseTrainProgram,
44 | eval_programs: Sequence[programs.BaseEvalProgram],
45 | decode_programs: Sequence[decode_program_lib.SingleTaskDecodeProgram],
46 | # TODO(laigd): this shouldn't be part of the executor API, consider adding
47 | # a dedicated executor for auto-tuning and get rid of this instead.
48 | early_stopping_fn: trainer_lib.EarlyStoppingFn | None,
49 | exit_after_ondemand_checkpoint: bool = False,
50 | enable_summary_writer: bool = True,
51 | ) -> None:
52 | """Sets up the programs and the executor."""
53 |
54 | @abc.abstractmethod
55 | def start(self) -> None:
56 | """Start executing the programs."""
57 |
--------------------------------------------------------------------------------
/paxml/experiment_vars_summary_parser.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Util for parsing experiment summary text files.
17 |
18 | This util is intentionally in a separate lightweight module that can be used in
19 | other libraries without adding heavy dependencies on other core paxml modules.
20 | """
21 |
22 | import ast
23 | from typing import Any
24 |
25 | from absl import logging
26 |
27 |
28 | def parse(cls_vars_summary: str) -> dict[str, Any]:
29 | """Parses a class variables summary into a dictionary of vars to values.
30 |
31 | Parses summaries created by experiment_utils.get_cls_vars_summary(). Values
32 | are left as strings if ast.literal_eval fails. For example, class instances
33 | will be strings.
34 |
35 | Args:
36 | cls_vars_summary: A summary in the format created by
37 | experiment_utils.get_cls_vars_summary().
38 |
39 | Returns:
40 | Dictionary of variable names and values.
41 | """
42 | cls_vars = {}
43 | lines = cls_vars_summary.splitlines()
44 | for line in lines:
45 | # Skip empty lines and class names
46 | if line.strip().endswith(':') or not line.strip():
47 | continue
48 | # Get name and value of each class variable
49 | keyval = [s.strip() for s in line.split(':', maxsplit=1)]
50 | if len(keyval) == 2:
51 | (key, val) = keyval
52 | try:
53 | cls_vars[key] = ast.literal_eval(val)
54 | except (ValueError, SyntaxError):
55 | # If unable to evaluate as literal, then store value as string
56 | cls_vars[key] = val
57 | else:
58 | logging.warning('Warning: Ignoring line with unexpected format: %s', line)
59 | return cls_vars
60 |
--------------------------------------------------------------------------------
/paxml/checkpoint_version.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Stores current checkpoint version and version history."""
17 |
18 | #
19 | # Past versions:
20 | # 1.2
21 | # - State checkpoint uses Tensorstore's OCDBT format.
22 | # - The state will consist of Tensorstore-managed files plus a 'checkpoint' file
23 | # managed by Orbax, which stores the PyTree structure.
24 | #
25 | # 1.1
26 | # - Metadata has a new key 'train_state_metadata', which is a pytree of array
27 | # metadata corresponding to the train state, including shape, dtype and
28 | # is_masked_node for `TrainState.mdl_vars`.
29 | #
30 | # 1.0
31 | # - Checkpoints folders are organized into per-step directories, where each has
32 | # a subdirectory for every item.
33 | # - The items are 'state' and 'metadata'.
34 | # - Per-step metadata contains a version key.
35 | #
36 | # 0.0
37 | # - Checkpoints do not have per-item directories.
38 | # - Flax checkpoints may or may not be contained within a step directory. In
39 | # other words, the msgpack file may be 'checkpoint_1' instead of
40 | # 'checkpoint_1/checkpoint', where 'checkpoint' is the msgpack file.
41 |
42 | # TODO(b/273803615) When rolled out globally, make _OCDBT_VERSION the standard
43 | # version.
44 | _OCDBT_VERSION: float = 1.2
45 | _VERSION: float = 1.1
46 | _VERSION_KEY: str = 'version'
47 |
48 |
49 | def get_version(tensorstore_use_ocdbt: bool | None = None) -> float:
50 | if tensorstore_use_ocdbt is None:
51 | raise ValueError('Must set the value of `tensorstore_use_ocdbt`.')
52 | if tensorstore_use_ocdbt:
53 | return _OCDBT_VERSION
54 | return _VERSION
55 |
56 |
57 | def get_version_key() -> str:
58 | return _VERSION_KEY
59 |
--------------------------------------------------------------------------------
/paxml/pip_package/build_pip_pkg.sh:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | #!/usr/bin/env bash
17 |
18 | set -e
19 |
20 | PLATFORM="$(uname -s | tr 'A-Z' 'a-z')"
21 |
22 | PIP_FILE_PREFIX="pip_package/"
23 |
24 | export PYTHON_VERSION="${PYTHON_VERSION:-3}"
25 | export PYTHON_MINOR_VERSION="${PYTHON_MINOR_VERSION}"
26 |
27 | if [[ -z "${PYTHON_MINOR_VERSION}" ]]; then
28 | PYTHON="python${PYTHON_VERSION}"
29 | else
30 | PYTHON="python${PYTHON_VERSION}.${PYTHON_MINOR_VERSION}"
31 | fi
32 |
33 | function main() {
34 | DEST=${1}
35 | if [[ -z "${DEST}" ]]; then
36 | echo "No destination directory provided."
37 | exit 1
38 | fi
39 |
40 | # Create the directory, then do dirname on a non-existent file inside it to
41 | # give us an absolute paths with tilde characters resolved to the destination
42 | # directory.
43 | [ ! -d $DEST ] && mkdir -p "${DEST}"
44 |
45 | DEST=$(readlink -f "${DEST}")
46 | echo "=== destination directory: ${DEST}"
47 |
48 | TMPDIR=$(mktemp -d -t tmp.XXXXXXXXXX)
49 |
50 | echo $(date) : "=== Using tmpdir: ${TMPDIR}"
51 |
52 | echo "=== Copy paxml files"
53 |
54 | cp setup.py "${TMPDIR}"
55 | cp requirements.in "${TMPDIR}"
56 | cp LICENSE "${TMPDIR}"
57 | rsync -avm -L paxml "${TMPDIR}"
58 | rsync -avm -L --include="*.so" --include="*_pb2.py" \
59 | --exclude="*.runfiles" --exclude="*_obj" --include="*/" --exclude="*" \
60 | bazel-bin/paxml "${TMPDIR}"
61 |
62 | pushd ${TMPDIR}
63 | echo $(date) : "=== Building wheel"
64 |
65 | ${PYTHON} setup.py bdist_wheel
66 | cp dist/*.whl "${DEST}"
67 | popd
68 | rm -rf ${TMPDIR}
69 | echo $(date) : "=== Output wheel file is in: ${DEST}"
70 | }
71 |
72 | main "$@"
73 |
--------------------------------------------------------------------------------
/paxml/main_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | from paxml import base_experiment
17 | from paxml import experiment_registry
18 | from paxml import main
19 | from absl.testing import absltest
20 |
21 |
22 | class FakeExperimentClassForTest(base_experiment.BaseExperiment):
23 | pass
24 |
25 |
26 | class MainTest(absltest.TestCase):
27 |
28 | def test_get_experiment_failure_to_import_module(self):
29 | with self.assertRaisesRegex(
30 | ValueError,
31 | 'Could not find experiment'
32 | ' `fake_module_for_paxml_main_test.Experiment9876` because could not'
33 | ' import module `fake_module_for_paxml_main_test`',
34 | ):
35 | # my_module is not a module that exists in this test.
36 | _ = main.get_experiment('fake_module_for_paxml_main_test.Experiment9876')
37 |
38 | def test_get_experiment_failure_to_find_experiment_in_module(self):
39 | with self.assertRaisesRegex(
40 | ValueError,
41 | 'Could not find experiment `builtins.Experiment9876`.\n'
42 | 'Registered experiments are: {}',
43 | ):
44 | # Module builtins is guaranteed to exist, but there's no corresponding
45 | # experiment in the builtins module.
46 | _ = main.get_experiment('builtins.Experiment9876')
47 |
48 | def test_get_experiment_success(self):
49 | try:
50 | experiment_registry.register(FakeExperimentClassForTest)
51 |
52 | actual = main.get_experiment(
53 | FakeExperimentClassForTest.__module__
54 | + '.'
55 | + FakeExperimentClassForTest.__qualname__
56 | )
57 |
58 | expected = FakeExperimentClassForTest
59 | self.assertEqual(actual, expected)
60 |
61 | finally:
62 | # Reset registry to empty.
63 | experiment_registry._ExperimentRegistryHelper._registry = {}
64 |
65 |
66 | if __name__ == '__main__':
67 | absltest.main()
68 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Setup.py file for paxml."""
17 |
18 | import os
19 | from setuptools import find_namespace_packages
20 | from setuptools import setup
21 |
22 | # Set this envvar to avoid installing packages from head that can overwrite
23 | # existing installs of those packages, e.g., jax
24 | SKIP_HEAD_INSTALLS = os.environ.get('SKIP_HEAD_INSTALLS', '')
25 |
26 | def _get_requirements():
27 | """Parses requirements.txt file."""
28 | install_requires_tmp = []
29 | with open(
30 | os.path.join(os.path.dirname(__file__), './requirements.in'), 'r'
31 | ) as f:
32 | for line in f:
33 | package_name = line.strip()
34 | # Skip empty line or comments starting with "#".
35 | if (
36 | not package_name
37 | or package_name[0] == '#'
38 | or (' @ ' in package_name and SKIP_HEAD_INSTALLS)
39 | ):
40 | continue
41 | else:
42 | install_requires_tmp.append(package_name)
43 | return install_requires_tmp
44 |
45 |
46 | install_requires = _get_requirements()
47 |
48 | setup(
49 | name='paxml',
50 | version='1.4.0', # use major/minor version number, e.g. "0.1.0"
51 | description=(
52 | 'Framework to configure and run machine learning experiments '
53 | 'on top of Jax.'
54 | ),
55 | author='PAX team',
56 | author_email='pax-dev@google.com',
57 | packages=find_namespace_packages(include=['paxml*']),
58 | python_requires='>=3.10',
59 | install_requires=install_requires,
60 | url='https://github.com/google/paxml',
61 | license='Apache-2.0',
62 | extras_require={
63 | 'gpu': ['jsonlines==3.1.0', 'pysimdjson==5.0.2', 'zstandard==0.18.0'],
64 | },
65 | classifiers=[
66 | 'Programming Language :: Python :: 3.10',
67 | 'Programming Language :: Python :: 3.11',
68 | ],
69 | zip_safe=False,
70 | )
71 |
--------------------------------------------------------------------------------
/paxml/tools/fiddle/codegen_highlevel_parameterization_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for codegen_highlevel_parameterization."""
17 |
18 | from absl.testing import absltest
19 | from fiddle._src.codegen.auto_config import ir_printer
20 | from paxml.tools.fiddle import codegen
21 | from paxml.tools.fiddle import codegen_highlevel_parameterization
22 | from paxml.tools.fiddle import codegen_tracer
23 |
24 |
25 | class CodegenTest(absltest.TestCase):
26 |
27 | def test_highlevel_parameterization_transform(self):
28 | init_pass = codegen.InitTask()
29 | codegen_pass = codegen_highlevel_parameterization.HighlevelParameterization(
30 | lowercasing=False
31 | )
32 | tracer_obj = codegen_tracer.make_tracer("tracer_name", 1)
33 | task = init_pass(tracer_obj)
34 | self.assertIs(codegen_pass(task), task)
35 | self.assertEqual(
36 | ir_printer.format_expr(task.top_level_call.fn.output_value),
37 | "self.tracer_name",
38 | )
39 |
40 | def test_highlevel_parameterization_transforms_keys(self):
41 | init_pass = codegen.InitTask()
42 | codegen_pass = codegen_highlevel_parameterization.HighlevelParameterization(
43 | lowercasing=False
44 | )
45 | tracer_obj = codegen_tracer.make_tracer("tracer_foo", 1)
46 | tracer_obj_2 = codegen_tracer.make_tracer("tracer_bar", 2)
47 | task = init_pass({tracer_obj: [1, 2, 3], 0: 10, tracer_obj_2: [4, 5, 6]})
48 | self.assertIs(codegen_pass(task), task)
49 | converted = task.top_level_call.fn.output_value
50 |
51 | with self.subTest("order_preservation"):
52 | self.assertEqual(list(converted.values()), [[1, 2, 3], 10, [4, 5, 6]])
53 |
54 | with self.subTest("tracer_conversion"):
55 | self.assertEqual(list(converted.keys())[0].attribute, "tracer_foo")
56 | self.assertEqual(list(converted.keys())[2].attribute, "tracer_bar")
57 |
58 |
59 | if __name__ == "__main__":
60 | absltest.main()
61 |
--------------------------------------------------------------------------------
/paxml/experiment_vars_summary_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Unit tests for creating and parsing experiment class variables summaries."""
17 |
18 | import unittest
19 |
20 | from paxml import base_experiment
21 | from paxml import experiment_utils
22 | from paxml import experiment_vars_summary_parser
23 |
24 |
25 | class TestExperimentA(base_experiment.BaseExperiment):
26 | INT_VAR = 0
27 | STR_VAR = 'A'
28 | TUPLE_VAR = (0, 'A')
29 |
30 |
31 | class TestExperimentB(TestExperimentA):
32 | STR_VAR = 'B'
33 | BOOL_VAR = True
34 |
35 |
36 | class ExperimentVarsSummaryTest(unittest.TestCase):
37 |
38 | def test_create_and_parse_cls_vars_summary(self):
39 | summary = experiment_utils.get_cls_vars_summary(TestExperimentB)
40 |
41 | summary_lines = summary.splitlines()
42 | self.assertEqual(summary_lines[0], 'paxml.base_experiment.BaseExperiment:')
43 | self.assertRegex(
44 | summary_lines[1], ' _abc_impl: <_abc._abc_data object at .*>'
45 | )
46 | self.assertEqual(summary_lines[2], '')
47 | self.assertEqual(summary_lines[3], '__main__.TestExperimentA:')
48 | self.assertEqual(summary_lines[4], ' INT_VAR: 0')
49 | self.assertEqual(summary_lines[5], " TUPLE_VAR: (0, 'A')")
50 | self.assertEqual(summary_lines[6], '')
51 | self.assertEqual(summary_lines[7], '__main__.TestExperimentB:')
52 | self.assertEqual(summary_lines[8], ' STR_VAR: B')
53 | self.assertEqual(summary_lines[9], ' BOOL_VAR: True')
54 |
55 | cls_vars = experiment_vars_summary_parser.parse(summary)
56 | self.assertCountEqual(
57 | cls_vars.keys(),
58 | ['_abc_impl', 'INT_VAR', 'STR_VAR', 'TUPLE_VAR', 'BOOL_VAR'],
59 | )
60 | self.assertEqual(cls_vars['INT_VAR'], 0)
61 | self.assertEqual(cls_vars['TUPLE_VAR'], (0, 'A'))
62 | self.assertEqual(cls_vars['STR_VAR'], 'B')
63 | self.assertEqual(cls_vars['BOOL_VAR'], True)
64 |
65 |
66 | if __name__ == '__main__':
67 | unittest.main()
68 |
--------------------------------------------------------------------------------
/paxml/tools/fiddle/graphviz_utils.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Provides a utility function for rendering configs with normalization."""
17 |
18 | from fiddle import graphviz as fiddle_graphviz
19 | import graphviz
20 | from paxml.tools.fiddle import config_normalization
21 |
22 |
23 | def render(
24 | config,
25 | *,
26 | max_depth: int | None = 4,
27 | max_str_length: int | None = 100,
28 | remove_defaults: bool = True,
29 | convert_dataclasses: bool = True,
30 | remove_sharding_annotations: bool = False,
31 | unshare_sharding_config: bool = True,
32 | ) -> graphviz.Graph:
33 | """Renders a config with normalization.
34 |
35 | Args:
36 | config: The config to render.
37 | max_depth: The maximum depth of the rendered graph.
38 | max_str_length: The maximum length of the rendered strings.
39 | remove_defaults: Whether to remove default values. Often with Pax configs,
40 | dataclass field defaulting magic means that you get large, expanded
41 | templates that may actually be unused or equal to their default values.
42 | convert_dataclasses: Whether to convert dataclass instances to configs. This
43 | will only be applied if the dataclasses do not have __post_init__
44 | functions, as __post_init__ can obscure the initial call values.
45 | remove_sharding_annotations: Whether to remove sharding annotations.
46 | unshare_sharding_config: If remove_sharding_annotations=False, whether to
47 | unshare values in sharding configuration. If
48 | remove_sharding_annotations=True, this should be False.
49 |
50 | Returns:
51 | A rendered graph.
52 | """
53 | normalizer = config_normalization.ConfigNormalizer(
54 | remove_defaults=remove_defaults,
55 | convert_dataclasses=convert_dataclasses,
56 | remove_sharding_annotations=remove_sharding_annotations,
57 | unshare_sharding_config=unshare_sharding_config,
58 | )
59 | config = normalizer(config)
60 | return fiddle_graphviz.render(
61 | config=config, max_depth=max_depth, max_str_length=max_str_length
62 | )
63 |
--------------------------------------------------------------------------------
/paxml/tools/fiddle/unshare_sharding.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Copies sharding annotations, making them unshared.
17 |
18 | Sharding doesn't matter for sharding annotations, but Fiddle maintains this
19 | information in general for shared objects.
20 | """
21 |
22 | from typing import Any, TypeVar
23 |
24 | from fiddle import daglish
25 | from paxml.tools.fiddle import remove_sharding
26 |
27 |
28 | _T = TypeVar("_T")
29 |
30 |
31 | def _deep_copy(value):
32 | """Returns a copy of a value, unsharing all sub-values.
33 |
34 | Unlike regular copy.deepcopy, this method removes all sharing of sub-objects
35 | in `value`. See deepcopy documentation. Given an input
36 |
37 | input = [shared, shared]
38 |
39 | where `shared` is any immutable variable,
40 |
41 | `copy.deepcopy(input)` returns `[shared_2, shared_2]`
42 | `_deep_copy(input)` returns `[unshared_3, unshared_4]`
43 |
44 | where usually shared == shared_2 == unshared_3 == unshared_4 (except custom
45 | equivalence operators). As indicated, the deepcopy input has the same object
46 | in both list positions, whereas this _deep_copy() produces different objects.
47 |
48 | Args:
49 | value: Value to copy.
50 | """
51 | return daglish.BasicTraversal.run(
52 | lambda x, state: state.map_children(x), value
53 | )
54 |
55 |
56 | class CustomMemoizedTraversal(daglish.MemoizedTraversal):
57 |
58 | def apply(self, value: Any, state: Any) -> Any:
59 | result = super().apply(value, state)
60 | if remove_sharding._is_sharding_annotation(value, state.current_path): # pylint: disable=protected-access
61 | del self.memo[id(value)]
62 | return result
63 |
64 |
65 | def unshare_sharding(config: _T) -> _T:
66 | """Unshares sharding annotations.
67 |
68 | Args:
69 | config: Base configuration or structure of configuration.
70 |
71 | Returns:
72 | Config sharding annotations unshared.
73 | """
74 |
75 | def transform(value, state: daglish.State):
76 | if remove_sharding._is_sharding_annotation(value, state.current_path): # pylint: disable=protected-access
77 | return _deep_copy(value)
78 | return state.map_children(value)
79 |
80 | return CustomMemoizedTraversal.run(transform, config)
81 |
--------------------------------------------------------------------------------
/paxml/docs/tasks.md:
--------------------------------------------------------------------------------
1 | # Tasks
2 |
3 | doc/pax/tasks
4 |
5 | [TOC]
6 |
7 | ## Introduction
8 |
9 | In general, an ML **task** identifies what you are trying to accomplish with
10 | your ML model. Some examples of ML tasks include:
11 |
12 |
13 |
14 |
15 |
16 | * [Regression][regression]
17 | * [Classification][classification]
18 |
19 |
20 |
21 |
22 |
23 | * [Clustering][clustering]
24 | * [Anomaly detection][anomaly]
25 |
26 |
27 |
28 |
29 |
30 | * [Transcription][transcription]
31 | * [Machine translation][translation]
32 |
33 |
34 |
35 |
36 |
37 | In Pax, a **Task** is an object (derived from the [BaseTask][base-task] class)
38 | that is composed of the necessary components to address a given ML task. These
39 | components include:
40 |
41 | * A Model
42 | * Metrics (optional)
43 | * An optimizer
44 |
45 | A **Mixture** is a term for *a collection of Tasks* and enables fine-tuning a
46 | model on multiple Tasks simultaneously.
47 |
48 | > Tip: For a rudamentary introduction to the basic Pax components, check out
49 | > [Pax Elements][pax-elements]. If you want to dive in for a hands-on
50 | > experience, try the [Pax Model and Task Jupyter Notebook][model_ipynb].
51 |
52 | ## Task How-To's
53 |
54 | ### Define a Task
55 |
56 | A Task contains one or more Models and Learner/Optimizers. The simplest Task
57 | subclass is a `SingleTask` which requires the following Hparams:
58 |
59 | ```python
60 | class HParams(base_task.BaseTask.HParams):
61 | """Task parameters.
62 |
63 | Attributes:
64 | name: Name of this task object, must be a valid identifier.
65 | model: The underlying JAX model encapsulating all the layers.
66 | train: HParams to control how this task should be trained.
67 | metrics: A BaseMetrics aggregator class to determine how metrics are
68 | computed.
69 | loss_aggregator: A LossAggregator aggregator class to derermine how the
70 | losses are aggregated (e.g single or MultiLoss)
71 | vn: HParams to control variational noise.
72 | ```
73 |
74 |
75 |
76 |
77 |
78 | [anomaly]: internal-link/ml-glossary/#anomaly-detection
79 | [base-task]: https://github.com/google/paxml/tree/main/paxml/base_task.py
80 | [classification]: internal-link/ml-glossary/#classification_model
81 | [clustering]: internal-link/ml-glossary/#clustering
82 | [model_ipynb]: https://github.com/google/paxml/tree/main/paxml/docs/hands-on-tutorials.md#pax-model-and-task
83 | [pax-elements]: https://github.com/google/paxml/tree/main/paxml/docs/learning-pax.md#pax-elements
84 | [regression]: internal-link/ml-glossary/#regression-model
85 | [transcription]: https://fireflies.ai/blog/what-is-ai-transcription/
86 | [translation]: https://en.wikipedia.org/wiki/Machine_translation
87 |
--------------------------------------------------------------------------------
/paxml/tasks/vision/params/BUILD:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | # Description:
17 | # Vision modeling model configurations.
18 |
19 | load("//paxml:paxml.bzl", "pytype_library")
20 | # Import for Google-internal testing code.
21 | load("//paxml:build_defs.bzl", "pax_targets")
22 | load("//praxis:build-visibility.bzl", "JAX_VISIBILITY")
23 |
24 | package(default_visibility = JAX_VISIBILITY)
25 |
26 | licenses(["notice"])
27 |
28 | pytype_library(
29 | name = "params",
30 | srcs = [
31 | "imagenet_resnets.py",
32 | ],
33 | deps = [
34 | # Implicit absl.flags dependency.
35 | # Implicit jax dependency.
36 | "//paxml:base_experiment",
37 | "//paxml:experiment_registry",
38 | "//paxml:learners",
39 | "//paxml:tasks_lib",
40 | "//paxml/tasks/vision:input_generator",
41 | "//praxis:base_input",
42 | "//praxis:base_layer",
43 | "//praxis:optimizers",
44 | "//praxis:pax_fiddle",
45 | "//praxis:py_utils",
46 | "//praxis:pytypes",
47 | "//praxis:schedules",
48 | "//praxis/layers",
49 | ],
50 | )
51 |
52 | pytype_library(
53 | name = "mnist",
54 | srcs = ["mnist.py"],
55 | deps = [
56 | # Implicit jax dependency.
57 | "//paxml:base_experiment",
58 | "//paxml:base_task",
59 | "//paxml:experiment_registry",
60 | "//paxml:learners",
61 | "//paxml:tasks_lib",
62 | "//praxis:base_input",
63 | "//praxis:base_layer",
64 | "//praxis:base_model",
65 | "//praxis:optimizers",
66 | "//praxis:pax_fiddle",
67 | "//praxis:py_utils",
68 | "//praxis:pytypes",
69 | "//praxis:schedules",
70 | "//praxis/layers:activations",
71 | "//praxis/layers:convolutions",
72 | "//praxis/layers:linears",
73 | "//praxis/layers:poolings",
74 | # Implicit tensorflow_no_contrib dependency.
75 | # Implicit tensorflow_datasets dependency.
76 | ],
77 | )
78 |
79 | pax_targets(
80 | experiments = [
81 | ":mnist",
82 | ],
83 | prefix_name = "mnist",
84 | )
85 |
86 | pax_targets(
87 | experiments = [
88 | ":params",
89 | ],
90 | )
91 |
92 | # Google-internal testing code.
93 |
--------------------------------------------------------------------------------
/paxml/tasks/lm/BUILD:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | # Description:
17 | # Language modeling-specific libraries and model configurations
18 |
19 | load("//paxml:paxml.bzl", "pytype_strict_library", "pytype_strict_test")
20 | load("//praxis:build-visibility.bzl", "JAX_VISIBILITY")
21 |
22 | package(default_visibility = JAX_VISIBILITY)
23 |
24 | licenses(["notice"])
25 |
26 | pytype_strict_library(
27 | name = "input_generator",
28 | srcs = ["input_generator.py"],
29 | deps = [
30 | # Implicit absl.logging dependency.
31 | # Implicit jax dependency.
32 | "//praxis:base_input",
33 | "//praxis:pax_fiddle",
34 | "//praxis:py_utils",
35 | "//praxis:pytypes",
36 | # Implicit tensorflow_no_contrib dependency.
37 | ],
38 | )
39 |
40 | pytype_strict_library(
41 | name = "model_params",
42 | srcs = ["model_params.py"],
43 | deps = [
44 | # Implicit fiddle dependency.
45 | # Implicit jax dependency.
46 | "//paxml:base_experiment",
47 | "//paxml:tasks_lib",
48 | "//praxis:asserts",
49 | "//praxis:base_layer",
50 | "//praxis:base_model",
51 | "//praxis:optimizers",
52 | "//praxis:pax_fiddle",
53 | "//praxis:py_utils",
54 | "//praxis:schedules",
55 | "//praxis/layers",
56 | "//praxis/layers:activations",
57 | "//praxis/layers:embedding_softmax",
58 | "//praxis/layers:gpu_fast_attention",
59 | "//praxis/layers:models",
60 | "//praxis/layers:transformer_models",
61 | "//praxis/layers/injection:fp8_nvidia_gpu",
62 | ],
63 | )
64 |
65 | filegroup(
66 | name = "testdata",
67 | testonly = 1,
68 | srcs = glob(["testdata/*"]),
69 | )
70 |
71 | pytype_strict_test(
72 | name = "input_generator_test",
73 | srcs = ["input_generator_test.py"],
74 | data = [":testdata"],
75 | deps = [
76 | ":input_generator",
77 | # Implicit absl.testing.absltest.absltest dependency.
78 | # Implicit absl.testing.parameterized dependency.
79 | # Implicit numpy dependency.
80 | "//paxml:test_helper",
81 | "//praxis:base_hyperparams",
82 | "//praxis:pax_fiddle",
83 | "//praxis:test_utils",
84 | # Implicit tensorflow_no_contrib dependency.
85 | ],
86 | )
87 |
--------------------------------------------------------------------------------
/paxml/tools/BUILD:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | # Tools and utilities for Pax users.
17 |
18 | load("//paxml:paxml.bzl", "pytype_strict_library")
19 | load("//praxis:build-visibility.bzl", "JAX_VISIBILITY")
20 |
21 | package(default_visibility = JAX_VISIBILITY)
22 |
23 | licenses(["notice"])
24 |
25 | pytype_strict_library(
26 | name = "dump_input_specs_lib",
27 | srcs = ["dump_input_specs_lib.py"],
28 | deps = [
29 | # Implicit absl.logging dependency.
30 | # Implicit jax dependency.
31 | "//paxml:base_experiment",
32 | "//praxis:base_hyperparams",
33 | "//praxis:base_input",
34 | "//praxis:pytypes",
35 | # Implicit pyglove dependency.
36 | ],
37 | )
38 |
39 | exports_files([
40 | "dump_hparams.py",
41 | "dump_input_specs.py",
42 | "model_analysis.py",
43 | "validate_config.py",
44 | ])
45 |
46 | pytype_strict_library(
47 | name = "validate_config_lib",
48 | srcs = ["validate_config_lib.py"],
49 | deps = [
50 | # Implicit absl.app dependency.
51 | # Implicit absl.flags dependency.
52 | # Implicit absl.logging dependency.
53 | # Implicit fiddle.absl_flags dependency.
54 | # Implicit jax dependency.
55 | # Implicit numpy dependency.
56 | "//paxml:base_experiment",
57 | "//paxml:experiment_registry",
58 | "//praxis:base_layer",
59 | "//praxis:pax_fiddle",
60 | "//praxis:py_utils",
61 | ],
62 | )
63 |
64 | pytype_strict_library(
65 | name = "dump_hparams_lib",
66 | srcs = ["dump_hparams_lib.py"],
67 | deps = [
68 | # Implicit absl.app dependency.
69 | # Implicit absl.flags dependency.
70 | # Implicit absl.logging dependency.
71 | # Implicit etils dependency.
72 | # Implicit fiddle.absl_flags dependency.
73 | # Implicit jax dependency.
74 | # Implicit numpy dependency.
75 | "//paxml:base_experiment",
76 | "//paxml:experiment_registry",
77 | "//paxml:experiment_utils",
78 | "//praxis:base_hyperparams",
79 | "//praxis:base_layer",
80 | "//praxis:pax_fiddle",
81 | "//praxis:py_utils",
82 | "//praxis:pytypes",
83 | # Implicit pyglove dependency.
84 | # Implicit tensorflow_no_contrib dependency.
85 | ],
86 | )
87 |
--------------------------------------------------------------------------------
/paxml/pip_package/release.Dockerfile:
--------------------------------------------------------------------------------
1 | ARG cpu_base_image="ubuntu:22.04"
2 | ARG base_image=$cpu_base_image
3 | FROM $base_image
4 |
5 | LABEL maintainer="Pax team "
6 |
7 | # Re-declare args because the args declared before FROM can't be used in any
8 | # instruction after a FROM.
9 | ARG cpu_base_image="ubuntu:22.04"
10 | ARG base_image=$cpu_base_image
11 | ARG praxis_version
12 | ARG wheel_folder
13 | ENV WHEEL_FOLDER $wheel_folder
14 | ENV PYTHON_VERSION="3"
15 | ENV PYTHON_MINOR_VERSION="10"
16 |
17 | # Pick up some TF dependencies
18 | RUN apt update && DEBIAN_FRONTEND=noninteractive apt install -y --no-install-recommends software-properties-common
19 | RUN apt update && DEBIAN_FRONTEND=noninteractive apt install -y --no-install-recommends \
20 | build-essential \
21 | curl \
22 | git \
23 | pkg-config \
24 | rename \
25 | rsync \
26 | unzip \
27 | vim \
28 | && \
29 | apt-get clean && \
30 | rm -rf /var/lib/apt/lists/*
31 |
32 | # Install python 3.10
33 | RUN apt-get update && apt-get install -y \
34 | python3 python3-dev python3-pip python3-venv && \
35 | rm -rf /var/lib/apt/lists/* && \
36 | python3.10 -m pip install pip --upgrade && \
37 | update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.10 0
38 |
39 | # Make python3.10 the default python version
40 | RUN update-alternatives --install /usr/bin/python python /usr/bin/python3.10 0
41 |
42 | ARG bazel_version=5.1.1
43 | # This is to install bazel, for development purposes.
44 | ENV BAZEL_VERSION ${bazel_version}
45 | RUN mkdir /bazel && \
46 | cd /bazel && \
47 | curl -H "User-Agent: Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/57.0.2987.133 Safari/537.36" -fSsL -O https://github.com/bazelbuild/bazel/releases/download/$BAZEL_VERSION/bazel-$BAZEL_VERSION-installer-linux-x86_64.sh && \
48 | curl -H "User-Agent: Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/57.0.2987.133 Safari/537.36" -fSsL -o /bazel/LICENSE.txt https://raw.githubusercontent.com/bazelbuild/bazel/master/LICENSE && \
49 | chmod +x bazel-*.sh && \
50 | ./bazel-$BAZEL_VERSION-installer-linux-x86_64.sh && \
51 | cd / && \
52 | rm -f /bazel/bazel-$BAZEL_VERSION-installer-linux-x86_64.sh
53 |
54 | COPY . /paxml
55 | RUN mkdir $WHEEL_FOLDER
56 | RUN if [ "$praxis_version" = "release-test" ] ; then git clone https://github.com/google/praxis.git; else git clone -b r${praxis_version} https://github.com/google/praxis.git; fi
57 | RUN if [ "$praxis_version" = "release-test" ] ; then sed -i 's/ @ git.*//g' /praxis/requirements.in; fi
58 | RUN pip3 install -e praxis
59 |
60 | RUN cp -r praxis/praxis /paxml/
61 | RUN sed -i 's/ @ git.*//g' paxml/requirements.in
62 | RUN pip3 install -r paxml/requirements.in
63 |
64 | RUN mv paxml/paxml/pip_package /paxml/
65 | RUN cd /paxml && bash pip_package/build.sh
66 |
67 | WORKDIR /
68 |
69 | CMD ["/bin/bash"]
70 |
--------------------------------------------------------------------------------
/paxml/host_callback_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Unit tests for callback."""
17 |
18 | from absl.testing import absltest
19 | from paxml import host_callback
20 |
21 |
22 | class RepositoryTest(absltest.TestCase):
23 |
24 | def test_namespace(self):
25 | regex_id = host_callback.repository('namespace_1').add('regex')
26 | self.assertEqual(
27 | host_callback.repository('namespace_1').get(regex_id), 'regex'
28 | )
29 |
30 | def test_namespace_same_object(self):
31 | repository1 = host_callback.repository('namespace_1')
32 | repository2 = host_callback.repository('namespace_1')
33 | self.assertIs(repository1, repository2)
34 |
35 | def test_namespace_different_object(self):
36 | repository1 = host_callback.repository('namespace_2')
37 | repository2 = host_callback.repository('namespace_3')
38 | self.assertIsNot(repository1, repository2)
39 |
40 | def test_add_and_pop(self):
41 | repository = host_callback.Repository()
42 | self.assertEqual(repository.size, 0)
43 | regex_id = repository.add('regex')
44 | self.assertEqual(repository.size, 1)
45 | self.assertTrue(repository.pop(regex_id))
46 | self.assertEqual(repository.size, 0)
47 |
48 | def test_pop_unknown_id(self):
49 | repository = host_callback.Repository()
50 | self.assertEqual(repository.size, 0)
51 | self.assertFalse(repository.pop(0))
52 | self.assertEqual(repository.size, 0)
53 |
54 | def test_add_and_get(self):
55 | repository = host_callback.Repository()
56 | regex_id = repository.add('regex')
57 | self.assertEqual(repository.get(regex_id), 'regex')
58 |
59 | def test_get_unknown_id(self):
60 | repository = host_callback.Repository()
61 | with self.assertRaises(KeyError):
62 | repository.get(0)
63 |
64 | def test_eviction(self):
65 | repository = host_callback.Repository(max_size=1)
66 | self.assertEqual(repository.size, 0)
67 | regex_ids = []
68 | for i in range(4):
69 | regex_ids.append(repository.add(f'regex_{i}'))
70 | self.assertEqual(repository.size, 1)
71 |
72 | # The last regex still remains.
73 | self.assertEqual(repository.get(regex_ids[-1]), 'regex_3')
74 |
75 | # The others were evicted.
76 | for regex_id in regex_ids[:-1]:
77 | with self.assertRaises(KeyError):
78 | repository.get(regex_id)
79 |
80 |
81 | if __name__ == '__main__':
82 | absltest.main()
83 |
--------------------------------------------------------------------------------
/paxml/setup_jax.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | r"""Utilities to set up JAX global configs."""
17 |
18 | import dataclasses
19 |
20 | from absl import logging
21 | import jax
22 | from praxis import py_utils
23 | import tensorflow.compat.v2 as tf
24 |
25 |
26 | @dataclasses.dataclass
27 | class JaxDistributedOptions:
28 | coordinator_address: str
29 | num_processes: int
30 | process_id: int
31 |
32 |
33 | @py_utils.benchmark('[PAX STATUS]: ')
34 | def setup_jax(
35 | globally_use_hardware_rng: bool,
36 | jax_backend_target: str | None,
37 | jax_xla_backend: str | None,
38 | jax_enable_checks: bool,
39 | jax_traceback_filtering_option: str = 'auto',
40 | should_initialize_jax_distributed: bool = False,
41 | jax_distributed_options: JaxDistributedOptions | None = None,
42 | ) -> None:
43 | """Setups JAX and logs information about this job."""
44 |
45 | # Hide any GPUs from TensorFlow. Otherwise TF might reserve memory and make
46 | # it unavailable to JAX.
47 | tf.config.experimental.set_visible_devices([], 'GPU')
48 | if globally_use_hardware_rng:
49 | py_utils.set_globally_use_rbg_prng_key()
50 |
51 | # Log tracing and compilation time.
52 | jax.config.update('jax_log_compiles', True)
53 |
54 | # Allow users to configure JAX traceback filtering.
55 | # https://github.com/google/jax/blob/main/jax/_src/config.py
56 | jax.config.update('jax_traceback_filtering', jax_traceback_filtering_option)
57 |
58 | if jax_enable_checks:
59 | jax.config.update('jax_enable_checks', True)
60 | logging.info('jax_enable_checks has been enabled.')
61 |
62 | if jax_backend_target:
63 | logging.info('Using JAX backend target %s', jax_backend_target)
64 | jax_xla_backend = 'None' if jax_xla_backend is None else jax_xla_backend
65 | logging.info('Using JAX XLA backend %s', jax_xla_backend)
66 |
67 | if should_initialize_jax_distributed:
68 | if jax_distributed_options:
69 | jax.distributed.initialize(
70 | jax_distributed_options.coordinator_address,
71 | jax_distributed_options.num_processes,
72 | jax_distributed_options.process_id,
73 | )
74 | else:
75 | jax.distributed.initialize()
76 |
77 | logging.info('JAX process: %d / %d', jax.process_index(), jax.process_count())
78 | logging.info('JAX devices: %r', jax.devices())
79 | logging.info('jax.device_count(): %d', jax.device_count())
80 | logging.info('jax.local_device_count(): %d', jax.local_device_count())
81 | logging.info('jax.process_count(): %d', jax.process_count())
82 |
--------------------------------------------------------------------------------
/paxml/pip_package/build.sh:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | #!/bin/bash
17 |
18 | # This script uses custom-op docker, downloads code, builds and tests and then
19 | # builds a pip package.
20 |
21 | # See README.md for instructions to use this script.
22 |
23 | set -e -x
24 |
25 | # Override the following env variables if necessary.
26 | export PYTHON_VERSION="${PYTHON_VERSION:-3}"
27 | export PYTHON_MINOR_VERSION="${PYTHON_MINOR_VERSION}"
28 | export PIP_MANYLINUX2010="${PIP_MANYLINUX2010:-1}"
29 | export DEST="${WHEEL_FOLDER:-/tmp/wheels}"
30 |
31 | if [[ -z "${PYTHON_MINOR_VERSION}" ]]; then
32 | PYTHON="python${PYTHON_VERSION}"
33 | else
34 | PYTHON="python${PYTHON_VERSION}.${PYTHON_MINOR_VERSION}"
35 | fi
36 | update-alternatives --install /usr/bin/python3 python3 "/usr/bin/$PYTHON" 1
37 |
38 | function write_to_bazelrc() {
39 | echo "$1" >> .bazelrc
40 | }
41 |
42 | function write_action_env_to_bazelrc() {
43 | write_to_bazelrc "build --action_env $1=\"$2\""
44 | }
45 |
46 | # Remove .bazelrc if it already exist
47 | [ -e .bazelrc ] && rm .bazelrc
48 |
49 | write_to_bazelrc "build -c opt"
50 | write_to_bazelrc 'build --cxxopt="-std=c++14"'
51 | write_to_bazelrc 'build --cxxopt="-D_GLIBCXX_USE_CXX11_ABI=0"'
52 | write_to_bazelrc 'build --auto_output_filter=subpackages'
53 | write_to_bazelrc 'build --copt="-Wall" --copt="-Wno-sign-compare"'
54 | write_to_bazelrc 'build --linkopt="-lrt -lm"'
55 |
56 | TF_NEED_CUDA=0
57 | echo 'Using installed tensorflow'
58 | TF_CFLAGS=( $(${PYTHON} -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_compile_flags()))') )
59 | TF_LFLAGS="$(${PYTHON} -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_link_flags()))')"
60 |
61 | write_action_env_to_bazelrc "TF_HEADER_DIR" ${TF_CFLAGS:2}
62 | SHARED_LIBRARY_DIR=${TF_LFLAGS:2}
63 | SHARED_LIBRARY_NAME=$(echo $TF_LFLAGS | rev | cut -d":" -f1 | rev)
64 | if ! [[ $TF_LFLAGS =~ .*:.* ]]; then
65 | if [[ "$(uname)" == "Darwin" ]]; then
66 | SHARED_LIBRARY_NAME="libtensorflow_framework.dylib"
67 | else
68 | SHARED_LIBRARY_NAME="libtensorflow_framework.so"
69 | fi
70 | fi
71 | write_action_env_to_bazelrc "TF_SHARED_LIBRARY_DIR" ${SHARED_LIBRARY_DIR}
72 | write_action_env_to_bazelrc "TF_SHARED_LIBRARY_NAME" ${SHARED_LIBRARY_NAME}
73 | write_action_env_to_bazelrc "TF_NEED_CUDA" ${TF_NEED_CUDA}
74 |
75 | bazel clean
76 | bazel build ...
77 | bazel test --test_output=all --test_verbose_timeout_warnings -- paxml/... -paxml/tasks/vision:input_generator_test
78 |
79 | ./pip_package/build_pip_pkg.sh "$DEST" ${PYTHON_VERSION}
80 | pip3 freeze > "${DEST}/dependencies.txt"
81 |
--------------------------------------------------------------------------------
/paxml/tools/fiddle/remove_sharding.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Pass to remove sharding annotations from a model.
17 |
18 | This can be useful for visualization, and code generation. Currently, the goal
19 | is not perfection; for visualization we just want to remove enough to make
20 | config more readable, and for code generation, sharding annotations are added
21 | back in later (it just helps to factor them into a separate function).
22 | """
23 |
24 | from typing import TypeVar
25 |
26 | import fiddle as fdl
27 | from fiddle import daglish
28 | from praxis import base_layer
29 | from praxis import pax_fiddle
30 |
31 | _T = TypeVar("_T")
32 |
33 | BASE_LAYER_SHARDING_FIELDS = {
34 | "weight_split_dims_mapping",
35 | "activation_split_dims_mapping",
36 | }
37 |
38 |
39 | SHARDING_TYPES = (
40 | base_layer.BaseLayer.WeightSharding,
41 | base_layer.BaseLayer.ActivationSharding,
42 | )
43 |
44 |
45 | def _is_sharding_annotation(value, path: daglish.Path) -> bool:
46 | """Returns whether the current value or path is for a sharding annotation."""
47 | if path:
48 | last_elt = path[-1]
49 | if (
50 | isinstance(last_elt, daglish.Attr)
51 | and last_elt.name in BASE_LAYER_SHARDING_FIELDS
52 | ):
53 | return True
54 |
55 | if isinstance(value, fdl.Buildable):
56 | fn_or_cls = fdl.get_callable(value)
57 | return isinstance(fn_or_cls, type) and issubclass(fn_or_cls, SHARDING_TYPES)
58 |
59 | return False
60 |
61 |
62 | class RemoveSentinel:
63 | pass
64 |
65 |
66 | _remove_sentinel = RemoveSentinel()
67 |
68 |
69 | def remove_sharding(config: _T, replace_with_default: bool = False) -> _T:
70 | """Removes sharding annotations from a config.
71 |
72 | Args:
73 | config: Base configuration or structure of configuration.
74 | replace_with_default: Instead of just removing the sharding annotations,
75 | replace them with the default values.
76 |
77 | Returns:
78 | Config without sharding annotations.
79 | """
80 |
81 | def transform(value, state: daglish.State):
82 | value = state.map_children(value)
83 | if _is_sharding_annotation(value, state.current_path):
84 | return _remove_sentinel
85 | elif isinstance(value, fdl.Buildable):
86 | for name, sub_value in fdl.ordered_arguments(value).items():
87 | if sub_value is _remove_sentinel:
88 | delattr(value, name)
89 | if replace_with_default:
90 | default_obj = pax_fiddle.Config(fdl.get_callable(value))
91 | setattr(value, name, getattr(default_obj, name))
92 | return value
93 |
94 | return daglish.MemoizedTraversal.run(transform, config)
95 |
--------------------------------------------------------------------------------
/paxml/pip_package/prepare_release.sh:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | #!/bin/bash
17 |
18 | # This script prepare a new release by:
19 | # 1) update version number in setup.py and cloudbuild-release.yaml
20 | # 2) add a new section in RELEASE.md with version and corresponding commit
21 |
22 | set -e -x
23 |
24 | function print_help_and_exit {
25 | echo "Usage: prepare_release.sh -v -x -d "
26 | echo "exp: bash prepare_release.sh -v 0.2.0 -x 0.2.0 -d 20221114"
27 | exit 0
28 | }
29 |
30 | while getopts "hv:d:x:" opt; do
31 | case $opt in
32 | v)
33 | PAXML_VERSION=${OPTARG}
34 | ;;
35 | x)
36 | PRAXIS_VERSION=${OPTARG}
37 | ;;
38 | d)
39 | BUILD_DATE=${OPTARG}
40 | ;;
41 | *)
42 | print_help_and_exit
43 | ;;
44 | esac
45 | done
46 |
47 | RELEASE_NOTE="../RELEASE.md"
48 | RELEASE_NOTE_NEW="release_new.md"
49 |
50 | if [[ -z "$BUILD_DATE" ]]; then
51 | echo "Build date is required!"
52 | exit 1
53 | fi
54 |
55 | if [[ -z "$PAXML_VERSION" ]]; then
56 | echo "paxml version is required!"
57 | exit 1
58 | fi
59 |
60 | echo "Build date: "$BUILD_DATE
61 | echo "PAXML version: "$PAXML_VERSION
62 |
63 | if [[ ! -z "$PRAXIS_VERSION" ]]; then
64 | sed -i "s/_PRAXIS_VERSION: '[0-9.]*'/_PRAXIS_VERSION: '$PRAXIS_VERSION'/" cloudbuild-release.yaml
65 | fi
66 |
67 | sed -i "s/version='[0-9.]*'/version='$PAXML_VERSION'/" setup.py
68 | sed -i "s/_RELEASE_VERSION: '[0-9.]*'/_RELEASE_VERSION: '$PAXML_VERSION'/" cloudbuild-release.yaml
69 | gsutil cp gs://pax-on-cloud-tpu-project/wheels/"$BUILD_DATE"/paxml_commit.txt ./
70 | gsutil cp gs://pax-on-cloud-tpu-project/wheels/"$BUILD_DATE"/praxis_commit.txt ./
71 | PAXML_COMMIT=$(> $RELEASE_NOTE_NEW
78 | echo "## Major Features and Improvements" >> $RELEASE_NOTE_NEW
79 | echo "## Breaking changes" >> $RELEASE_NOTE_NEW
80 | echo "## Deprecations" >> $RELEASE_NOTE_NEW
81 | echo "## Note" >> $RELEASE_NOTE_NEW
82 | echo "* Version: $PAXML_VERSION" >> $RELEASE_NOTE_NEW
83 | echo "* Build Date: $BUILD_DATE" >> $RELEASE_NOTE_NEW
84 | echo "* Paxml commit: $PAXML_COMMIT" >> $RELEASE_NOTE_NEW
85 | echo "* Praxis version: $PRAXIS_VERSION" >> $RELEASE_NOTE_NEW
86 | echo "* Praxis commit: $PRAXIS_COMMIT" >> $RELEASE_NOTE_NEW
87 | RELEASE_NOTE_TMP="RELEASE.tmp.md"
88 | cat $RELEASE_NOTE_NEW $RELEASE_NOTE >> $RELEASE_NOTE_TMP
89 | rm $RELEASE_NOTE_NEW
90 | rm $RELEASE_NOTE
91 | mv $RELEASE_NOTE_TMP $RELEASE_NOTE
92 |
--------------------------------------------------------------------------------
/paxml/docs/tutorials/inputs_in_Pax-eval.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "id": "4Rh0P4V34wZm"
7 | },
8 | "source": [
9 | "# Pax Workshop\n",
10 | "## Inputs in Pax - eval\n",
11 | "\n",
12 | "This colab demonstrates how inputs in Pax work.\n"
13 | ]
14 | },
15 | {
16 | "cell_type": "code",
17 | "execution_count": null,
18 | "metadata": {
19 | "id": "mNJyBaezlI7O"
20 | },
21 | "outputs": [],
22 | "source": [
23 | "from praxis import base_input\n",
24 | "from praxis import base_hyperparams\n",
25 | "from paxml import seqio_input\n",
26 | "import numpy as np"
27 | ]
28 | },
29 | {
30 | "cell_type": "markdown",
31 | "metadata": {
32 | "id": "2PJLV0fSCsyM"
33 | },
34 | "source": [
35 | "Let's look at some eval data."
36 | ]
37 | },
38 | {
39 | "cell_type": "code",
40 | "execution_count": null,
41 | "metadata": {
42 | "id": "x2-q5_vhBByz"
43 | },
44 | "outputs": [],
45 | "source": [
46 | "import t5.data.mixtures\n",
47 | "p_eval = seqio_input.SeqIOInput.HParams(\n",
48 | " mixture_name='super_glue_copa_v102',\n",
49 | " split_name='validation',\n",
50 | " task_feature_lengths={\n",
51 | " 'inputs': 256,\n",
52 | " 'targets': 32,\n",
53 | " },\n",
54 | " feature_converter=seqio_input.LanguageModelFeatures(pack=False),\n",
55 | " batch_size=4,\n",
56 | " infeed_host_index=1,\n",
57 | " num_infeed_hosts=4,\n",
58 | " is_training=False,\n",
59 | " reset_for_eval=True)\n",
60 | "inp_eval = base_hyperparams.instantiate(p_eval)"
61 | ]
62 | },
63 | {
64 | "cell_type": "code",
65 | "execution_count": null,
66 | "metadata": {
67 | "id": "q0gj3RvcBoNU"
68 | },
69 | "outputs": [],
70 | "source": [
71 | "# The dataset is finite, and raises exception after being exhausted\n",
72 | "inp_eval.reset()\n",
73 | "for i in range(10):\n",
74 | " try:\n",
75 | " batch = inp_eval.get_next()\n",
76 | " except StopIteration:\n",
77 | " print(i, 'batches')\n",
78 | " break"
79 | ]
80 | },
81 | {
82 | "cell_type": "markdown",
83 | "metadata": {
84 | "id": "he9Np05oCHD_"
85 | },
86 | "source": [
87 | "Try tweaking `infeed_host_index` above and verify that each host raises after the same number of batches. What happens if we change `num_infeed_hosts`? How does that change the number of batches it takes to exhaust the data?\n",
88 | "\n",
89 | "Also verify that `batch.eval_sample_weights` field is used to tell which examples are added as paddings."
90 | ]
91 | }
92 | ],
93 | "metadata": {
94 | "kernelspec": {
95 | "display_name": "Python 3 (ipykernel)",
96 | "language": "python",
97 | "name": "python3"
98 | },
99 | "language_info": {
100 | "codemirror_mode": {
101 | "name": "ipython",
102 | "version": 3
103 | },
104 | "file_extension": ".py",
105 | "mimetype": "text/x-python",
106 | "name": "python",
107 | "nbconvert_exporter": "python",
108 | "pygments_lexer": "ipython3",
109 | "version": "3.8.10"
110 | }
111 | },
112 | "nbformat": 4,
113 | "nbformat_minor": 1
114 | }
115 |
--------------------------------------------------------------------------------
/paxml/ghostnorm/BUILD:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | # Description:
17 | # Library of layers that implement ghost norm protocol
18 | load("//paxml:paxml.bzl", "pytype_strict_library", "pytype_strict_test")
19 | load("//praxis:build-visibility.bzl", "JAX_VISIBILITY")
20 |
21 | package(default_visibility = JAX_VISIBILITY)
22 |
23 | pytype_strict_library(
24 | name = "base",
25 | srcs = ["base.py"],
26 | deps = [
27 | # Implicit flax.core dependency.
28 | # Implicit jax dependency.
29 | "//praxis:pytypes",
30 | ],
31 | )
32 |
33 | pytype_strict_library(
34 | name = "linears",
35 | srcs = ["linears.py"],
36 | deps = [
37 | ":base",
38 | # Implicit jax dependency.
39 | "//praxis:pytypes",
40 | "//praxis/layers",
41 | "//praxis/layers:base_ops",
42 | ],
43 | )
44 |
45 | pytype_strict_library(
46 | name = "embedding",
47 | srcs = ["embedding.py"],
48 | deps = [
49 | ":base",
50 | ":linears",
51 | # Implicit jax dependency.
52 | # Implicit jax.experimental.sparse dependency.
53 | "//praxis:base_layer",
54 | "//praxis:pytypes",
55 | "//praxis/layers",
56 | "//praxis/layers:base_ops",
57 | ],
58 | )
59 |
60 | pytype_strict_library(
61 | name = "generic_wrapper",
62 | srcs = ["generic_wrapper.py"],
63 | deps = [
64 | ":base",
65 | ":embedding",
66 | ":linears",
67 | # Implicit flax.core dependency.
68 | # Implicit jax dependency.
69 | "//praxis:base_layer",
70 | "//praxis:pax_fiddle",
71 | "//praxis:pytypes",
72 | "//praxis/layers:attentions",
73 | "//praxis/layers:embedding_softmax",
74 | "//praxis/layers:linears",
75 | "//praxis/layers:normalizations",
76 | "//praxis/layers:transformers",
77 | ],
78 | )
79 |
80 | pytype_strict_test(
81 | name = "layers_test",
82 | srcs = ["layers_test.py"],
83 | deps = [
84 | ":base",
85 | ":embedding",
86 | ":generic_wrapper",
87 | ":linears",
88 | # Implicit absl.testing.absltest.absltest dependency.
89 | # Implicit absl.testing.parameterized dependency.
90 | # Implicit jax dependency.
91 | # Implicit numpy dependency.
92 | # Implicit optax dependency.
93 | "//praxis:base_layer",
94 | "//praxis:pax_fiddle",
95 | "//praxis:py_utils",
96 | "//praxis/layers:embedding_softmax",
97 | "//praxis/layers:linears",
98 | "//praxis/layers:transformer_models",
99 | # Implicit tensorflow_no_contrib dependency.
100 | ],
101 | )
102 |
--------------------------------------------------------------------------------
/paxml/contrib/gpu/scripts_gpu/lora_utils.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | from abc import ABC
17 | import fiddle as fdl
18 | from paxml import tasks_lib
19 | from praxis import pax_fiddle
20 | from praxis.contrib.gpu.scripts_gpu.lora_layers import (
21 | LoraAttentionProjection,
22 | LoraCombinedQKVProjection,
23 | LoraLinear,
24 | )
25 | from praxis.layers import transformers
26 |
27 |
28 | class LoRAMixin(ABC):
29 | USE_LORA = False
30 | LORA_RANK = 8
31 | LORA_TARGET_LAYERS = "all"
32 |
33 | def _validate(self):
34 | if self.LORA_TARGET_LAYERS not in ["all", "attention", "mlp"]:
35 | raise ValueError(
36 | "LAYERS_TO_INCLUDE_FOR_LORA should be one of all, attention or mlp."
37 | )
38 |
39 | def configure_lora(
40 | self, task_p: pax_fiddle.Config[tasks_lib.SingleTask]
41 | ) -> pax_fiddle.Config[tasks_lib.SingleTask]:
42 | if not self.USE_LORA:
43 | return task_p
44 |
45 | self._validate()
46 | train_p = task_p.train
47 |
48 | if hasattr(self, "CHECKPOINT_IGNORE_RULES"):
49 | self.CHECKPOINT_IGNORE_RULES = [r"^.*lora.*$"]
50 |
51 | train_p.learner.bprop_variable_inclusion = [r"^.*lora.*$"]
52 | stacked_p = task_p.model.lm_tpl.stacked_transformer_tpl
53 | if issubclass(
54 | fdl.get_callable(stacked_p), transformers.StackedTransformerRepeated
55 | ):
56 | stacked_p = stacked_p.block
57 | stacked_p = stacked_p.transformer_layer_params_tpl
58 |
59 | if self.LORA_TARGET_LAYERS in ["all", "mlp"]:
60 | ff_templ = stacked_p.tr_fflayer_tpl.fflayer_tpl
61 | original_linear_p = ff_templ.linear_tpl
62 | ff_templ.linear_tpl = pax_fiddle.Config(
63 | LoraLinear,
64 | rank=self.LORA_RANK,
65 | name="lora_linear",
66 | )
67 | ff_templ.linear_tpl.copy_fields_from(original_linear_p)
68 |
69 | if self.LORA_TARGET_LAYERS in ["all", "attention"]:
70 | if hasattr(stacked_p.tr_atten_tpl, "combined_qkv_proj_tpl"):
71 | original_combined_qkv_p = stacked_p.tr_atten_tpl.combined_qkv_proj_tpl
72 | stacked_p.tr_atten_tpl.combined_qkv_proj_tpl = pax_fiddle.Config(
73 | LoraCombinedQKVProjection,
74 | name="lora_qkv_projection",
75 | rank=self.LORA_RANK,
76 | )
77 | stacked_p.tr_atten_tpl.combined_qkv_proj_tpl.copy_fields_from(
78 | original_combined_qkv_p
79 | )
80 |
81 | original_proj_p = stacked_p.tr_atten_tpl.proj_tpl
82 | stacked_p.tr_atten_tpl.proj_tpl = pax_fiddle.Config(
83 | LoraAttentionProjection,
84 | name="lora_attention_projection",
85 | rank=self.LORA_RANK,
86 | )
87 | stacked_p.tr_atten_tpl.proj_tpl.copy_fields_from(original_proj_p)
88 |
89 | return task_p
90 |
--------------------------------------------------------------------------------
/paxml/tools/fiddle/make_parameterized_experiment.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Makes a parameterized experiment from a legacy BaseExperiment subclass.
17 |
18 | TODO(b/292000357): Add unit tests. This is currently decently well integration
19 | tested through codegen_test.py
20 | """
21 |
22 | from typing import Type
23 |
24 | from paxml import base_experiment
25 | from paxml import parameterized_experiment
26 | from paxml.tools.fiddle import config_normalization
27 | from praxis import pax_fiddle
28 |
29 |
30 | def from_legacy(
31 | experiment_cls: Type[base_experiment.BaseExperiment],
32 | *,
33 | normalizer: config_normalization.ConfigNormalizer
34 | | None = config_normalization.default_normalizer(),
35 | has_train_dataset: bool = True,
36 | has_input_specs_provider: bool = True,
37 | ) -> pax_fiddle.Config[parameterized_experiment.ParameterizedExperiment]:
38 | """Returns a ParameterizedExperiment config from a legacy experiment.
39 |
40 | Args:
41 | experiment_cls: Subclass of BaseExperiment.
42 | normalizer: Object that will normalize the output configuration. Pass None
43 | if you don't want any normalization.
44 | has_train_dataset: Whether the experiment has train datasets. Usually this
45 | is true, but some experiments may be inference-only, and calling their
46 | .train_datasets() method might raise an error. Set this to False in those
47 | cases.
48 | has_input_specs_provider: Likewise, usually it's safe to leave this as its
49 | default (True), but for occasional situations like testing, it may be
50 | reasonable to disable.
51 | """
52 | experiment: base_experiment.BaseExperiment = experiment_cls()
53 |
54 | # Get the task configuration, modulo any changes.
55 | task_config = experiment.task()
56 |
57 | dataset_configs = experiment.datasets()
58 | eval_datasets = [
59 | dataset_config
60 | for dataset_config in dataset_configs
61 | if not dataset_config.is_training
62 | ]
63 | decode_datasets = experiment.decode_datasets()
64 | if not isinstance(decode_datasets, list):
65 | decode_datasets = list(decode_datasets)
66 | overall_config = pax_fiddle.Config(
67 | parameterized_experiment.ParameterizedExperiment,
68 | task=task_config,
69 | eval_datasets=eval_datasets,
70 | )
71 | if has_train_dataset:
72 | overall_config.train_datasets = experiment.train_datasets()
73 | if has_input_specs_provider:
74 | overall_config.input_specs_provider = (
75 | experiment.get_input_specs_provider_params()
76 | )
77 | if decode_datasets:
78 | overall_config.decode_datasets = decode_datasets
79 |
80 | # Now run normalization, and return the result.
81 | return normalizer(overall_config) if normalizer else overall_config
82 |
--------------------------------------------------------------------------------
/paxml/base_inference_runner_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for base_inference_runner."""
17 |
18 | from __future__ import annotations
19 |
20 | from typing import Any, Sequence
21 |
22 | from absl.testing import absltest
23 | import jax
24 | import numpy as np
25 | from paxml import base_inference_runner
26 | from paxml import train_states
27 | from praxis import base_hyperparams
28 | from praxis import base_layer
29 | from praxis import base_model
30 | from praxis import pax_fiddle
31 | from praxis import py_utils
32 | from praxis import pytypes
33 | from praxis import test_utils
34 | import tensorflow.compat.v2 as tf
35 | import tensorflow_datasets as tfds
36 |
37 | instantiate = base_hyperparams.instantiate
38 | NestedMap = py_utils.NestedMap
39 | NestedWeightHParams = base_layer.NestedWeightHParams
40 | PRNGKey = pytypes.PRNGKey
41 | TrainState = train_states.TrainState
42 |
43 |
44 | class DummyInference(base_inference_runner.BaseInferenceRunner):
45 | output: Any = None
46 | output_schema: Any = None
47 |
48 | def infer(self, train_state: TrainState, prng_key: PRNGKey,
49 | var_weight_hparams: NestedWeightHParams,
50 | input_batch: NestedMap) -> NestedMap:
51 | return self.output
52 |
53 |
54 | class BaseInferenceRunnerTest(test_utils.TestCase):
55 |
56 | def test_infer(self):
57 | dummy_output = NestedMap(
58 | tensor=np.arange(64, dtype=np.float32).reshape(8, 8),
59 | nested=NestedMap(
60 | text=np.array([f'{i}'.encode('utf-8') for i in range(8)],
61 | dtype=object)))
62 | dummy_schema = NestedMap(
63 | tensor=tfds.features.Tensor(shape=(8,), dtype=tf.float32),
64 | nested=NestedMap(text=tfds.features.Text()))
65 |
66 | infer_runner_p = pax_fiddle.Config(
67 | DummyInference, output=dummy_output, output_schema=dummy_schema
68 | )
69 | infer_runner = infer_runner_p.Instantiate(model=None)
70 |
71 | serialized_outputs = infer_runner.serialize_outputs(
72 | # Pass dummy values to all 4 arguments of infer().
73 | infer_runner.infer(*([None] * 4)))
74 |
75 | expected_outputs: Sequence[NestedMap] = py_utils.tree_unstack(
76 | dummy_output, 0
77 | )
78 | self.assertEqual(len(serialized_outputs), len(expected_outputs))
79 |
80 | features_dict = tfds.features.FeaturesDict(dummy_schema)
81 | for serialized, expected in zip(serialized_outputs, expected_outputs):
82 | output = features_dict.deserialize_example(serialized)
83 | output_np = jax.tree.map(lambda x: x.numpy(), output)
84 |
85 | for output_leaf, expected_leaf in zip(
86 | jax.tree_util.tree_leaves(output_np),
87 | jax.tree_util.tree_leaves(expected)):
88 | self.assertArraysEqual(output_leaf, expected_leaf)
89 |
90 |
91 | if __name__ == '__main__':
92 | absltest.main()
93 |
--------------------------------------------------------------------------------
/paxml/profiling.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Expose functionalities for profiling code."""
17 |
18 | from absl import logging
19 |
20 |
21 | class Profiler:
22 | """Dummy class to capture code profiles.
23 |
24 | Note: The current implementation is a no-op.
25 | """
26 |
27 | def __init__(
28 | self,
29 | num_steps: float = 2.0,
30 | min_duration_sec: float = 1.0,
31 | default_duration_sec: float = 5.0,
32 | tag: str | None = None,
33 | max_num_hosts: int | None = None,
34 | capture_host_profile: bool = False,
35 | ) -> None:
36 | """Constructor.
37 |
38 | Args:
39 | num_steps: The number of steps to capture based on the step duration
40 | estimate that is set by calling update_step_moving_mean() successfully.
41 | min_duration_sec: The minimum duration of the profiler capture in seconds.
42 | Set to this value when the estimate step duration times num_steps is
43 | smaller than this value.
44 | default_duration_sec: The default duration of the profiler capture in
45 | seconds. Used when no step duration were sampled by calling
46 | update_step_moving_mean().
47 | tag: An optional tag to be added to the profiler trace.
48 | max_num_hosts: If max_num_hosts is unspecified, all hosts of devices will
49 | be chosen. Otherwise, at most max_num_hosts will be chosen. This option
50 | only works with pathways.
51 | capture_host_profile: Capture host CPU profile as well.
52 | """
53 | self._capture_num_steps = num_steps
54 | self._capture_min_duration_sec = min_duration_sec
55 | self._capture_default_duration_sec = default_duration_sec
56 | self._tag = tag
57 | self._max_num_hosts = max_num_hosts
58 | self._capture_host_profile = capture_host_profile
59 | self._step_duration_sec = 0.
60 | self._step_count = 0
61 |
62 | def capture_async(self) -> None:
63 | """Captures a trace asynchronously.
64 |
65 | The duration of the trace corresponds to step_duration_estimate_sec.
66 | """
67 | logging.info('Dummy profiler currently does not capture any trace.')
68 |
69 | def update_step_moving_mean(self, duration_sec: float):
70 | """Updates the step duration moving average with a step duration estimate.
71 |
72 | Args:
73 | duration_sec: The duration of the step to add in seconds.
74 | """
75 | self._step_duration_sec += duration_sec
76 | self._step_count += 1
77 |
78 | @property
79 | def step_duration_estimate_sec(self) -> float:
80 | """Estimates of the step duration in seconds.
81 |
82 | If update_step_moving_mean() has never been called, returns the default
83 | duration instead.
84 | """
85 | if not self._step_count:
86 | return self._capture_default_duration_sec
87 | return self._step_duration_sec / self._step_count
88 |
--------------------------------------------------------------------------------
/paxml/base_inference_runner.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Base API for inference runners."""
17 |
18 | from __future__ import annotations
19 |
20 | import abc
21 | import dataclasses
22 | from typing import Any
23 |
24 | from paxml import train_states
25 | from praxis import base_hyperparams
26 | from praxis import base_layer
27 | from praxis import base_model
28 | from praxis import lazy_loader
29 | from praxis import py_utils
30 | from praxis import pytypes
31 |
32 | # TFDS is slow to import, so we do it lazily.
33 | tfds = lazy_loader.LazyLoader('tfds', globals(), 'tensorflow_datasets')
34 |
35 | NestedMap = py_utils.NestedMap
36 | NestedWeightHParams = base_layer.NestedWeightHParams
37 | PRNGKey = pytypes.PRNGKey
38 | TrainState = train_states.TrainState
39 |
40 |
41 | class BaseInferenceRunner(base_hyperparams.FiddleBaseParameterizable, abc.ABC):
42 | """Abstract base class for users to override.
43 |
44 | This class is essentially a container for a functional infer method and
45 | output schema definition. It defines (1) how to run inference to generate
46 | outputs given a model and some inputs, and (2) the corresponding schema for
47 | the output.
48 |
49 | TODO(b/238220793): Currently we only write Jax native types since we do all
50 | computations in a jit-ed context. We may eventually want to support non jax-
51 | native types such as strings.
52 | """
53 | _model: Any = dataclasses.field(init=False, repr=False)
54 | model: base_model.BaseModel = None
55 |
56 | def __post_init__(self):
57 | self._model = self.model
58 |
59 | @abc.abstractmethod
60 | def infer(self, train_state: TrainState, prng_key: PRNGKey,
61 | var_weight_hparams: NestedWeightHParams,
62 | input_batch: NestedMap) -> NestedMap:
63 | """Generates some output given a model and input. Should be pmap-able."""
64 |
65 | @property
66 | @abc.abstractmethod
67 | def output_schema(self) -> NestedMap:
68 | """Returns the schema for the output to be serialized.
69 |
70 | This must be a nested map of `tfds.features.FeatureConnector` types. See
71 | https://www.tensorflow.org/datasets/api_docs/python/tfds/features/FeatureConnector
72 | for more information. The following is an example:
73 |
74 | ```
75 | return NestedMap(
76 | bucket_keys=tfds.features.Scalar(dtype=tf.int32),
77 | nested=NestedMap(
78 | field=tfds.features.Tensor(shape=(1000,), dtype=tf.int32)
79 | ),
80 | logprobs=tfds.features.Tensor(shape=(1, 32,), dtype=tf.float32),
81 | )
82 | ```
83 | """
84 |
85 | def serialize_outputs(self, outputs: NestedMap) -> list[bytes]:
86 | input_batch_dim = 0
87 | features_dict = tfds.features.FeaturesDict(self.output_schema)
88 | examples = py_utils.tree_unstack(outputs, input_batch_dim)
89 |
90 | serialized_examples = []
91 | for ex in examples:
92 | serialized_examples.append(features_dict.serialize_example(ex))
93 |
94 | return serialized_examples
95 |
--------------------------------------------------------------------------------
/paxml/tasks/lm/params/optimal_scaling.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Decoder-only language model configurations with Chinchilla-like scaling."""
17 |
18 | from paxml import experiment_registry
19 | from paxml import tasks_lib
20 | from paxml.tasks.lm.params.lm_cloud import LmCloudSpmd
21 | from praxis import layers
22 | from praxis import pax_fiddle
23 |
24 |
25 | class OptimalScaling(LmCloudSpmd):
26 | """Decoder-only language model configurations with Chinchilla-like scaling."""
27 |
28 | CHECKPOINT_POLICY = layers.AutodiffCheckpointType.SAVE_DOT_WITH_NO_BATCH_DIM
29 |
30 | # subclasses override these
31 | PERCORE_BATCH_SIZE = None
32 | NUM_LAYERS = None
33 |
34 | def task(self) -> pax_fiddle.Config[tasks_lib.SingleTask]:
35 | # pylint: disable=invalid-name
36 | assert self.NUM_LAYERS
37 | self.MODEL_DIMS = self.NUM_LAYERS * 128
38 | self.HIDDEN_DIMS = self.MODEL_DIMS * 4
39 | # pylint: enable=invalid-name
40 | return super().task()
41 |
42 |
43 | @experiment_registry.register
44 | class OptimalScaling2x2x1(OptimalScaling):
45 | NUM_LAYERS = 28
46 | PERCORE_BATCH_SIZE = 16
47 | ICI_MESH_SHAPE = [1, 4, 1]
48 |
49 |
50 | @experiment_registry.register
51 | class OptimalScaling2x2x2(OptimalScaling):
52 | NUM_LAYERS = 32
53 | PERCORE_BATCH_SIZE = 8
54 | ICI_MESH_SHAPE = [1, 8, 1]
55 |
56 |
57 | @experiment_registry.register
58 | class OptimalScaling2x2x4(OptimalScaling):
59 | NUM_LAYERS = 36
60 | PERCORE_BATCH_SIZE = 8
61 | ICI_MESH_SHAPE = [1, 16, 1]
62 |
63 |
64 | @experiment_registry.register
65 | class OptimalScaling2x4x4(OptimalScaling):
66 | NUM_LAYERS = 40
67 | PERCORE_BATCH_SIZE = 8
68 | ICI_MESH_SHAPE = [1, 32, 1]
69 |
70 |
71 | @experiment_registry.register
72 | class OptimalScaling4x4x4(OptimalScaling):
73 | NUM_LAYERS = 45
74 | PERCORE_BATCH_SIZE = 8
75 | ICI_MESH_SHAPE = [1, 64, 1]
76 |
77 |
78 | @experiment_registry.register
79 | class OptimalScaling4x4x8(OptimalScaling):
80 | NUM_LAYERS = 50
81 | PERCORE_BATCH_SIZE = 4
82 | ICI_MESH_SHAPE = [1, 128, 1]
83 |
84 |
85 | @experiment_registry.register
86 | class OptimalScaling4x8x8(OptimalScaling):
87 | NUM_LAYERS = 56
88 | PERCORE_BATCH_SIZE = 2
89 | ICI_MESH_SHAPE = [1, 64, 4]
90 |
91 |
92 | @experiment_registry.register
93 | class OptimalScaling4x8x16(OptimalScaling):
94 | NUM_LAYERS = 64
95 | PERCORE_BATCH_SIZE = 2
96 | ICI_MESH_SHAPE = [1, 128, 4]
97 |
98 |
99 | @experiment_registry.register
100 | class OptimalScaling4x16x16(OptimalScaling):
101 | NUM_LAYERS = 64
102 | PERCORE_BATCH_SIZE = 1
103 | ICI_MESH_SHAPE = [1, 256, 4]
104 |
105 |
106 | @experiment_registry.register
107 | class OptimalScaling4x16x32(OptimalScaling):
108 | NUM_LAYERS = 64
109 | PERCORE_BATCH_SIZE = 1
110 | ICI_MESH_SHAPE = [1, 512, 4]
111 |
112 |
113 | @experiment_registry.register
114 | class OptimalScaling4x24x32(OptimalScaling):
115 | NUM_LAYERS = 64
116 | PERCORE_BATCH_SIZE = 1
117 | ICI_MESH_SHAPE = [1, 768, 4]
118 |
--------------------------------------------------------------------------------
/paxml/tools/fiddle/codegen_highlevel_parameterization.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Codegen pass to trace high-level settings and make those useable."""
17 |
18 | import dataclasses
19 | from typing import Any
20 |
21 | from fiddle import daglish
22 | from fiddle._src.codegen.auto_config import code_ir
23 | from fiddle.codegen.auto_config import experimental_top_level_api
24 | from paxml.tools.fiddle import codegen_pax_code_ir
25 |
26 |
27 | @dataclasses.dataclass(frozen=True)
28 | class HighlevelParameterization(experimental_top_level_api.CodegenPass):
29 | """Parameterizes fixtures by a highlevel settings object."""
30 |
31 | lowercasing: bool = False
32 |
33 | def __call__(self, task: Any, **pass_kwargs: Any) -> Any:
34 | assert isinstance(task, codegen_pax_code_ir.PaxCodegenTask)
35 | all_fns = task.top_level_call.all_fixture_functions()
36 |
37 | def process_fn(fn: code_ir.FixtureFunction):
38 | self_name = code_ir.Name("self", is_generated=False)
39 | fn.parameters.insert(0, code_ir.Parameter(self_name))
40 |
41 | def add_self_to_calls(value, state: daglish.State):
42 | value = state.map_children(value)
43 |
44 | # Convert calls to sub-fixtures like model_fixture() to
45 | # self.model_fixture().
46 | if isinstance(value, code_ir.SymbolOrFixtureCall):
47 | if isinstance(value.symbol_expression, code_ir.FixtureReference):
48 | value.symbol_expression = code_ir.AttributeExpression(
49 | base=code_ir.VariableReference(self_name),
50 | attribute=value.symbol_expression.name.value,
51 | )
52 |
53 | # Convert any instances of highlevel variables to relevant expressions.
54 | elif hasattr(value, "__highlevel_name__"):
55 | attribute_name = value.__highlevel_name__
56 | if self.lowercasing:
57 | attribute_name = attribute_name.lower()
58 | task.highlevel_accesses[attribute_name] = value
59 | return code_ir.AttributeExpression(
60 | base=code_ir.VariableReference(self_name),
61 | attribute=attribute_name,
62 | )
63 |
64 | # Process dict keys too (normally ignored by daglish).
65 | if isinstance(value, dict):
66 | converted = {}
67 | for key, sub_value in value.items():
68 | if hasattr(key, "__highlevel_name__"):
69 | attribute_name = key.__highlevel_name__
70 | if self.lowercasing:
71 | attribute_name = attribute_name.lower()
72 | task.highlevel_accesses[attribute_name] = key
73 | key = code_ir.AttributeExpression(
74 | base=code_ir.VariableReference(self_name),
75 | attribute=attribute_name,
76 | )
77 | converted[key] = sub_value
78 | return converted
79 |
80 | return value
81 |
82 | fn.replace_with(daglish.MemoizedTraversal.run(add_self_to_calls, fn))
83 |
84 | for fn in all_fns:
85 | process_fn(fn)
86 |
87 | return task
88 |
--------------------------------------------------------------------------------
/paxml/pip_package/Dockerfile:
--------------------------------------------------------------------------------
1 | ARG cpu_base_image="ubuntu:22.04"
2 | ARG base_image=$cpu_base_image
3 | FROM $base_image
4 |
5 | LABEL maintainer="Pax team "
6 |
7 | # Re-declare args because the args declared before FROM can't be used in any
8 | # instruction after a FROM.
9 | ARG cpu_base_image="ubuntu:22.04"
10 | ARG base_image=$cpu_base_image
11 | ARG wheel_folder
12 | ENV WHEEL_FOLDER $wheel_folder
13 | ENV PYTHON_VERSION="3"
14 | ENV PYTHON_MINOR_VERSION="10"
15 |
16 | # Pick up some TF dependencies
17 | RUN apt update && DEBIAN_FRONTEND=noninteractive apt install -y --no-install-recommends software-properties-common
18 | RUN apt update && DEBIAN_FRONTEND=noninteractive apt install -y --no-install-recommends \
19 | build-essential \
20 | curl \
21 | git \
22 | pkg-config \
23 | rename \
24 | rsync \
25 | unzip \
26 | vim \
27 | && \
28 | apt-get clean && \
29 | rm -rf /var/lib/apt/lists/*
30 |
31 | # Install python 3.10
32 | RUN apt-get update && apt-get install -y \
33 | python3 python3-dev python3-pip python3-venv && \
34 | rm -rf /var/lib/apt/lists/* && \
35 | python3.10 -m pip install pip --upgrade && \
36 | update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.10 0
37 |
38 | # Install pip package
39 | RUN pip3 install pip-tools
40 |
41 | ARG bazel_version=5.1.1
42 | # This is to install bazel, for development purposes.
43 | ENV BAZEL_VERSION ${bazel_version}
44 | RUN mkdir /bazel && \
45 | cd /bazel && \
46 | curl -H "User-Agent: Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/57.0.2987.133 Safari/537.36" -fSsL -O https://github.com/bazelbuild/bazel/releases/download/$BAZEL_VERSION/bazel-$BAZEL_VERSION-installer-linux-x86_64.sh && \
47 | curl -H "User-Agent: Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/57.0.2987.133 Safari/537.36" -fSsL -o /bazel/LICENSE.txt https://raw.githubusercontent.com/bazelbuild/bazel/master/LICENSE && \
48 | chmod +x bazel-*.sh && \
49 | ./bazel-$BAZEL_VERSION-installer-linux-x86_64.sh && \
50 | cd / && \
51 | rm -f /bazel/bazel-$BAZEL_VERSION-installer-linux-x86_64.sh
52 |
53 | COPY . /paxml
54 |
55 | RUN mkdir -p $WHEEL_FOLDER && cd paxml && \
56 | git rev-parse HEAD > $WHEEL_FOLDER/paxml_commit.txt
57 |
58 | RUN git clone https://github.com/google/praxis.git && \
59 | cd praxis && git rev-parse HEAD > $WHEEL_FOLDER/praxis_commit.txt
60 |
61 | RUN cp -r praxis/praxis /paxml/
62 |
63 | RUN cd /praxis && \
64 | pip-compile --quiet requirements.in \
65 | --output-file $WHEEL_FOLDER/praxis_requirements.txt
66 |
67 | RUN cd /paxml && \
68 | sed -i 's/praxis/#praxis/' requirements.in && \
69 | pip-compile --quiet requirements.in \
70 | /praxis/requirements.in \
71 | --output-file $WHEEL_FOLDER/paxml_requirements.txt
72 |
73 | RUN pip3 install --no-deps -r paxml/paxml/pip_package/requirements.txt
74 |
75 | RUN mv paxml/paxml/pip_package /paxml/
76 | RUN cd /paxml && bash pip_package/build.sh
77 | #TODO:enable -praxis/layers:normalizations_test once the new Lingvo pip package is released
78 | RUN cd praxis && \
79 | bazel test \
80 | --test_output=all \
81 | --test_verbose_timeout_warnings \
82 | -- \
83 | praxis/... \
84 | -praxis/layers:attentions_test \
85 | -praxis/layers:convolutions_test \
86 | -praxis/layers:ctc_objectives_test \
87 | -praxis/layers:embedding_softmax_test \
88 | -praxis/layers:models_test \
89 | -praxis/layers:ngrammer_test \
90 | -praxis/layers:normalizations_test \
91 | -praxis/layers:transformer_models_test \
92 | -praxis/layers:transformers_test
93 |
94 | RUN cd praxis && bash praxis/pip_package/build_pip_pkg.sh
95 |
96 | WORKDIR /
97 |
98 | CMD ["/bin/bash"]
99 |
--------------------------------------------------------------------------------
/paxml/tools/fiddle/codegen_tracer.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tracer classes for the Pax-specific codegen."""
17 |
18 | # Note: Fiddle and Pax devs are in collaboration; please generally do not import
19 | # private libraries from Fiddle.
20 |
21 | import dataclasses
22 | import enum
23 | from typing import Type, TypeVar
24 |
25 | from fiddle import daglish
26 | from paxml import base_experiment
27 |
28 |
29 | @dataclasses.dataclass(frozen=True)
30 | class BoolTracer:
31 | """Special-case tracer for bool values, since we can't inherit from bool."""
32 |
33 | name: str
34 | value: bool
35 |
36 | @property
37 | def __highlevel_name__(self):
38 | return self.name
39 |
40 | def __bool__(self):
41 | return self.value
42 |
43 | def __eq__(self, other):
44 | if isinstance(other, bool):
45 | return self.value == other
46 | return super().__eq__(other)
47 |
48 |
49 | def make_tracer(
50 | name, value, allowed_types=(bool, int, float, str, list, tuple)
51 | ):
52 | """Wraps a value in a tracer object, that has a name."""
53 | typ = type(value)
54 | if not issubclass(typ, allowed_types):
55 | raise ValueError(
56 | f"Type {typ} is not allowed. If it seems to work with "
57 | "the subclassed tracers, please add it to allowed_types "
58 | "and write a unit test."
59 | )
60 |
61 | if typ == bool:
62 | return BoolTracer(name, value)
63 | wrapped = type(
64 | f"Wrapped{typ.__name__}_{name}", (typ,), {"__highlevel_name__": name}
65 | )
66 | return wrapped(value)
67 |
68 |
69 | class TracerMixin:
70 | """Mixin that will wrap experiments' property accessors with traced versions.
71 |
72 | This means that when there are high-level settings as attributes on a
73 | BaseExperiment instance, a traced value will be returned. This can then be
74 | intercepted in code generation, resulting in partially abstracted code.
75 | """
76 |
77 | __trace_names__: set[str]
78 |
79 | def __init__(self, *args, **kwargs):
80 | super().__init__(*args, **kwargs)
81 | object.__setattr__(self, "__trace_names__", set())
82 |
83 | def __getattribute__(self, name):
84 | result = super().__getattribute__(name)
85 | if name in {"__trace_names__"}:
86 | return result
87 | elif isinstance(result, enum.Enum):
88 | return result
89 | elif isinstance(result, (bool, int, float, str, list, tuple)):
90 | self.__trace_names__.add(name)
91 | return make_tracer(name, result)
92 | else:
93 | return result
94 |
95 |
96 | def make_subclass_mixin(
97 | experiment_cls: Type[base_experiment.BaseExperiment],
98 | ):
99 | """Creates a dynamic subclass of an experiment that adds TracerMixin."""
100 | if not issubclass(experiment_cls, base_experiment.BaseExperiment):
101 | raise TypeError("Please pass a subclass of BaseExperiment.")
102 | cls_name = experiment_cls.__name__
103 | return type(f"{cls_name}Traced", (experiment_cls, TracerMixin), {})
104 |
105 |
106 | _T = TypeVar("_T")
107 |
108 |
109 | def remove_tracers(root: _T) -> _T:
110 | def transform(value, state: daglish.State):
111 | if isinstance(value, BoolTracer):
112 | value = bool(value)
113 | elif hasattr(value, "__highlevel_name__"):
114 | value = type(value).__bases__[0](value)
115 | return state.map_children(value)
116 |
117 | return daglish.MemoizedTraversal.run(transform, root)
118 |
--------------------------------------------------------------------------------
/paxml/tools/dump_input_specs_lib.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Input specification retrieval from either provider or input pipeline."""
17 |
18 | import pprint
19 |
20 | from absl import logging
21 | import jax
22 | from paxml import base_experiment
23 | from praxis import base_hyperparams
24 | from praxis import base_input
25 | from praxis import pytypes
26 | import pyglove as pg
27 |
28 | NestedShapeDtypeStruct = pytypes.NestedShapeDtypeStruct
29 | instantiate = base_hyperparams.instantiate
30 |
31 |
32 | def extract_input_specs(
33 | experiment_config: base_experiment.BaseExperiment,
34 | ) -> tuple[NestedShapeDtypeStruct | None, NestedShapeDtypeStruct | None]:
35 | """Extracts the input specs for a given experiment config."""
36 | logging.info('Starting extraction of input specs info.')
37 |
38 | # Input specs from input_specs_provider.
39 | input_specs_provider_p = experiment_config.get_input_specs_provider_params()
40 | input_specs_from_provider = None
41 | if not isinstance(input_specs_provider_p,
42 | base_input.DatasetInputSpecsProvider):
43 | logging.info('Extracting input specs info from provider.')
44 | specs_provider = instantiate(input_specs_provider_p)
45 | input_specs_from_provider = specs_provider.get_input_specs()
46 |
47 | # NOTE(daiyip): putting `training_dataset()` and `instantiate(train_input_p)`
48 | # under an AutoML context allows dynamic evaluation of hyperparameters that is
49 | # to be swept. The first values of all `pg.oneof` will be used.
50 | with pg.hyper.DynamicEvaluationContext(require_hyper_name=True).collect():
51 | # Input specs from experiment config
52 | logging.info('Extracting input specs info from input pipeline.')
53 | try:
54 | # Clone it since we may mutate a few attributes below.
55 | train_input_p = experiment_config.train_datasets()[0].clone()
56 | except ValueError:
57 | logging.info('Could not find a training input pipeline for %s',
58 | experiment_config)
59 | train_input_p = None
60 |
61 | if train_input_p is None:
62 | return input_specs_from_provider, None
63 |
64 | # Attempt at reducing loading time when using Lingvo input.
65 | if isinstance(train_input_p, base_input.LingvoInputAdaptor):
66 | train_input_p.input.num_batcher_threads = 1
67 | train_input_p.input.file_parallelism = 1
68 | train_input_p.input.file_buffer_size = 32
69 |
70 | logging.info('Instantiating input pipeline...')
71 | input_pipeline = instantiate(train_input_p)
72 | logging.info('Retrieving specs from input pipeline...')
73 | input_specs_from_input_pipeline = jax.tree_util.tree_map(
74 | lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype),
75 | input_pipeline.get_next_padded())
76 |
77 | return input_specs_from_provider, input_specs_from_input_pipeline
78 |
79 |
80 | def specs_to_string(
81 | experiment_name: str,
82 | specs: tuple[NestedShapeDtypeStruct | None, NestedShapeDtypeStruct | None],
83 | ) -> str:
84 | """Converts input specs into a readable string."""
85 | pp = pprint.PrettyPrinter(indent=2)
86 | specs_provider, specs_pipeline = specs
87 | out_lst = []
88 | out_lst.append(experiment_name)
89 | out_lst.append('From InputSpecsProvider:')
90 | out_lst.append(pp.pformat(specs_provider))
91 | out_lst.append('From training input pipeline:')
92 | out_lst.append(pp.pformat(specs_pipeline))
93 | out_lst.append('\n\n')
94 | out_str = '\n'.join(out_lst)
95 | logging.info(out_str)
96 | return out_str
97 |
--------------------------------------------------------------------------------
/paxml/tasks/lm/params/c4_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for GPT-3 models defined in c4.py."""
17 | import os
18 |
19 | from absl.testing import absltest
20 | import fiddle as fdl
21 | import jax
22 | from paxml.tasks.lm.params import c4
23 | from praxis import layers
24 | from praxis import optimizers
25 | from praxis import schedules
26 | from praxis import test_utils
27 |
28 | prev_xla_flags = None
29 |
30 |
31 | def setUpModule():
32 | global prev_xla_flags
33 | prev_xla_flags = os.getenv("XLA_FLAGS")
34 | flags_str = prev_xla_flags or ""
35 | # Don't override user-specified device count, or other XLA flags.
36 | if "xla_force_host_platform_device_count" not in flags_str:
37 | os.environ["XLA_FLAGS"] = (
38 | flags_str + " --xla_force_host_platform_device_count=768"
39 | )
40 | # Clear any cached backends so new CPU backend will pick up the env var.
41 | jax.extend.backend.get_backend.cache_clear()
42 |
43 |
44 | def tearDownModule():
45 | if prev_xla_flags is None:
46 | del os.environ["XLA_FLAGS"]
47 | else:
48 | os.environ["XLA_FLAGS"] = prev_xla_flags
49 | jax.extend.backend.get_backend.cache_clear()
50 |
51 |
52 | class C4Test(test_utils.TestCase):
53 |
54 | def test_gpt3_mlperf_bs1p5k_config(self):
55 | config = c4.C4SpmdPipelineGpt3AdamMLPerfHPBS1p5k768Replicas()
56 | task_p = config.task()
57 |
58 | # Model architecture
59 | lm_tpl = task_p.model.lm_tpl
60 | self.assertEqual(config.MAX_SEQ_LEN, 2048)
61 | self.assertEqual(config.NUM_LAYERS, 96)
62 | self.assertEqual(config.NUM_HEADS, 96)
63 | self.assertEqual(lm_tpl.model_dims, 12288)
64 | self.assertEqual(config.HIDDEN_DIMS, 12288 * 4)
65 | self.assertGreaterEqual(lm_tpl.vocab_size, 50257)
66 | self.assertEqual(
67 | lm_tpl.position_emb_tpl.cls,
68 | layers.embedding_softmax.TrainablePositionalEmbedding,
69 | )
70 |
71 | global_batch_size = int(
72 | config.PERCORE_BATCH_SIZE * jax.device_count() + 1e-6
73 | )
74 | self.assertEqual(global_batch_size, 1536)
75 | self.assertEqual(
76 | task_p.train.eval_interval_steps * global_batch_size, 24 * 1024
77 | )
78 |
79 | # Early stopping fn
80 | self.assertEqual(task_p.early_stopping_fn.cls, c4.EarlyStoppingFn)
81 | self.assertAlmostEqual(task_p.early_stopping_fn.target_log_pplx, 2.69)
82 |
83 | # optimizer and HPs
84 | optimizer_p = task_p.train.learner.optimizer
85 | self.assertEqual(fdl.get_callable(optimizer_p), optimizers.Adam)
86 | self.assertAlmostEqual(optimizer_p.weight_decay, 0.1)
87 | self.assertAlmostEqual(optimizer_p.epsilon, 1e-8)
88 | self.assertAlmostEqual(optimizer_p.beta1, 0.9)
89 | self.assertAlmostEqual(optimizer_p.beta2, 0.95)
90 | self.assertAlmostEqual(optimizer_p.clip_gradient_norm_to_value, 1.0)
91 |
92 | # LR schedule
93 | lr_schedule = optimizer_p.lr_schedule
94 | self.assertEqual(lr_schedule.cls, schedules.LinearRampupCosineDecay)
95 | self.assertEqual(lr_schedule.warmup_steps * global_batch_size, 265 * 1536)
96 | self.assertEqual(lr_schedule.decay_start, lr_schedule.warmup_steps + 1)
97 | self.assertEqual(lr_schedule.decay_end * global_batch_size, 108600 * 1536)
98 | self.assertAlmostEqual(optimizer_p.learning_rate, 2e-5)
99 | self.assertAlmostEqual(lr_schedule.min_ratio, 0.1)
100 | self.assertAlmostEqual(lr_schedule.max, 1.0)
101 |
102 |
103 | if __name__ == "__main__":
104 | absltest.main()
105 |
--------------------------------------------------------------------------------
/paxml/tools/fiddle/codegen_external_init_checkpoint_fns_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for codegen_external_init_checkpoint_fns."""
17 |
18 | from absl.testing import absltest
19 | from fiddle._src.codegen.auto_config import init_task
20 | from fiddle._src.codegen.auto_config import ir_printer
21 | from fiddle.experimental import visualize
22 | from paxml.tools.fiddle import codegen_external_init_checkpoint_fns
23 | from paxml.tools.fiddle import codegen_pax_code_ir
24 | from paxml.tools.fiddle import test_fixtures
25 |
26 |
27 | class InitCheckpointRulesFromOtherTaskTest(absltest.TestCase):
28 |
29 | def test_creates_calls(self):
30 | config = test_fixtures.SampleExperimentWithInitFromCheckpointRules().task()
31 | config = visualize.with_defaults_trimmed(config, remove_deep_defaults=True)
32 | task = init_task.init_task(config)
33 | task = codegen_pax_code_ir.PaxCodegenTask(
34 | original_config=task.original_config,
35 | top_level_call=task.top_level_call,
36 | )
37 | codegen_pass = (
38 | codegen_external_init_checkpoint_fns.InitCheckpointRulesFromOtherTask()
39 | )
40 | task = codegen_pass(
41 | task,
42 | init_checkpoint_experiments={
43 | "/path/to/my/checkpoint": (
44 | test_fixtures.SampleExperimentWithInputSpecsProvider
45 | )
46 | },
47 | )
48 |
49 | debug_str = ir_printer.format_task(task)
50 | self.assertIn(
51 | "task_p=call:.task(*[[]], **{})>",
54 | debug_str,
55 | )
56 | self.assertIn(
57 | "input_specs_provider_p=call:.get_input_specs_provider_params(*[[]], **{})>",
60 | debug_str,
61 | )
62 |
63 | def test_errors_unused_rule(self):
64 | config = test_fixtures.SampleExperiment().task()
65 | config = visualize.with_defaults_trimmed(config, remove_deep_defaults=True)
66 | task = init_task.init_task(config)
67 | task = codegen_pax_code_ir.PaxCodegenTask(
68 | original_config=task.original_config,
69 | top_level_call=task.top_level_call,
70 | )
71 | codegen_pass = (
72 | codegen_external_init_checkpoint_fns.InitCheckpointRulesFromOtherTask()
73 | )
74 | with self.assertRaisesRegex(
75 | ValueError, r"Didn't encounter.*path/to/my/checkpoint"
76 | ):
77 | codegen_pass(
78 | task,
79 | init_checkpoint_experiments={
80 | "/path/to/my/checkpoint": (
81 | test_fixtures.SampleExperimentWithInputSpecsProvider
82 | )
83 | },
84 | )
85 |
86 | def test_errors_unmatched_rule(self):
87 | config = test_fixtures.SampleExperimentWithInitFromCheckpointRules().task()
88 | config = visualize.with_defaults_trimmed(config, remove_deep_defaults=True)
89 | task = init_task.init_task(config)
90 | task = codegen_pax_code_ir.PaxCodegenTask(
91 | original_config=task.original_config,
92 | top_level_call=task.top_level_call,
93 | )
94 | codegen_pass = (
95 | codegen_external_init_checkpoint_fns.InitCheckpointRulesFromOtherTask()
96 | )
97 | with self.assertRaisesRegex(
98 | ValueError, r"No task for checkpoint /path/to/my/checkpoint"
99 | ):
100 | codegen_pass(
101 | task,
102 | init_checkpoint_experiments={},
103 | )
104 |
105 |
106 | if __name__ == "__main__":
107 | absltest.main()
108 |
--------------------------------------------------------------------------------
/paxml/tools/fiddle/remove_sharding_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for remove_sharding."""
17 |
18 | import dataclasses
19 | import typing
20 |
21 | from absl.testing import absltest
22 | import fiddle as fdl
23 | from fiddle import daglish
24 | from paxml.tools.fiddle import remove_sharding
25 | from praxis import base_layer
26 | from praxis import pax_fiddle
27 |
28 |
29 | class TestLayer(base_layer.BaseLayer):
30 | x: int = 12
31 |
32 |
33 | def fake_config():
34 | config = pax_fiddle.Config(TestLayer, x=14)
35 | config.weight_split_dims_mapping.wt = ("foo", "bar")
36 | return config
37 |
38 |
39 | def _is_sharding_config(typ):
40 | origin = typing.get_origin(typ)
41 | if not isinstance(origin, type) or not issubclass(origin, pax_fiddle.Config):
42 | return False
43 | args = typing.get_args(typ)
44 | if len(args) != 1:
45 | return False
46 | return args[0] in remove_sharding.SHARDING_TYPES
47 |
48 |
49 | class RemoveShardingTest(absltest.TestCase):
50 |
51 | def test_base_layer_sharding_fields(self):
52 | detected = {
53 | name
54 | for name, typ in typing.get_type_hints(base_layer.BaseLayer).items()
55 | if _is_sharding_config(typ)
56 | }
57 | self.assertEqual(detected, remove_sharding.BASE_LAYER_SHARDING_FIELDS)
58 |
59 | def test_is_sharding_annotation(self):
60 | self.assertTrue(
61 | remove_sharding._is_sharding_annotation(
62 | value=pax_fiddle.Config(base_layer.BaseLayer.WeightSharding),
63 | path=(),
64 | )
65 | )
66 |
67 | def test_is_sharding_annotation_works_with_function(self):
68 | self.assertFalse(
69 | remove_sharding._is_sharding_annotation(
70 | value=pax_fiddle.Config(fake_config),
71 | path=(),
72 | )
73 | )
74 |
75 | def test_is_sharding_annotation_by_path(self):
76 | # This is just to catch if the user forgets to inherit from WeightSharding /
77 | # ActivationSharding.
78 | class Foo(base_layer.BaseLayer):
79 |
80 | @dataclasses.dataclass
81 | class MyActivationSplitDimsMapping:
82 | bar_sharding: list[str] = dataclasses.field(default_factory=list)
83 |
84 | activation_split_dims_mapping: MyActivationSplitDimsMapping = (
85 | dataclasses.field(default_factory=MyActivationSplitDimsMapping)
86 | )
87 |
88 | config = pax_fiddle.Config(
89 | Foo,
90 | activation_split_dims_mapping=pax_fiddle.Config(
91 | Foo.MyActivationSplitDimsMapping, bar_sharding=["baz", "qux"]
92 | ),
93 | )
94 | self.assertTrue(
95 | remove_sharding._is_sharding_annotation(
96 | value=config.activation_split_dims_mapping,
97 | path=(daglish.Attr("activation_split_dims_mapping"),),
98 | )
99 | )
100 |
101 | def test_remove_sharding(self):
102 | config = fake_config()
103 | without_sharding = remove_sharding.remove_sharding(config=config)
104 | self.assertIn("weight_split_dims_mapping", fdl.ordered_arguments(config))
105 | self.assertNotIn(
106 | "weight_split_dims_mapping", fdl.ordered_arguments(without_sharding)
107 | )
108 |
109 | # Ensure other attributes were not deleted.
110 | self.assertEqual(config.x, 14)
111 | self.assertEqual(without_sharding.x, 14)
112 |
113 | def test_replace_with_defaults(self):
114 | config = fake_config()
115 | without_sharding = remove_sharding.remove_sharding(
116 | config=config, replace_with_default=True
117 | )
118 | without_sharding.weight_split_dims_mapping.wt = ()
119 |
120 |
121 | if __name__ == "__main__":
122 | absltest.main()
123 |
--------------------------------------------------------------------------------
/paxml/tasks/lm/params/BUILD:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | # Description:
17 | # Language modeling model configurations.
18 |
19 | load("//paxml:paxml.bzl", "pytype_library")
20 | load("//paxml:paxml.bzl", "py_strict_test")
21 | load("//paxml:build_defs.bzl", "pax_targets")
22 | load("//praxis:build-visibility.bzl", "JAX_VISIBILITY")
23 |
24 | package(default_visibility = JAX_VISIBILITY)
25 |
26 | licenses(["notice"])
27 |
28 | pytype_library(
29 | name = "params",
30 | srcs = [
31 | "bert.py",
32 | "c4.py",
33 | "c4_multislice.py",
34 | "lm_cloud.py",
35 | "optimal_scaling.py",
36 | ],
37 | tags = ["keep_dep"],
38 | deps = [
39 | ":lm_cloud",
40 | # Implicit absl.logging dependency.
41 | # Implicit etils dependency.
42 | # Implicit fiddle dependency.
43 | # Implicit jax dependency.
44 | "//paxml:base_experiment",
45 | "//paxml:experiment_registry",
46 | "//paxml:seqio_input",
47 | "//paxml:tasks_lib",
48 | "//paxml:trainer_lib",
49 | "//paxml/tasks/lm:input_generator",
50 | "//paxml/tasks/lm:model_params",
51 | "//praxis:base_hyperparams",
52 | "//praxis:base_input",
53 | "//praxis:base_layer",
54 | "//praxis:optimizers",
55 | "//praxis:pax_fiddle",
56 | "//praxis:schedules",
57 | "//praxis/layers",
58 | "//praxis/layers:transformers",
59 | # Implicit seqio dependency.
60 | # Implicit t5.data dependency.
61 | # Implicit t5.data.preprocessors dependency.
62 | ],
63 | )
64 |
65 | pytype_library(
66 | name = "nvidia",
67 | srcs = [
68 | "nvidia.py",
69 | ],
70 | tags = ["keep_dep"],
71 | deps = [
72 | ":lm_cloud",
73 | ":params",
74 | # Implicit fiddle dependency.
75 | # Implicit jax dependency.
76 | # Implicit numpy dependency.
77 | "//paxml:experiment_registry",
78 | "//paxml:tasks_lib",
79 | "//paxml/tasks/lm:model_params",
80 | "//praxis:base_layer",
81 | "//praxis:optimizers",
82 | "//praxis:pax_fiddle",
83 | "//praxis:schedules",
84 | "//praxis/layers",
85 | "//praxis/layers:activations",
86 | "//praxis/layers:glam",
87 | "//praxis/layers:gpu_fast_attention",
88 | "//praxis/layers:grok",
89 | "//praxis/layers:transformers",
90 | ],
91 | )
92 |
93 | pytype_library(
94 | name = "lm_cloud",
95 | srcs = [
96 | "lm_cloud.py",
97 | ],
98 | deps = [
99 | # Implicit absl.logging dependency.
100 | # Implicit jax dependency.
101 | "//paxml:base_experiment",
102 | "//paxml:experiment_registry",
103 | "//paxml:tasks_lib",
104 | "//paxml/tasks/lm:input_generator",
105 | "//paxml/tasks/lm:model_params",
106 | "//praxis:base_layer",
107 | "//praxis:pax_fiddle",
108 | "//praxis/layers",
109 | ],
110 | )
111 |
112 | py_strict_test(
113 | name = "c4_test",
114 | timeout = "long",
115 | srcs = ["c4_test.py"],
116 | deps = [
117 | ":params",
118 | # Implicit absl.testing.absltest.absltest dependency.
119 | # Implicit fiddle dependency.
120 | # Implicit jax dependency.
121 | "//praxis:optimizers",
122 | "//praxis:schedules",
123 | "//praxis:test_utils",
124 | "//praxis/layers",
125 | ],
126 | )
127 |
128 | pax_targets(
129 | experiments = [
130 | ":params",
131 | ],
132 | )
133 |
134 | pax_targets(
135 | experiments = [
136 | ":nvidia",
137 | ],
138 | prefix_name = "gpu",
139 | )
140 |
--------------------------------------------------------------------------------
/paxml/xla_passthrough.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Helpers for handling inputs unsupported by XLA."""
17 |
18 | from absl import logging
19 | import jax
20 | import numpy as np
21 |
22 |
23 | def split_out_xla_unsupported_batch(batch, partitioning_spec=None):
24 | """Splits out values not supported by XLA (such as strings) from the batch.
25 |
26 | This is used to pass the unsupported values through a different channel. See
27 | also `merge_back_xla_unsupported_batch`.
28 |
29 | Args:
30 | batch: The input (possibly nested) dictionary.
31 | partitioning_spec: The dictionary with partitioning information or None.
32 |
33 | Returns:
34 | A tuple of the following elements.
35 | batch: The original batch dictionary with unsupported elements removed.
36 | unsupported_batch: A dictionary with only the unsupported elements.
37 | partitioning_spec: The original partitioning_spec with unsupported elements
38 | removed.
39 | """
40 | unsupported_batch = {}
41 | new_partitioning_spec = {}
42 |
43 | for k, v in batch.items():
44 | if hasattr(v, 'items'):
45 | nested_batch, nested_unsupported_batch, nested_partitioning_spec = (
46 | split_out_xla_unsupported_batch(
47 | v,
48 | partitioning_spec=partitioning_spec.get(k)
49 | if partitioning_spec
50 | else None,
51 | )
52 | )
53 | if nested_unsupported_batch:
54 | batch[k] = nested_batch
55 | unsupported_batch[k] = nested_unsupported_batch
56 | if (
57 | partitioning_spec
58 | and k in partitioning_spec
59 | and nested_partitioning_spec
60 | ):
61 | new_partitioning_spec[k] = nested_partitioning_spec
62 | continue
63 |
64 | if not np.issubdtype(v.dtype, np.str_) and not np.issubdtype(
65 | v.dtype, np.object_
66 | ):
67 | continue
68 |
69 | unsupported_batch[k] = v
70 |
71 | # If no unsupported keys were detected, return out the original batch object
72 | # without modifying it.
73 | if not unsupported_batch:
74 | return batch, {}, partitioning_spec
75 |
76 | # Similarly for the multi-host case, which is not supported yet: return out
77 | # the original batch object without modifying it.
78 | if jax.process_count() > 1 and unsupported_batch:
79 | # TODO(b/279795947): Support xla passthrough for multihost eval.
80 | raise NotImplementedError(
81 | (
82 | 'Unsupported inputs (with keys %s) were detected, but running with'
83 | ' more than one host. Forwarding these keys is currently not'
84 | ' supported (but may be supported in the future).'
85 | )
86 | % unsupported_batch.keys(),
87 | )
88 |
89 | batch = {k: v for k, v in batch.items() if k not in unsupported_batch}
90 | if partitioning_spec is not None:
91 | new_partitioning_spec.update(
92 | {
93 | k: v
94 | for k, v in partitioning_spec.items()
95 | if k not in unsupported_batch
96 | }
97 | )
98 | else:
99 | new_partitioning_spec = None
100 | return batch, unsupported_batch, new_partitioning_spec
101 |
102 |
103 | def merge_back_xla_unsupported_batch(out, xla_unsupported_batch):
104 | """Adds back unsupported parts of the batch into out.
105 |
106 | This is done in case process_decode_out or other parts of the code want to
107 | make use of the unsupported parts.
108 |
109 | Args:
110 | out: The output dictionary without unsupported parts.
111 | xla_unsupported_batch: A dictionary with only the unsupported elements, if
112 | any.
113 | """
114 | if xla_unsupported_batch:
115 | out.update(xla_unsupported_batch)
116 |
--------------------------------------------------------------------------------
/paxml/tasks/lm/input_generator_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Test lm input generator."""
17 |
18 | from absl.testing import absltest
19 | from absl.testing import parameterized
20 | import numpy as np
21 | from paxml import test_helper
22 | from paxml.tasks.lm import input_generator
23 | from praxis import base_hyperparams
24 | from praxis import pax_fiddle
25 | from praxis import test_utils
26 | import tensorflow.compat.v2 as tf
27 |
28 | instantiate = base_hyperparams.instantiate
29 |
30 |
31 | class InputTest(test_utils.TestCase):
32 |
33 | # We use first id and seq length as fingerprint to identify each shard
34 | # has the right elements.
35 | def _get_first_id_and_lengths(self, batch):
36 | return batch.labels[:, 1], np.sum(batch.segment_ids, axis=1, dtype=np.int32)
37 |
38 | def test_full(self):
39 | p = pax_fiddle.Config(input_generator.TFRecordBertInput)
40 | # There are 10 examples in this test data file.
41 | p.input_file = test_helper.test_src_dir_path('tasks/lm/testdata/tfrecords')
42 | p.batch_size = 10
43 |
44 | inp = instantiate(p)
45 | batch = inp.get_next()
46 | ids, lengths = self._get_first_id_and_lengths(batch)
47 | expected_ids = np.array(
48 | [2003, 1996, 1996, 2049, 3748, 1007, 4862, 1996, 2004, 2002],
49 | dtype=np.int32)
50 | expected_lengths = np.array([35, 239, 55, 56, 511, 511, 161, 43, 416, 511],
51 | dtype=np.int32)
52 | self.assertArraysEqual(ids, expected_ids)
53 | self.assertArraysEqual(lengths, expected_lengths)
54 |
55 | def test_remask(self):
56 | p = pax_fiddle.Config(input_generator.TFRecordBertInput)
57 | # There are 10 examples in this test data file.
58 | p.input_file = test_helper.test_src_dir_path('tasks/lm/testdata/tfrecords')
59 | p.batch_size = 10
60 | p.is_training = True
61 | p.remask = True
62 | p.mlm_augmenter.Set(mask_token_id=103, vocab_size=8000)
63 |
64 | inp = instantiate(p)
65 | batch = inp.get_next()
66 | ids, lengths = self._get_first_id_and_lengths(batch)
67 | self.assertEqual(ids.shape, (10,))
68 | self.assertEqual(lengths.shape, (10,))
69 |
70 | @parameterized.parameters(True, False)
71 | def test_sharded(self, provide_data_size):
72 | p = pax_fiddle.Config(input_generator.TFRecordBertInput)
73 | # There are 10 examples in this test data file.
74 | p.input_file = test_helper.test_src_dir_path('tasks/lm/testdata/tfrecords')
75 | p.batch_size = 4
76 | p.eval_data_size = 10 if provide_data_size else 0
77 | sharded_inputs = []
78 | for i in range(4):
79 | local_p = p.clone().set(infeed_host_index=i, num_infeed_hosts=4)
80 | sharded_inputs.append(instantiate(local_p))
81 |
82 | # This is the same as in test_full() above.
83 | expected_ids = np.array(
84 | [2003, 1996, 1996, 2049, 3748, 1007, 4862, 1996, 2004, 2002],
85 | dtype=np.int32)
86 | expected_lengths = np.array([35, 239, 55, 56, 511, 511, 161, 43, 416, 511],
87 | dtype=np.int32)
88 | expected_ids = np.reshape(
89 | np.concatenate(
90 | [expected_ids, np.array([0] * 6, dtype=np.int32)], axis=0), [4, -1])
91 | expected_lengths = np.reshape(
92 | np.concatenate([expected_lengths,
93 | np.array([0] * 6, dtype=np.int32)],
94 | axis=0), [4, -1])
95 |
96 | for i in [1, 3, 2, 0]:
97 | # each shard would produce one batch, and then out of range.
98 | batch = sharded_inputs[i].get_next()
99 | ids, lengths = self._get_first_id_and_lengths(batch)
100 | self.assertArraysEqual(ids, expected_ids[i])
101 | self.assertArraysEqual(lengths, expected_lengths[i])
102 |
103 | with self.assertRaisesRegex(tf.errors.OutOfRangeError, 'End of sequence'):
104 | sharded_inputs[i].get_next()
105 |
106 |
107 | if __name__ == '__main__':
108 | absltest.main()
109 |
--------------------------------------------------------------------------------
/paxml/train_states.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """TrainState class for encapsulating model weights and optimizer states."""
17 |
18 | from __future__ import annotations
19 |
20 | import dataclasses
21 | from typing import Any, Generic, TypeVar
22 |
23 | from flax import struct as flax_struct
24 | import jax
25 | import jaxtyping as jt
26 | import optax
27 | from praxis import base_layer
28 | from praxis import py_utils
29 | from praxis import pytypes
30 | from praxis import trees
31 |
32 | JTensor = py_utils.JTensor
33 | JTensorProvenance = tuple[str, int | None]
34 | JTensorOrPartitionSpec = pytypes.JTensorOrPartitionSpec
35 | Nested = pytypes.Nested
36 | NestedJTensor = base_layer.NestedJTensor
37 | NestedJTensorOrPartitionSpec = pytypes.NestedJTensorOrPartitionSpec
38 | NestedMap = py_utils.NestedMap
39 |
40 |
41 | _ArrayOrPSpec = TypeVar('_ArrayOrPSpec', jax.Array, jax.sharding.PartitionSpec)
42 | """Either a pspec (when tracing) or a Jax tensor."""
43 | ExtraStateType = NestedJTensorOrPartitionSpec | None
44 |
45 |
46 | # A helper class for managing various train states. This struct may contain the
47 | # actual Jax tensors, or simply PartitionSpecs for the corresponding tensor.
48 | # If the latter, this struct is used for specifying the PartitionSpecs for
49 | # input/output to/from a pjit-ed function.
50 | class TrainState(flax_struct.PyTreeNode, Generic[_ArrayOrPSpec]):
51 | """Simple train state."""
52 |
53 | step: _ArrayOrPSpec
54 | mdl_vars: jt.PyTree[_ArrayOrPSpec]
55 | opt_states: list[jt.PyTree[_ArrayOrPSpec]]
56 | extra_state: ExtraStateType = ()
57 |
58 | def new_state(
59 | self,
60 | mdl_vars: NestedJTensor,
61 | opt_states: list[optax.OptState],
62 | extra_state: ExtraStateType = (),
63 | ) -> TrainState:
64 | """Returns a new TrainState with updated mdl_vars and opt_states."""
65 | return TrainState(
66 | step=self.step + 1,
67 | mdl_vars=trees.copy(mdl_vars),
68 | opt_states=trees.copy(opt_states),
69 | extra_state=trees.copy(extra_state),
70 | )
71 |
72 | def to_eval_state(self) -> TrainState:
73 | """Returns a new TrainState with opt_states removed, for eval purpose."""
74 | return TrainState(
75 | step=self.step, mdl_vars=self.mdl_vars, opt_states=[], extra_state=()
76 | )
77 |
78 |
79 | @dataclasses.dataclass
80 | class TensorProvenance:
81 | checkpoint_path: str = 'random_init'
82 | checkpoint_step: int | None = None
83 |
84 | def __repr__(self) -> str:
85 | if self.checkpoint_path == 'random_init':
86 | return f'"({self.checkpoint_path})"'
87 |
88 | checkpoint_step_repr = (
89 | self.checkpoint_step if self.checkpoint_step is not None else 'latest'
90 | )
91 |
92 | return f'"({self.checkpoint_path}:{checkpoint_step_repr})"'
93 |
94 |
95 | @dataclasses.dataclass
96 | class TrainStateProvenance:
97 | """Provenance for the TrainState pytree struct (not jax-transformable)."""
98 |
99 | step: TensorProvenance
100 | mdl_vars: Nested[TensorProvenance]
101 | opt_states: Nested[TensorProvenance]
102 | extra_state: Nested[TensorProvenance]
103 |
104 | def replace(self, **changes: Any) -> TrainStateProvenance:
105 | return dataclasses.replace(self, **changes)
106 |
107 |
108 | def build_train_state_provenance(
109 | train_state: TrainState,
110 | checkpoint_path: str | None = None,
111 | step: int | None = None,
112 | ) -> TrainStateProvenance:
113 | assert not isinstance(
114 | train_state.step, jax.sharding.PartitionSpec
115 | ), 'Tensor provenance is only for tensors'
116 |
117 | provenance = TensorProvenance()
118 | if checkpoint_path:
119 | provenance = TensorProvenance(
120 | checkpoint_path=checkpoint_path, checkpoint_step=step
121 | )
122 | return TrainStateProvenance(
123 | step=provenance,
124 | mdl_vars=jax.tree.map(lambda x: provenance, train_state.mdl_vars),
125 | opt_states=jax.tree.map(lambda x: provenance, train_state.opt_states),
126 | extra_state=jax.tree.map(lambda x: provenance, train_state.extra_state),
127 | )
128 |
--------------------------------------------------------------------------------
/paxml/xla_passthrough_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for xla_passthrough."""
17 |
18 | from absl.testing import absltest
19 | import numpy as np
20 | from paxml import xla_passthrough
21 |
22 |
23 | class InputUtilsTest(absltest.TestCase):
24 |
25 | def test_split_out_xla_unsupported_batch_noop(self):
26 | batch = {'a': np.array([1, 2, 3, 4]), 'b': np.array([5, 6, 7, 8])}
27 | partitioning_spec = {'a': 'fake_spec', 'b': 'fake_spec'}
28 | out_batch, out_unsupported, partitioning_spec = (
29 | xla_passthrough.split_out_xla_unsupported_batch(
30 | batch, partitioning_spec=partitioning_spec
31 | )
32 | )
33 | # In no-op cases, the exact same object should be returned.
34 | self.assertIs(batch, out_batch)
35 | self.assertEmpty(out_unsupported)
36 | xla_passthrough.merge_back_xla_unsupported_batch(out_batch, out_unsupported)
37 | self.assertEqual(out_batch, batch)
38 | # Partitioning spec should not be modified.
39 | self.assertIsNotNone(partitioning_spec)
40 | self.assertCountEqual(partitioning_spec.keys(), {'a', 'b'})
41 |
42 | def test_split_out_xla_unsupported_batch_singly_nested(self):
43 | batch = {
44 | 'a': np.array([1, 2, 3, 4]),
45 | 'b': np.array([5, 6, 7, 8]),
46 | 'c': np.array(['a', 'b', 'c', 'd']),
47 | }
48 | partitioning_spec = {'a': 'fake_spec', 'b': 'fake_spec', 'c': 'fake_spec'}
49 | out_batch, out_unsupported, new_partitioning_spec = (
50 | xla_passthrough.split_out_xla_unsupported_batch(
51 | batch, partitioning_spec=partitioning_spec
52 | )
53 | )
54 | self.assertCountEqual(out_batch.keys(), {'a', 'b'})
55 | self.assertCountEqual(out_unsupported.keys(), {'c'})
56 | # Verify that the unsupported parts were flattened.
57 | self.assertEqual(list(out_unsupported['c']), ['a', 'b', 'c', 'd'])
58 | xla_passthrough.merge_back_xla_unsupported_batch(out_batch, out_unsupported)
59 | self.assertCountEqual(out_batch.keys(), {'a', 'b', 'c'})
60 | # The original partitioning_spec should not be modified.
61 | self.assertCountEqual(partitioning_spec.keys(), {'a', 'b', 'c'})
62 | # The unsupported key should have been deleted from the new partitioning
63 | # spec.
64 | self.assertIsNotNone(new_partitioning_spec)
65 | self.assertCountEqual(new_partitioning_spec.keys(), {'a', 'b'})
66 |
67 | def test_split_out_xla_unsupported_batch_multi_nested(self):
68 | batch = {
69 | 'a': np.array([1, 2, 3, 4]),
70 | 'b': np.array([5, 6, 7, 8]),
71 | 'c': {
72 | 'd': np.array(['a', 'b', 'c', 'd']),
73 | 'e': np.array([1, 2, 3, 4]),
74 | },
75 | }
76 | partitioning_spec = {
77 | 'a': 'fake_spec',
78 | 'b': 'fake_spec',
79 | 'c': {'d': 'fake_spec', 'e': 'fake_spec'},
80 | }
81 | out_batch, out_unsupported, new_partitioning_spec = (
82 | xla_passthrough.split_out_xla_unsupported_batch(
83 | batch, partitioning_spec=partitioning_spec
84 | )
85 | )
86 | self.assertCountEqual(out_batch.keys(), {'a', 'b'})
87 | self.assertCountEqual(out_unsupported.keys(), {'c'})
88 | self.assertCountEqual(out_unsupported['c'].keys(), {'d'})
89 | # Verify that the unsupported parts were flattened.
90 | self.assertEqual(list(out_unsupported['c']['d']), ['a', 'b', 'c', 'd'])
91 | xla_passthrough.merge_back_xla_unsupported_batch(out_batch, out_unsupported)
92 | self.assertCountEqual(out_batch.keys(), {'a', 'b', 'c'})
93 | self.assertCountEqual(out_batch['c'].keys(), {'d'})
94 | # The original partitioning_spec should not be modified.
95 | self.assertCountEqual(partitioning_spec.keys(), {'a', 'b', 'c'})
96 | self.assertCountEqual(partitioning_spec['c'].keys(), {'d', 'e'})
97 | # The unsupported key should have been deleted from the new partitioning
98 | # spec.
99 | self.assertIsNotNone(new_partitioning_spec)
100 | self.assertCountEqual(new_partitioning_spec.keys(), {'a', 'b', 'c'})
101 | self.assertCountEqual(new_partitioning_spec['c'].keys(), {'e'})
102 |
103 |
104 | if __name__ == '__main__':
105 | absltest.main()
106 |
--------------------------------------------------------------------------------
/paxml/docs/tutorials/inputs_in_Pax-train.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "id": "4Rh0P4V34wZm"
7 | },
8 | "source": [
9 | "# Pax Workshop\n",
10 | "## Inputs in Pax - training\n",
11 | "\n",
12 | "This colab demonstrates how inputs in Pax work.\n"
13 | ]
14 | },
15 | {
16 | "cell_type": "code",
17 | "execution_count": null,
18 | "metadata": {
19 | "id": "mNJyBaezlI7O"
20 | },
21 | "outputs": [],
22 | "source": [
23 | "from praxis import base_input\n",
24 | "from praxis import base_hyperparams\n",
25 | "from paxml import seqio_input\n",
26 | "import numpy as np"
27 | ]
28 | },
29 | {
30 | "cell_type": "markdown",
31 | "metadata": {
32 | "id": "sVtD2spKACvz"
33 | },
34 | "source": [
35 | "Let's start with a SeqIO input using the wsc training data."
36 | ]
37 | },
38 | {
39 | "cell_type": "code",
40 | "execution_count": null,
41 | "metadata": {
42 | "id": "frocGIyv5Kd-"
43 | },
44 | "outputs": [],
45 | "source": [
46 | "import t5.data.tasks\n",
47 | "p = seqio_input.SeqIOInput.HParams(\n",
48 | " mixture_name='super_glue_wsc_v102_simple_train',\n",
49 | " split_name='train',\n",
50 | " task_feature_lengths={'targets': 1280},\n",
51 | " feature_converter=seqio_input.LanguageModelFeatures(pack=True),\n",
52 | " is_training=True,\n",
53 | " use_cached=False,\n",
54 | " input_random_seed=123,\n",
55 | " batch_size=4)\n",
56 | "inp = base_hyperparams.instantiate(p)"
57 | ]
58 | },
59 | {
60 | "cell_type": "code",
61 | "execution_count": null,
62 | "metadata": {
63 | "id": "eKeVT2kelzoa"
64 | },
65 | "outputs": [],
66 | "source": [
67 | "# Get a batch, inspect the spec of the data\n",
68 | "batch = inp.get_next()\n",
69 | "for k, v in batch.FlattenItems():\n",
70 | " print(k, v.shape, v.dtype)"
71 | ]
72 | },
73 | {
74 | "cell_type": "code",
75 | "execution_count": null,
76 | "metadata": {
77 | "id": "dGzr0xtq7AYF"
78 | },
79 | "outputs": [],
80 | "source": [
81 | "# The data is packed\n",
82 | "for _ in range(4):\n",
83 | " batch = inp.get_next()\n",
84 | " print('segments: ', np.max(batch.segment_ids, axis=1))\n"
85 | ]
86 | },
87 | {
88 | "cell_type": "markdown",
89 | "metadata": {
90 | "id": "Jq8trTrlDeXB"
91 | },
92 | "source": [
93 | "We set `input_random_seed=123` on the input hparams. What happens with `inp.reset()`? Does it reproduce the same data?\n",
94 | "\n",
95 | "What about if we re-instantiate the input object?"
96 | ]
97 | },
98 | {
99 | "cell_type": "code",
100 | "execution_count": null,
101 | "metadata": {
102 | "id": "v9dhcdqj74lY"
103 | },
104 | "outputs": [],
105 | "source": [
106 | "# Tweak some fields\n",
107 | "p2 = p.clone().set(infeed_host_index=0, num_infeed_hosts=2, shuffle=False)\n",
108 | "# disable packing\n",
109 | "p2.feature_converter = seqio_input.LanguageModelFeatures(pack=False)\n",
110 | "inp2 = base_hyperparams.instantiate(p2)\n",
111 | "\n",
112 | "p2_complement = p2.clone().set(infeed_host_index=1)\n",
113 | "inp2_complement = base_hyperparams.instantiate(p2_complement)\n",
114 | "\n",
115 | "batches = [inp2.get_next(), inp2_complement.get_next()]"
116 | ]
117 | },
118 | {
119 | "cell_type": "code",
120 | "execution_count": null,
121 | "metadata": {
122 | "id": "CjPqIOz68lJm"
123 | },
124 | "outputs": [],
125 | "source": [
126 | "inp.ids_to_strings(batches[0].ids, [1280] * 4)"
127 | ]
128 | },
129 | {
130 | "cell_type": "markdown",
131 | "metadata": {
132 | "id": "G-5jh4P4CcPp"
133 | },
134 | "source": [
135 | "Now inspect the data from `inp2_complement`. Verify that it does not overlap with the data from `inp2`. Does this hold if we run more batches?"
136 | ]
137 | },
138 | {
139 | "cell_type": "code",
140 | "execution_count": null,
141 | "metadata": {
142 | "id": "8fbBkr64-2tA"
143 | },
144 | "outputs": [],
145 | "source": [
146 | "# The data is also no longer packed\n",
147 | "np.max(batches[0].segment_ids, axis=1)"
148 | ]
149 | }
150 | ],
151 | "metadata": {
152 | "kernelspec": {
153 | "display_name": "Python 3 (ipykernel)",
154 | "language": "python",
155 | "name": "python3"
156 | },
157 | "language_info": {
158 | "codemirror_mode": {
159 | "name": "ipython",
160 | "version": 3
161 | },
162 | "file_extension": ".py",
163 | "mimetype": "text/x-python",
164 | "name": "python",
165 | "nbconvert_exporter": "python",
166 | "pygments_lexer": "ipython3",
167 | "version": "3.8.10"
168 | }
169 | },
170 | "nbformat": 4,
171 | "nbformat_minor": 1
172 | }
173 |
--------------------------------------------------------------------------------
/paxml/docs/models.md:
--------------------------------------------------------------------------------
1 | # Models
2 |
3 | doc/pax/models
4 |
5 | [TOC]
6 |
7 | ## Introduction
8 |
9 | A **model** is a collection of layers defining the network.
10 |
11 | In Pax, Models are used in [Tasks][tasks], which are part of
12 | [Experiments][experiments] that can be *trained*, *evaluated*, and *decoded* (as
13 | well as a mix of these).
14 |
15 | > Tip: For a rudamentary introduction to the basic Pax components, check out
16 | > [Pax Elements][pax-elements]. If you want to dive in for a hands-on
17 | > experience, try the [Pax Model and Task Jupyter Notebook][model_ipynb].
18 |
19 | ## Model How-To's
20 |
21 | ### Define a Model
22 |
23 | In Pax, a Model inherits from `BaseModel`; and `BaseModel`, in turn, inherits
24 | from `BaseLayer` along with a few interfaces for interacting with the
25 | model:
26 |
27 | * `compute_predictions()`
28 | * `compute_loss()`
29 | * `decode()`
30 | * `process_decode_out()`
31 |
32 | A BaseModel is a nothing more than a BaseLayer with a specific API that is used
33 | to integrate with the Pax trainer:
34 |
35 | ```python
36 | class BaseModel(base_layer.BaseLayer):
37 | ...
38 | ```
39 |
40 | Therefore, to build your own Pax Model, you will need to define these methods in
41 | your derived class.
42 |
43 | #### `compute_predictions`
44 |
45 | *Computes predictions for `input_batch`.*
46 |
47 | The output can be in the form of probabilistic distributions, such as softmax
48 | logits for discrete outputs, mixture of logistics for continuous values, or
49 | regression values.
50 |
51 | For training or evaluation, the output will be used for computing loss and
52 | gradient updates, including comparing predicted distributions between teacher
53 | and student for distillation. During inference the output can be used to compute
54 | final outputs, perhaps with sampling.
55 |
56 | Args:
57 |
58 | * input_batch: A `.NestedMap` object containing input tensors.
59 |
60 | Returns:
61 |
62 | * Predictions, either a single Tensor, a `.NestedMap`, or a namedtuple.
63 |
64 |
65 | #### `compute_loss`
66 |
67 | *Computes the loss and other metrics for the given predictions.*
68 |
69 | Args:
70 |
71 | * predictions: The output of `compute_predictions`.
72 | * input_batch: A `.NestedMap` object containing input tensors to this tower.
73 |
74 | Returns:
75 |
76 | * WeightedScalars - A dict or NestedMap containing str keys and
77 | (value, weight) pairs as values, where one or more entries are
78 | expected to correspond to the loss (or losses).
79 | * A dict containing arbitrary tensors describing something about each
80 | training example, where the first dimension of each tensor is the batch
81 | index.
82 |
83 | #### `decode`
84 |
85 | *Decodes the input_batch.*
86 |
87 | This code should be expected to run on TPUs.
88 |
89 | Args:
90 |
91 | * input_batch: The input batch. A `NestedMap` of tensors. Or, if input batch
92 | splitting is used, a list of `NestedMap`, one for each split.
93 |
94 | Returns a 3-tuple with:
95 |
96 | * weighted scalars, a NestedMap containing str keys and (value, weight)
97 | pairs for the current batch (a tuple of two scalars).
98 | * results, a `.NestedMap` as decoder output.
99 | * metrics, a NestedMap containing str keys and clu_metrics.Metric
100 | objects.
101 |
102 | #### `process_decode_out`
103 |
104 | *Processes one batch of decoded outputs.*
105 |
106 | This code will run on the host (CPU) and not on an accelerator (GPU or
107 | TPU). This allows you to run things that can't be processed on TPUs, such as
108 | strings.
109 |
110 | Args:
111 |
112 | * input_obj: The input object where a tokenizer is accessible.
113 | * decode_out: The output from decode(). May have an extra leading axis.
114 |
115 | Returns a 3-tuple with:
116 |
117 | * weighted scalars, a NestedMap containing str keys and (value, weight)
118 | pairs for the current batch (a tuple of two scalars).
119 | * A list of tuples where each element corresponds to a row in the batch.
120 | Each tuple is a key value pair.
121 | * metrics, a NestedMap containing str keys and clu_metrics.Metric
122 | objects. These will run outside of pmap/pjit.
123 |
124 | ---
125 |
126 | ### Select an Existing Model
127 |
128 | TODO: To be written. Volunteers welcome.
129 |
130 | For a library of pre-defined models, check out [base models][base-models], which
131 | includes:
132 |
133 | * LanguageModel
134 | * SequenceModel
135 | * ClassificationModel
136 |
137 |
138 |
139 |
140 | [base-models]: https://github.com/google/praxis/tree/main/praxis/layers/models.py
141 | [experiments]: https://github.com/google/paxml/tree/main/paxml/docs/experiments.md
142 | [model_ipynb]: https://github.com/google/paxml/tree/main/paxml/docs/hands-on-tutorials.md#pax-model-and-task
143 | [pax-elements]: https://github.com/google/paxml/tree/main/paxml/docs/learning-pax.md#pax-elements
144 | [tasks]: https://github.com/google/paxml/tree/main/paxml/docs/tasks.md
145 |
--------------------------------------------------------------------------------
/paxml/docs/hands-on-tutorials.md:
--------------------------------------------------------------------------------
1 | # Hands-on Tutorials
2 |
3 | doc/pax/hands-on-tutorials
4 |
5 | [TOC]
6 |
7 | TLDR: A curated set of Jupyter Notebooks to get you comfortable using Pax.
8 |
9 | ## Overview
10 |
11 | The following Jupyter Notebooks have been create by Pax SMEs. The goal is to provide a hands-on introduction
12 | basic Pax activities.
13 |
14 | Good luck! Have fun!
15 |
16 | ## Jupyter Notebooks
17 |
18 | ### A JAX Primer
19 |
20 | In this first Jupyter Notebook, you will be introduced to JAX basics. When finished, you
21 | will be ready to dive in and see how JAX is used in the Pax infrastructure for
22 | training and multipod models.
23 |
24 | **Jupyter Notebook (coming soon):** [A JAX Primer][ipynb_jax_primer]
25 |
26 | ### JAX for Edit-Distance
27 |
28 | This lab showcases a few new things you can do in Jupyter Notebook with JAX. Specifically,
29 | you will develop a native, JAX-based edit-distance algorithm that works on
30 | padded batches.
31 |
32 | You are encouraged to build your own JAX Jupyter Notebooks as a way to learn by doing.
33 |
34 | **Jupyter Notebook (coming soon):** [JAX for Edit-Distance][ipynb_jax_ed]
35 |
36 | ### Pax Layer Basics
37 |
38 | Get your first hands-on look at Pax. You will learn about its fundamental
39 | component, the Pax Layers: the essential building blocks of models. In the
40 | process, you will learn about the basics for authoring a new Pax layer.
41 |
42 | Additionally, you will learn about Flax, a high-performance neural network
43 | library for JAX that is designed for flexibility.
44 |
45 | **Jupyter Notebook:** [Pax Layer Basics][ipynb_pax_layer]
46 |
47 | ### Inputs in Pax
48 |
49 | This short Jupyter Notebook demonstrates how inputs work in Pax.
50 |
51 | **Jupyter Notebook:**
52 | * [Inputs in Pax (training)][ipynb_pax_inputs_train]
53 | * [Inputs in Pax (eval)][ipynb_pax_inputs_eval]
54 |
55 | ### Pax Model and Task
56 |
57 | Model and task examples in Pax
58 |
59 | **Jupyter Notebook (coming soon):** [Pax Model and Task][ipynb_model_and_task]
60 |
61 | ### Pax End-to-end Tutorial
62 |
63 | Here you will put together what you have learned so far. You will be building a
64 | Translator using a Transformer Encoder/Decoder Architecture via the Pax
65 | framework.
66 |
67 | Without going too deep into the Layer design, you will see how to build, test,
68 | and combine them to train a model.
69 |
70 | **Jupyter Notebook:** [Pax End-to-end Tutorial][ipynb_pax_e2e]
71 |
72 | ### Pax Checkpointing
73 |
74 | This lab shows how to use checkpoints for warm-starting. It covers three topics
75 | around checkpoint handling:
76 |
77 | * **Run a fine-tuning** starting from a pretrained checkpoint in Pax.
78 | * **Inspect variables** in a checkpoint file.
79 | * **Test a model** interactively using a checkpoint file.
80 |
81 | **Jupyter Notebook (coming soon):** [Pax Checkpointing][ipynb_checkpoint]
82 |
83 | ### Pax RNN Decode
84 |
85 | This Jupyter Notebook demonstates how to set up *extend step*, *init*, and *update decode
86 | state cache* for the model's autoregressive decoding.
87 |
88 | In *autoregressive decoding*, each output of the network is generated based on
89 | previously generated output. Models such as RNN and transformer uses
90 | autoregressive decoding to generate new sequence.
91 |
92 | **Jupyter Notebook (coming soon):** [Pax RNN Decode][ipynb_rnn_decode]
93 |
94 | ### Sharding in Pax
95 |
96 | This Jupyter Notebook demonstrates how to use *sharding* in Pax.
97 |
98 | **Jupyter Notebook (coming soon):** [Sharding Jupyter Notebook][ipynb_shard_hard]
99 |
100 |
101 | ## Where to go from here
102 |
103 | Congratulations! At this point, you should have a pretty good idea on how to
104 | work with the various aspects of Pax. The rest of this site should help guide
105 | you as you move on to deeper topics.
106 |
107 |
108 |
109 |
110 | [ipynb_shard_hard]: https://github.com/google/paxml/tree/main/paxml/docs/tutorials/sharding.ipynb
111 | [ipynb_checkpoint]: https://github.com/google/paxml/tree/main/paxml/docs/tutorials/pax201_checkpointing.ipynb
112 | [ipynb_jax_ed]: https://github.com/google/paxml/tree/main/paxml/docs/tutorials/pax101_jax_for_edit_distance.ipynb
113 | [ipynb_jax_primer]: https://github.com/google/paxml/tree/main/paxml/docs/tutorials/pax101_jax_primer.ipynb
114 | [ipynb_model_and_task]: https://github.com/google/paxml/tree/main/paxml/docs/tutorials/pax101_model_and_task.ipynb
115 | [ipynb_pax_e2e]: https://github.com/google/paxml/tree/main/paxml/docs/tutorials/pax101_e2e_tutorial.ipynb
116 | [ipynb_pax_inputs_train]: https://github.com/google/paxml/blob/main/paxml/docs/tutorials/inputs_in_Pax-train.ipynb
117 | [ipynb_pax_inputs_eval]: https://github.com/google/paxml/blob/main/paxml/docs/tutorials/inputs_in_Pax-eval.ipynb
118 | [ipynb_pax_layer]: https://github.com/google/paxml/tree/main/paxml/docs/tutorials/pax_layer_basics.ipynb
119 | [ipynb_rnn_decode]: https://github.com/google/paxml/tree/main/paxml/docs/tutorials/pax_rnn_decode.ipynb
120 |
--------------------------------------------------------------------------------
/paxml/host_callback.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Utilities for host callbacks."""
17 |
18 | import collections
19 | import threading
20 |
21 |
22 | class Repository:
23 | """Thread-safe container of ID-keyed strings to pass strings to host_callback.
24 |
25 | The intended usage is as follows:
26 | 1. Set up a `Repository` that is accessible from both pre-processing
27 | and the model. For convenience, the function `repository(namespace)`
28 | in this module returns a per-namespace singleton `Repository`.
29 | Other approaches include a module-level `Repository` variable or a
30 | `Repository` injected via hparams.
31 | 2. In pre-processing, use `repository(namespace).add()` to add strings
32 | and pass the resulting string IDs to the model.
33 | 3. In the model/accelerator, use `repository(namespace).get()` to fetch
34 | strings by ID.
35 | 4. Set the `device_index` argument in host_callback.call match the device
36 | that runs pre-processing.
37 | 5. In post-processing, use `repository(namespace).pop()` to remove strings
38 | by ID. There is also a last-resort eviction policy, see `MAX_SIZE`.
39 |
40 | To avoid OOM when the caller does not promptly pop() the strings they add(),
41 | there is a limit on size. If this grows beyond that limit, then strings are
42 | evicted in least-recently-added order.
43 |
44 | TODO(terrykoo): Define string ID using fingerprints to allow caching and
45 | reuse. If we do this, however, we will need to add refcounts so pop() only
46 | removes an ID when all usages of it have subsided.
47 | """
48 |
49 | # Maximum number of strings held in the repository of each namespace.
50 | MAX_SIZE = 10000
51 |
52 | def __init__(self, max_size: int = MAX_SIZE):
53 | """Creates an empty repository.
54 |
55 | If you use a non-singleton `Repository`, the generated string IDs might not
56 | be sufficiently unique.
57 |
58 | Args:
59 | max_size: Maximum number of strings to hold.
60 | """
61 | self._max_size = max_size
62 | self._lock = threading.Lock()
63 | self._string_by_id = dict()
64 | self._next_id_to_assign = 0
65 | self._next_id_to_evict = 0
66 |
67 | def add(self, value: str) -> int:
68 | """Adds new string to the mapping and returns its global ID.
69 |
70 | If necessary, also evicts old regexes to keep this under the maximum size.
71 |
72 | Args:
73 | value: String to add.
74 |
75 | Returns:
76 | ID of the string. IDs are unique per LM server provided the caller uses
77 | the singleton.
78 | """
79 | with self._lock:
80 | string_id = self._next_id_to_assign
81 | self._next_id_to_assign += 1
82 | self._string_by_id[string_id] = value
83 |
84 | while len(self._string_by_id) > self._max_size:
85 | self._string_by_id.pop(self._next_id_to_evict)
86 | self._next_id_to_evict += 1
87 |
88 | return string_id
89 |
90 | def pop(self, value_id: int) -> bool:
91 | """Attempts to remove the `string_id`.
92 |
93 | The regex might not be removed if the `string_id` is unknown.
94 |
95 | Args:
96 | value_id: ID of the string to remove, as returned by add().
97 |
98 | Returns:
99 | True if the string was removed.
100 | """
101 | with self._lock:
102 | return self._string_by_id.pop(value_id, None) is not None
103 |
104 | def get(self, value_id: int) -> str:
105 | """Returns the string mapped to `string_id`.
106 |
107 | Args:
108 | value_id: ID of the string to fetch, as returned by add().
109 |
110 | Returns:
111 | String associated with the `value_id`.
112 |
113 | Raises:
114 | KeyError: If the `value_id` is not mapped.
115 | """
116 | with self._lock:
117 | return self._string_by_id[value_id]
118 |
119 | @property
120 | def size(self) -> int:
121 | """Returns the number of strings in this."""
122 | with self._lock:
123 | return len(self._string_by_id)
124 |
125 |
126 | # This is defined and instantiated after the class, because (unlike languages
127 | # like C++, Java, or TypeScript) Python classes don't exist until after their
128 | # definition ends.
129 | _global_lock = threading.Lock()
130 | _global_repository_by_namespace = collections.defaultdict(Repository)
131 |
132 |
133 | def repository(namespace: str) -> Repository:
134 | with _global_lock:
135 | return _global_repository_by_namespace[namespace]
136 |
--------------------------------------------------------------------------------