├── paxml ├── docs │ ├── images │ │ ├── 16B-loss.png │ │ ├── 16B-pplx.png │ │ ├── 1B-loss.png │ │ ├── 1B-pplx.png │ │ ├── GPT3-XL-loss.png │ │ ├── GPT3-XL-pplx.png │ │ └── Weak_scaling_of_large_language_model_training_on_TPU_v4.png │ ├── README.md │ ├── tasks.md │ ├── tutorials │ │ ├── inputs_in_Pax-eval.ipynb │ │ └── inputs_in_Pax-train.ipynb │ ├── models.md │ └── hands-on-tutorials.md ├── tasks │ ├── lm │ │ ├── testdata │ │ │ └── tfrecords │ │ ├── __init__.py │ │ ├── BUILD │ │ ├── params │ │ │ ├── optimal_scaling.py │ │ │ ├── c4_test.py │ │ │ └── BUILD │ │ └── input_generator_test.py │ ├── test │ │ ├── BUILD │ │ └── synthetic.py │ └── vision │ │ ├── BUILD │ │ ├── input_generator_test.py │ │ └── params │ │ └── BUILD ├── contrib │ └── gpu │ │ ├── README.md │ │ └── scripts_gpu │ │ ├── download_boolq.py │ │ ├── download_the_pile.py │ │ ├── download_lambada.py │ │ ├── checkpoint_utils.py │ │ ├── run_pile_singlenode.sh │ │ ├── run_lambada_singlenode.sh │ │ ├── run_base_config_multinode.sh │ │ ├── run_pile_multinode.sh │ │ └── lora_utils.py ├── AUTHORS ├── pip_package │ ├── cloudbuild-postsubmit.yaml │ ├── cloudbuild-presubmit.yaml │ ├── postsubmit.Dockerfile │ ├── compile_requirements_helper.sh │ ├── cloudbuild.yaml │ ├── presubmit.Dockerfile │ ├── cloudbuild-release.yaml │ ├── collect_wheels.sh │ ├── compile_requirements.sh │ ├── build_pip_pkg.sh │ ├── release.Dockerfile │ ├── build.sh │ ├── prepare_release.sh │ └── Dockerfile ├── first_result_metric_callback.py ├── test_helper.py ├── lineage_logging.py ├── tools │ ├── fiddle │ │ ├── graphviz_utils_test.py │ │ ├── wrap_nested_maps.py │ │ ├── convert_seqio_task_objects_test.py │ │ ├── codegen_pax_code_ir.py │ │ ├── convert_seqio_task_objects.py │ │ ├── wrap_nested_maps_test.py │ │ ├── codegen_highlevel_parameterization_test.py │ │ ├── graphviz_utils.py │ │ ├── unshare_sharding.py │ │ ├── remove_sharding.py │ │ ├── make_parameterized_experiment.py │ │ ├── codegen_highlevel_parameterization.py │ │ ├── codegen_tracer.py │ │ ├── codegen_external_init_checkpoint_fns_test.py │ │ └── remove_sharding_test.py │ ├── validate_config.py │ ├── dump_hparams.py │ ├── dump_input_specs.py │ ├── BUILD │ └── dump_input_specs_lib.py ├── experimental │ ├── nested_map_config_helper_test.py │ ├── nested_map_config_helper.py │ └── BUILD ├── base_task.py ├── paxml.bzl ├── ml_monitoring.py ├── checkpoint_types.py ├── experiment_imports_all_test.py ├── partitioning_test.py ├── base_executor.py ├── experiment_vars_summary_parser.py ├── checkpoint_version.py ├── main_test.py ├── experiment_vars_summary_test.py ├── host_callback_test.py ├── setup_jax.py ├── ghostnorm │ └── BUILD ├── base_inference_runner_test.py ├── profiling.py ├── base_inference_runner.py ├── xla_passthrough.py ├── train_states.py ├── xla_passthrough_test.py └── host_callback.py ├── requirements.in ├── CONTRIBUTING.md ├── WORKSPACE └── setup.py /paxml/docs/images/16B-loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/paxml/HEAD/paxml/docs/images/16B-loss.png -------------------------------------------------------------------------------- /paxml/docs/images/16B-pplx.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/paxml/HEAD/paxml/docs/images/16B-pplx.png -------------------------------------------------------------------------------- /paxml/docs/images/1B-loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/paxml/HEAD/paxml/docs/images/1B-loss.png -------------------------------------------------------------------------------- /paxml/docs/images/1B-pplx.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/paxml/HEAD/paxml/docs/images/1B-pplx.png -------------------------------------------------------------------------------- /paxml/docs/images/GPT3-XL-loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/paxml/HEAD/paxml/docs/images/GPT3-XL-loss.png -------------------------------------------------------------------------------- /paxml/docs/images/GPT3-XL-pplx.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/paxml/HEAD/paxml/docs/images/GPT3-XL-pplx.png -------------------------------------------------------------------------------- /paxml/tasks/lm/testdata/tfrecords: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/paxml/HEAD/paxml/tasks/lm/testdata/tfrecords -------------------------------------------------------------------------------- /paxml/docs/images/Weak_scaling_of_large_language_model_training_on_TPU_v4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/paxml/HEAD/paxml/docs/images/Weak_scaling_of_large_language_model_training_on_TPU_v4.png -------------------------------------------------------------------------------- /paxml/contrib/gpu/README.md: -------------------------------------------------------------------------------- 1 | # Pax on GPUs 2 | Please refer to [Rosetta](https://github.com/NVIDIA/JAX-Toolbox/tree/main/rosetta/rosetta/projects/pax), NVIDIA's project that enables seamless training of LLMs, CV models and multimodal models in JAX, for information about running experiments on GPUs in Pax. 3 | -------------------------------------------------------------------------------- /paxml/AUTHORS: -------------------------------------------------------------------------------- 1 | # This is the list of Pax's significant contributors. 2 | # 3 | # This does not necessarily list everyone who has contributed code, 4 | # especially since many employees of one corporation may be contributing. 5 | # To see the full list of contributors, see the revision history in 6 | # source control. 7 | Google LLC 8 | NVIDIA Corporation -------------------------------------------------------------------------------- /paxml/pip_package/cloudbuild-postsubmit.yaml: -------------------------------------------------------------------------------- 1 | steps: 2 | - name: 'gcr.io/cloud-builders/docker' 3 | args: [ 4 | 'build', 5 | '--build-arg', 'image_name=${_IMAGE_NAME}', 6 | '-f', 'paxml/pip_package/postsubmit.Dockerfile', '.' 7 | ] 8 | 9 | substitutions: 10 | _PYTHON_VERSION: '3.10' 11 | _RELEASE_VERSION: 'nightly' # or rX.Y 12 | _IMAGE_NAME: 'paxml_${_RELEASE_VERSION}_${_PYTHON_VERSION}' 13 | options: 14 | dynamic_substitutions: true 15 | substitution_option: 'ALLOW_LOOSE' 16 | timeout: 1200s 17 | -------------------------------------------------------------------------------- /paxml/pip_package/cloudbuild-presubmit.yaml: -------------------------------------------------------------------------------- 1 | steps: 2 | - name: 'gcr.io/cloud-builders/docker' 3 | args: [ 4 | 'build', 5 | '--build-arg', 'image_name=${_IMAGE_NAME}', 6 | '-f', 'paxml/pip_package/presubmit.Dockerfile', '.' 7 | ] 8 | 9 | substitutions: 10 | _PYTHON_VERSION: '3.10' 11 | _RELEASE_VERSION: 'nightly' # or rX.Y 12 | _IMAGE_NAME: 'paxml_${_RELEASE_VERSION}_${_PYTHON_VERSION}' 13 | options: 14 | dynamic_substitutions: true 15 | substitution_option: 'ALLOW_LOOSE' 16 | machineType: E2_HIGHCPU_8 17 | timeout: 1200s 18 | -------------------------------------------------------------------------------- /requirements.in: -------------------------------------------------------------------------------- 1 | # To update requirements.txt for praxis and paxml, run: 2 | # bash ./compile_requirements.sh 3 | 4 | absl-py 5 | clu @ git+https://github.com/google/CommonLoopUtils 6 | etils 7 | flax @ git+https://github.com/google/flax 8 | graphviz 9 | jax @ git+https://github.com/google/jax 10 | lingvo 11 | numpy 12 | orbax-checkpoint @ git+https://github.com/google/orbax/#subdirectory=checkpoint 13 | praxis 14 | protobuf==3.19.6 15 | pyglove 16 | seqio-nightly 17 | t5 18 | tensorflow~=2.9.2 19 | tensorflow-text~=2.9.0 20 | tensorstore 21 | tensorflow-datasets==4.8.3 22 | tfds-nightly==4.8.3.dev202303280045 23 | tensorflow-metadata==1.12.0 24 | -------------------------------------------------------------------------------- /paxml/tasks/lm/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | -------------------------------------------------------------------------------- /paxml/first_result_metric_callback.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | def train_first_result_callback_fn(metric: bool) -> None: 17 | return 18 | -------------------------------------------------------------------------------- /paxml/contrib/gpu/scripts_gpu/download_boolq.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import tensorflow_datasets as tfds 17 | 18 | # This will download 'super_glue/boolq' to TFDS_DATA_DIR (environment variable). 19 | ds = tfds.load('super_glue/boolq') 20 | -------------------------------------------------------------------------------- /paxml/contrib/gpu/scripts_gpu/download_the_pile.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import paxml.contrib.gpu.scripts_gpu.tfds_pile 17 | import tensorflow_datasets as tfds 18 | 19 | # This will download 'ThePile' to TFDS_DATA_DIR (environment variable). 20 | ds = tfds.load('ThePile') 21 | -------------------------------------------------------------------------------- /paxml/test_helper.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Helpers for unit tests.""" 17 | 18 | import os 19 | 20 | 21 | def test_src_dir_path(relative_path: str) -> str: 22 | return os.path.join(os.environ['TEST_SRCDIR'], '__main__/paxml', 23 | relative_path) 24 | -------------------------------------------------------------------------------- /paxml/contrib/gpu/scripts_gpu/download_lambada.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import paxml.contrib.gpu.scripts_gpu.tfds_lambada 17 | import tensorflow_datasets as tfds 18 | 19 | # This will download 'MyLambada' to TFDS_DATA_DIR (environment variable). 20 | ds = tfds.load('MyLambada') 21 | -------------------------------------------------------------------------------- /paxml/lineage_logging.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Utils for lineage logging.""" 17 | 18 | # internal lineage logging module import 19 | 20 | 21 | def log_config_file_lineage(filepath: str) -> None: 22 | """Logs config file lineage.""" 23 | del filepath 24 | # Internal lineage logging implementation 25 | -------------------------------------------------------------------------------- /paxml/pip_package/postsubmit.Dockerfile: -------------------------------------------------------------------------------- 1 | ARG image_name 2 | ARG base_image="gcr.io/pax-on-cloud-project/${image_name}:latest" 3 | FROM $base_image 4 | 5 | RUN rm -rf /praxis && rm -rf /paxml/paxml && rm -rf /paxml/praxis 6 | COPY . /paxml_new 7 | RUN git clone https://github.com/google/praxis.git 8 | RUN mv /praxis/praxis /paxml/ && mv /paxml_new/paxml /paxml/ 9 | RUN pip3 uninstall -y fiddle 10 | RUN pip3 uninstall -y seqio 11 | RUN pip3 uninstall -y flax 12 | RUN pip3 uninstall -y jax 13 | RUN pip3 install --no-deps -r /paxml/paxml/pip_package/requirements.txt 14 | 15 | RUN cd /paxml && bazel build ... 16 | RUN cd /paxml && \ 17 | bazel test \ 18 | --test_output=all \ 19 | --test_verbose_timeout_warnings \ 20 | -- \ 21 | paxml/... \ 22 | -paxml/tasks/lm/params:c4_test \ 23 | -paxml/tasks/vision:input_generator_test \ 24 | -paxml:checkpoint_managers_test \ 25 | -paxml:seqio_input_test \ 26 | -paxml:tasks_lib_test 27 | 28 | WORKDIR / 29 | 30 | CMD ["/bin/bash"] 31 | -------------------------------------------------------------------------------- /paxml/pip_package/compile_requirements_helper.sh: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | #!/bin/bash 17 | 18 | set -e -x 19 | 20 | pip3 install -U pip-tools 21 | cd /tmp/requirements 22 | pip-compile --quiet --output-file paxml-requirements.txt praxis-requirements.in paxml-requirements.in 23 | pip-compile --quiet --output-file praxis-requirements.txt praxis-requirements.in 24 | -------------------------------------------------------------------------------- /paxml/tools/fiddle/graphviz_utils_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for graphviz.""" 17 | 18 | from absl.testing import absltest 19 | from paxml.tools.fiddle import graphviz_utils 20 | from paxml.tools.fiddle import test_fixtures 21 | 22 | 23 | class GraphvizTest(absltest.TestCase): 24 | 25 | def test_smoke_render(self): 26 | config = test_fixtures.SampleExperimentNewBaseline().experiment_fixture() 27 | graphviz_utils.render(config=config) 28 | 29 | 30 | if __name__ == "__main__": 31 | absltest.main() 32 | -------------------------------------------------------------------------------- /paxml/tools/validate_config.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | r"""A simple utility to validate an experiment config. 17 | 18 | The binary target `:validate_config` is defined by `pax_targets()` in the `BUILD` 19 | file. 20 | 21 | Example commandline: 22 | bazel run //PATH/TO/PAX/TARGETS:validate_config -- \ 23 | --exp=lm1b.Lm1bTransformerL32H8kSPMD8x8Repeat \ 24 | --completeness=light 25 | """ 26 | 27 | from paxml.tools import validate_config_lib 28 | 29 | 30 | if __name__ == '__main__': 31 | validate_config_lib.main() 32 | 33 | -------------------------------------------------------------------------------- /paxml/pip_package/cloudbuild.yaml: -------------------------------------------------------------------------------- 1 | steps: 2 | - name: 'gcr.io/cloud-builders/docker' 3 | args: [ 4 | 'build', 5 | '--build-arg', 'wheel_folder=${_WHEEL_FOLDER}', 6 | '-t', 'gcr.io/${PROJECT_ID}/${_IMAGE_NAME}', 7 | '-f', 'paxml/pip_package/Dockerfile', '.' 8 | ] 9 | timeout: 3600s 10 | - name: 'gcr.io/cloud-builders/docker' 11 | args: ['push', '--all-tags', 'gcr.io/${PROJECT_ID}/${_IMAGE_NAME}'] 12 | timeout: 1800s 13 | - name: 'gcr.io/${PROJECT_ID}/${_IMAGE_NAME}' 14 | entrypoint: 'bash' 15 | args: ['-c', 'source paxml/pip_package/collect_wheels.sh && collect_wheels ${_RELEASE_VERSION} ${_WHEEL_FOLDER}'] 16 | 17 | substitutions: 18 | _PYTHON_VERSION: '3.10' 19 | _RELEASE_VERSION: 'nightly' # or rX.Y 20 | _IMAGE_NAME: 'paxml_${_RELEASE_VERSION}_${_PYTHON_VERSION}' 21 | _WHEEL_FOLDER: '/tmp/wheels' 22 | options: 23 | dynamic_substitutions: true 24 | substitution_option: 'ALLOW_LOOSE' 25 | machineType: E2_HIGHCPU_32 26 | timeout: 5400s 27 | artifacts: 28 | objects: 29 | location: 'gs://pax-on-cloud-tpu-project/wheels/$(date -u +%Y%m%d)' 30 | paths: ['/**/*.whl', '/**/*.txt'] 31 | -------------------------------------------------------------------------------- /paxml/pip_package/presubmit.Dockerfile: -------------------------------------------------------------------------------- 1 | ARG image_name 2 | ARG base_image="gcr.io/pax-on-cloud-project/${image_name}:latest" 3 | FROM $base_image 4 | 5 | RUN rm -rf /praxis && rm -rf /paxml/paxml && rm -rf /paxml/praxis 6 | COPY . /paxml_new 7 | RUN git clone https://github.com/google/praxis.git 8 | RUN mv /praxis/praxis /paxml/ && mv /paxml_new/paxml /paxml/ 9 | RUN pip3 uninstall -y fiddle 10 | RUN pip3 uninstall -y seqio 11 | RUN pip3 uninstall -y flax 12 | RUN pip3 uninstall -y jax 13 | RUN pip3 install --no-deps -r /paxml/paxml/pip_package/requirements.txt 14 | RUN cd /paxml && bazel build ... 15 | 16 | # RUN cd /paxml && bazel test paxml/... --test_output=all --test_verbose_timeout_warnings 17 | # RUN cd /paxml && bazel test paxml:tasks_lib_test --test_output=all --test_verbose_timeout_warnings 18 | RUN cd /paxml && \ 19 | bazel test \ 20 | --test_output=all \ 21 | --test_verbose_timeout_warnings \ 22 | -- \ 23 | paxml/... \ 24 | -paxml/tasks/lm/params:c4_test \ 25 | -paxml/tasks/vision:input_generator_test \ 26 | -paxml:checkpoint_managers_test \ 27 | -paxml:seqio_input_test \ 28 | -paxml:tasks_lib_test 29 | WORKDIR / 30 | 31 | CMD ["/bin/bash"] 32 | -------------------------------------------------------------------------------- /paxml/experimental/nested_map_config_helper_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for nested_map_config_helper.""" 17 | 18 | from absl.testing import absltest 19 | from paxml.experimental import nested_map_config_helper 20 | 21 | 22 | class NestedMapConfigHelperTest(absltest.TestCase): 23 | 24 | def test_make_nested_map(self): 25 | result = nested_map_config_helper.make_nested_map(base={"foo": 1, "bar": 2}) 26 | self.assertEqual(result.foo, 1) 27 | self.assertEqual(result.bar, 2) 28 | 29 | 30 | if __name__ == "__main__": 31 | absltest.main() 32 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement. You (or your employer) retain the copyright to your contribution; 10 | this simply gives us permission to use and redistribute your contributions as 11 | part of the project. Head over to to see 12 | your current agreements on file or to sign a new one. 13 | 14 | You generally only need to submit a CLA once, so if you've already submitted one 15 | (even if it was for a different project), you probably don't need to do it 16 | again. 17 | 18 | ## Code reviews 19 | 20 | All submissions, including submissions by project members, require review. We 21 | use GitHub pull requests for this purpose. Consult 22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 23 | information on using pull requests. 24 | 25 | ## Community Guidelines 26 | 27 | This project follows 28 | [Google's Open Source Community Guidelines](https://opensource.google.com/conduct/). 29 | -------------------------------------------------------------------------------- /paxml/pip_package/cloudbuild-release.yaml: -------------------------------------------------------------------------------- 1 | steps: 2 | - name: 'gcr.io/cloud-builders/docker' 3 | args: [ 4 | 'build', 5 | '-t', 'gcr.io/${PROJECT_ID}/${_IMAGE_NAME}', 6 | '-f', 'paxml/pip_package/release.Dockerfile', '.', 7 | '--build-arg', 'wheel_folder=${_WHEEL_FOLDER}', 8 | '--build-arg', 'praxis_version=${_PRAXIS_VERSION}', 9 | ] 10 | timeout: 3600s 11 | - name: 'gcr.io/cloud-builders/docker' 12 | args: ['push', '--all-tags', 'gcr.io/${PROJECT_ID}/${_IMAGE_NAME}'] 13 | timeout: 1800s 14 | - name: 'gcr.io/${PROJECT_ID}/${_IMAGE_NAME}' 15 | entrypoint: 'bash' 16 | args: ['-c', 'mv ${_WHEEL_FOLDER}/*.whl .'] 17 | 18 | substitutions: 19 | _PYTHON_VERSION: '3.10' 20 | _RELEASE_VERSION: '1.4.0' # or rX.Y 21 | _PRAXIS_VERSION: '1.4.0' # or rX.Y 22 | _IMAGE_NAME: 'paxml_${_RELEASE_VERSION}_${_PYTHON_VERSION}' 23 | _WHEEL_FOLDER: '/tmp/wheels' 24 | options: 25 | dynamic_substitutions: true 26 | substitution_option: 'ALLOW_LOOSE' 27 | machineType: E2_HIGHCPU_8 28 | timeout: 5400s 29 | artifacts: 30 | objects: 31 | location: 'gs://pax-on-cloud-tpu-project/wheels/$(date -u +%Y%m%d)-paxml-${_RELEASE_VERSION}' 32 | paths: ['/**/*.whl'] 33 | -------------------------------------------------------------------------------- /paxml/tasks/test/BUILD: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Description: 17 | # Test model configurations. 18 | 19 | load("@rules_python//python:defs.bzl", "py_library") 20 | load("//praxis:build-visibility.bzl", "JAX_VISIBILITY") 21 | 22 | package(default_visibility = JAX_VISIBILITY) 23 | 24 | licenses(["notice"]) 25 | 26 | py_library( 27 | name = "synthetic", 28 | srcs = [ 29 | "synthetic.py", 30 | ], 31 | tags = ["keep_dep"], 32 | deps = [ 33 | "//paxml:base_experiment", 34 | "//paxml:experiment_registry", 35 | "//praxis:pax_fiddle", 36 | "//praxis/layers", 37 | ], 38 | ) 39 | -------------------------------------------------------------------------------- /paxml/base_task.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Base class for all tasks. 17 | 18 | A model solely consists of the network, while a task combines one or several 19 | models with one or several learners/optimizers. 20 | """ 21 | 22 | from __future__ import annotations 23 | 24 | import abc 25 | 26 | from praxis import base_hyperparams 27 | 28 | 29 | class BaseTask( 30 | base_hyperparams.FiddleBaseParameterizable, metaclass=abc.ABCMeta 31 | ): 32 | """Abstract base class for all tasks.""" 33 | 34 | def __post_init__(self): 35 | assert self.name, ( 36 | 'Task params for %s must have a "name"' % self.__class__.__name__ 37 | ) 38 | -------------------------------------------------------------------------------- /paxml/tasks/test/synthetic.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Test model configuration using synthetic data.""" 17 | 18 | from paxml import base_experiment 19 | from paxml import experiment_registry 20 | from praxis import layers 21 | from praxis import pax_fiddle 22 | 23 | 24 | @experiment_registry.register 25 | class SyntheticClassifier(base_experiment.BaseExperiment): 26 | # TODO(shafey): Implement a real test model. 27 | 28 | def datasets(self): 29 | return [] 30 | 31 | def task(self): 32 | act_p = pax_fiddle.Config(layers.Identity) 33 | return act_p 34 | 35 | 36 | @experiment_registry.register 37 | class SharedNameExperiment(SyntheticClassifier): 38 | pass 39 | -------------------------------------------------------------------------------- /paxml/pip_package/collect_wheels.sh: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | #!/bin/bash 17 | 18 | function collect_wheels() { 19 | release_version=$1 20 | wheel_version="${release_version}" 21 | wheel_folder=$2 22 | if [ "${release_version}" != "nightly" ]; then 23 | wheel_version=$( echo "${release_version}" | grep -oP '\d+.\d+(.\d+)?' ) 24 | fi 25 | 26 | mkdir /tmp/staging-wheels 27 | pushd /tmp/staging-wheels 28 | cp $wheel_folder/*.whl . 29 | rename -v "s/^paxml-(.*?)-py3/paxml-${wheel_version}+$(date -u +%Y%m%d)-py3/" *.whl 30 | rename -v "s/^praxis-(.*?)-py3/praxis-${wheel_version}+$(date -u +%Y%m%d)-py3/" *.whl 31 | popd 32 | mv /tmp/staging-wheels/* . 33 | mv $wheel_folder/*.txt . 34 | } 35 | -------------------------------------------------------------------------------- /paxml/paxml.bzl: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Implements custom rules for Paxml.""" 17 | 18 | # Placeholder to use until bazel supports pytype_*. 19 | def pytype_library(name, **kwargs): 20 | native.py_library(name = name, **kwargs) 21 | 22 | def pytype_strict_library(name, **kwargs): 23 | native.py_library(name = name, **kwargs) 24 | 25 | def pytype_binary(name, **kwargs): 26 | native.py_binary(name = name, **kwargs) 27 | 28 | def pytype_strict_binary(name, **kwargs): 29 | native.py_binary(name = name, **kwargs) 30 | 31 | def pytype_strict_test(name, **kwargs): 32 | native.py_test(name = name, **kwargs) 33 | 34 | # Placeholder to use until bazel supports py_strict_test. 35 | def py_strict_test(name, **kwargs): 36 | native.py_test(name = name, **kwargs) 37 | -------------------------------------------------------------------------------- /paxml/experimental/nested_map_config_helper.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Simple shims for Fiddle configs to reduce custom objects. 17 | 18 | Inserting custom objects (anything that's not a Fiddle Buildable, list, tuple, 19 | dict, enum, or primitive) into configuration can limit use of configs with 20 | Fiddle's various tooling. 21 | 22 | Therefore, we have a very simple helper function to convert primitive dicts 23 | (which can live in the config) to NestedMap's (which we want after building the 24 | config). 25 | """ 26 | 27 | from typing import Any 28 | 29 | from praxis import pytypes 30 | 31 | 32 | def make_nested_map(base: dict[str, Any]) -> pytypes.NestedMap: 33 | """Converts a dict to a NestedMap. 34 | 35 | Args: 36 | base: The dict to convert. 37 | 38 | Returns: 39 | A NestedMap. 40 | """ 41 | return pytypes.NestedMap(base) 42 | -------------------------------------------------------------------------------- /paxml/tools/dump_hparams.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | r"""A simple utility to dump experiment hparams to stdout or a txt file. 17 | 18 | The binary target `:dump_hparams` is defined by `pax_targets()` in the `BUILD` 19 | file. 20 | 21 | Example commandline: 22 | python paxml/tools/dump_hparams.py \ 23 | --exp=tasks.lm.params.lm_cloud.LmCloudTransformerAdamTest \ 24 | --params_ofile=/tmp/bert.txt 25 | 26 | Alternatively, omitting --params_ofile just prints to stdout, which might be 27 | useful in shell pipelines. 28 | 29 | You may also additionally specify --post_init_params_ofile, to write a second 30 | file consisting of the post-init model params (this takes longer to generate): 31 | --post_init_params_ofile=/tmp/lm_post.txt 32 | """ 33 | 34 | from paxml.tools import dump_hparams_lib 35 | 36 | 37 | if __name__ == '__main__': 38 | dump_hparams_lib.main() 39 | -------------------------------------------------------------------------------- /paxml/tools/fiddle/wrap_nested_maps.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Wraps NestedMap instances into config sub-graphs that produce them. 17 | 18 | Custom objects like NestedMap can cause problems, so we rewrite them into a 19 | fdl.Config of a helper function. 20 | """ 21 | 22 | from fiddle import daglish 23 | from paxml.experimental import nested_map_config_helper 24 | from praxis import pax_fiddle 25 | from praxis import pytypes 26 | 27 | 28 | def wrap_nested_maps(config): 29 | """Wraps NestedMap instances into config sub-graphs that produce them.""" 30 | 31 | def traverse(value, state: daglish.State): 32 | if isinstance(value, pytypes.NestedMap): 33 | value = pax_fiddle.Config( 34 | nested_map_config_helper.make_nested_map, base=dict(value) 35 | ) 36 | return state.map_children(value) 37 | 38 | return daglish.MemoizedTraversal.run(traverse, config) 39 | -------------------------------------------------------------------------------- /paxml/ml_monitoring.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """ML Monitoring for PAX.""" 17 | 18 | import contextlib 19 | import enum 20 | 21 | 22 | class MlEvent(enum.Enum): 23 | """ML events to be recorded.""" 24 | 25 | INITIALIZE_BACKEND = enum.auto() 26 | INITIALIZE_SETUP = enum.auto() 27 | MAIN_LOOP = enum.auto() 28 | TRAIN_STEP = enum.auto() 29 | EVAL_STEP = enum.auto() 30 | DECODE_STEP = enum.auto() 31 | 32 | 33 | class EventBoundary(enum.Enum): 34 | """Event boundary to be recorded.""" 35 | 36 | START = enum.auto() 37 | END = enum.auto() 38 | 39 | 40 | def record_step_number(step_number: int): 41 | """Records the step number.""" 42 | pass 43 | 44 | 45 | def record_event_boundary(event: MlEvent, boundary: EventBoundary, **kwargs): 46 | """Records the event boundary.""" 47 | pass 48 | 49 | 50 | @contextlib.contextmanager 51 | def ml_event_logger(event: MlEvent, **kwargs): 52 | yield 53 | -------------------------------------------------------------------------------- /paxml/tools/fiddle/convert_seqio_task_objects_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for convert_seqio_task_objects.""" 17 | 18 | from absl.testing import absltest 19 | import fiddle as fdl 20 | from paxml.tools.fiddle import convert_seqio_task_objects 21 | from praxis import pax_fiddle 22 | import seqio 23 | from seqio import test_utils 24 | 25 | 26 | class ConvertSeqioTaskObjectsTest(test_utils.FakeTaskTest): 27 | 28 | def test_convert_seqio_tasks(self): 29 | config = {"task": seqio.get_mixture_or_task("tfds_task")} 30 | transformed = convert_seqio_task_objects.convert_seqio_tasks(config=config) 31 | self.assertIsInstance(transformed["task"], pax_fiddle.Config) 32 | self.assertEqual( 33 | fdl.ordered_arguments(transformed["task"]), # pytype: disable=wrong-arg-types 34 | {"task_or_mixture_name": "tfds_task"}, 35 | ) 36 | 37 | 38 | if __name__ == "__main__": 39 | absltest.main() 40 | -------------------------------------------------------------------------------- /paxml/docs/README.md: -------------------------------------------------------------------------------- 1 | # Pax Basics (Start here) 2 | 3 | The **Pax Basics (Start here)** section is designed to provide all the of basics 4 | concepts necessary to understand the rest of the site. 5 | 6 | | **Section** | **Description** | 7 | | -------------------------------- | ----------------------------------------- | 8 | | [About Pax][about-pax] | A gentle introduction to what Pax is, why its important, and its major components. | 9 | | [Learning Pax][learning-pax] | An introduction to the Pax Components along with resources for learning more. | 10 | | [Key Concepts][concepts] | A set of concepts Pax users will frequently encounter. | 11 | | [Tutorials][tutorials] | Notebook examples for hands-on introduction. | 12 | | [Life of a Pax Experiment][life] (coming soon)| The main concepts involved in defining and running machine learning experiments using Pax. | 13 | 14 | Armed with the Pax concepts presented in these pages, you should be able to 15 | grasp all of the user journeys and how-to's presented in this site. 16 | 17 | 18 | 19 | 20 | [about-pax]: https://github.com/google/paxml/tree/main/paxml/docs/about-pax.md 21 | [life]: https://github.com/google/paxml/tree/main/paxml/docs/life_of_an_experiment.md 22 | [concepts]: https://github.com/google/paxml/tree/main/paxml/docs/concepts.md 23 | [learning-pax]: https://github.com/google/paxml/tree/main/paxml/docs/learning-pax.md 24 | [tutorials]: https://github.com/google/paxml/tree/main/paxml/docs/hands-on-tutorials.md -------------------------------------------------------------------------------- /paxml/contrib/gpu/scripts_gpu/checkpoint_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from abc import ABC 17 | 18 | from paxml import tasks_lib 19 | from praxis import pax_fiddle 20 | 21 | 22 | class CheckpointRestoreMixin(ABC): 23 | CHECKPOINT_RESTORE_PATH = None 24 | CHECKPOINT_IGNORE_RULES = None 25 | 26 | def configure_checkpoint_restore( 27 | self, task_p: pax_fiddle.Config[tasks_lib.SingleTask] 28 | ) -> pax_fiddle.Config[tasks_lib.SingleTask]: 29 | train_p = task_p.train 30 | 31 | if self.CHECKPOINT_RESTORE_PATH: 32 | train_p.init_from_checkpoint_rules = { 33 | self.CHECKPOINT_RESTORE_PATH: tasks_lib.CheckpointLoadingRules( 34 | task_p=task_p.clone(), 35 | load_rules=[("(.*)", "{}")], 36 | ignore_rules=self.CHECKPOINT_IGNORE_RULES, 37 | input_specs_provider_p=self.get_input_specs_provider_params(), 38 | ) 39 | } 40 | 41 | return task_p 42 | -------------------------------------------------------------------------------- /paxml/tools/fiddle/codegen_pax_code_ir.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """A few small code IR nodes for Pax.""" 17 | 18 | import dataclasses 19 | from typing import Any 20 | 21 | from fiddle._src.codegen.auto_config import code_ir 22 | import libcst as cst 23 | 24 | 25 | @dataclasses.dataclass 26 | class PaxCodegenTask(code_ir.CodegenTask): 27 | """CodegenTask that tracks a few extra bits of state. 28 | 29 | Attributes: 30 | highlevel_accesses: Accesses to high-level settings. These become fields on 31 | the generated class. 32 | sharding_diff_module: CST module containing a fiddler that will re-add 33 | sharding to a model. Factoring our fixtures into ones that generate an 34 | unsharded module, and a function that re-adds the sharding can be more 35 | readable. (You can disable this in `codegen.py`.) 36 | """ 37 | 38 | highlevel_accesses: dict[str, Any] = dataclasses.field(default_factory=dict) 39 | sharding_diff_module: cst.Module | None = None 40 | -------------------------------------------------------------------------------- /paxml/tasks/vision/BUILD: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Description: 17 | # Vision modeling-specific libraries and model configurations 18 | 19 | load("//paxml:paxml.bzl", "pytype_strict_library") 20 | load("//paxml:paxml.bzl", "py_strict_test") 21 | load("//praxis:build-visibility.bzl", "JAX_VISIBILITY") 22 | 23 | licenses(["notice"]) 24 | 25 | package(default_visibility = JAX_VISIBILITY) 26 | 27 | pytype_strict_library( 28 | name = "input_generator", 29 | srcs = [ 30 | "input_generator.py", 31 | "resnet_preprocessing.py", 32 | ], 33 | deps = [ 34 | # Implicit absl.logging dependency. 35 | # Implicit tensorflow_no_contrib dependency. 36 | ], 37 | ) 38 | 39 | py_strict_test( 40 | name = "input_generator_test", 41 | srcs = ["input_generator_test.py"], 42 | tags = [ 43 | "external", 44 | "notap", 45 | "requires-net:external", 46 | ], 47 | deps = [ 48 | ":input_generator", 49 | # Implicit absl.testing.absltest.absltest dependency. 50 | ], 51 | ) 52 | -------------------------------------------------------------------------------- /paxml/checkpoint_types.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Module defining possible checkpoint types and utility methods.""" 17 | 18 | import enum 19 | from etils import epath 20 | from paxml import base_task 21 | from praxis import pax_fiddle 22 | from praxis import py_utils 23 | 24 | 25 | @enum.unique 26 | class CheckpointType(str, enum.Enum): 27 | """The type of the checkpointing format.""" 28 | 29 | UNSPECIFIED = 'unspecified' 30 | FLAX = 'flax' 31 | GDA = 'gda' 32 | PERSISTENCE = 'persistence' 33 | 34 | 35 | def retrieve_checkpoint_type( 36 | maybe_use_persistence_checkpointing, 37 | task: base_task.BaseTask | pax_fiddle.Config[base_task.BaseTask], 38 | ) -> CheckpointType: 39 | """Retrieves the CheckpointType given the input arguments.""" 40 | using_pjit = task.model.mesh_shape is not None # pytype: disable=attribute-error 41 | if using_pjit or py_utils.pmap_use_tensorstore(): 42 | if maybe_use_persistence_checkpointing: 43 | return CheckpointType.PERSISTENCE 44 | else: 45 | return CheckpointType.GDA 46 | else: 47 | # pmap uses FLAX, Persistence-based or not. 48 | return CheckpointType.FLAX 49 | -------------------------------------------------------------------------------- /paxml/pip_package/compile_requirements.sh: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | #!/bin/bash 17 | # This script generates new requirements.txt for Praxis and Paxml 18 | # It pulls the nightly build docker image, and re-compile requirements.in 19 | 20 | set -e -x 21 | 22 | export TMP_FOLDER="$HOME/tmp/requirements" 23 | 24 | [ -f $TMP_FOLDER ] && rm -rf $TMP_FOLDER 25 | mkdir -p $TMP_FOLDER 26 | cp ../../paxml/pip_package/requirements.in $TMP_FOLDER/paxml-requirements.in 27 | cp ../../praxis/pip_package/requirements.in $TMP_FOLDER/praxis-requirements.in 28 | cp ./compile_requirements_helper.sh $TMP_FOLDER/ 29 | sed -i 's/praxis/#praxis/' $TMP_FOLDER/paxml-requirements.in 30 | 31 | docker pull gcr.io/pax-on-cloud-project/paxml_nightly_3.10:latest 32 | docker run --rm -a stdin -a stdout -a stderr -v $TMP_FOLDER:/tmp/requirements \ 33 | --name container1 gcr.io/pax-on-cloud-project/paxml_nightly_3.10:latest \ 34 | bash /tmp/requirements/compile_requirements_helper.sh 35 | 36 | cp $TMP_FOLDER/paxml-requirements.txt ../../paxml/pip_package/requirements.txt 37 | cp $TMP_FOLDER/praxis-requirements.txt ../../praxis/pip_package/requirements.txt 38 | 39 | rm -rf $TMP_FOLDER 40 | -------------------------------------------------------------------------------- /paxml/tools/dump_input_specs.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Dumps the input specs of an experiment config.""" 17 | 18 | from collections.abc import Sequence 19 | 20 | from absl import app 21 | from absl import flags 22 | from paxml import experiment_registry 23 | from paxml.tools import dump_input_specs_lib 24 | import tensorflow.compat.v2 as tf 25 | 26 | _EXP = flags.DEFINE_string('exp', None, 'A registered experiment name.') 27 | _OUTPUT_FILENAME = flags.DEFINE_string( 28 | 'output_filename', None, 'Output filename for dumping the input_specs.') 29 | 30 | FLAGS = flags.FLAGS 31 | 32 | 33 | def main(argv: Sequence[str]) -> None: 34 | if len(argv) > 1: 35 | raise app.UsageError('Too many command-line arguments.') 36 | 37 | experiment_config = experiment_registry.get(_EXP.value)() 38 | 39 | specs = dump_input_specs_lib.extract_input_specs(experiment_config) 40 | out_str = dump_input_specs_lib.specs_to_string(FLAGS.exp, specs) 41 | with tf.io.gfile.GFile(_OUTPUT_FILENAME.value, 'w') as fout: 42 | fout.write(out_str) 43 | print(out_str) 44 | 45 | 46 | if __name__ == '__main__': 47 | flags.mark_flags_as_required(['exp', 'output_filename']) 48 | app.run(main) 49 | -------------------------------------------------------------------------------- /paxml/tasks/vision/input_generator_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for input_generator.""" 17 | 18 | from absl.testing import absltest 19 | from paxml.tasks.vision import input_generator 20 | 21 | 22 | class InputGeneratorTest(absltest.TestCase): 23 | 24 | def _check_data_shape(self, 25 | p, 26 | is_multlabels=False, 27 | batch_size=8, 28 | image_size=33): 29 | p.data_shape = (image_size, image_size, 3) 30 | p.batch_size = batch_size 31 | p.num_batcher_threads = 1 32 | p.file_parallelism = 1 33 | p.file_buffer_size = 1 34 | inp = p.Instantiate() 35 | batch = inp.GetPreprocessedInputBatch() 36 | b, (h, w, d) = p.batch_size, p.data_shape 37 | self.assertEqual(batch.image.shape, (b, h, w, d)) 38 | self.assertEqual(batch.label_probs.shape, (b, p.num_classes)) 39 | 40 | def testImageNetValidation(self): 41 | p = input_generator.ImageNetValidation.Params() 42 | self._check_data_shape(p) 43 | 44 | def testImageNetTrain(self): 45 | p = input_generator.ImageNetTrain.Params() 46 | self._check_data_shape(p) 47 | 48 | 49 | if __name__ == '__main__': 50 | absltest.main() 51 | -------------------------------------------------------------------------------- /paxml/experiment_imports_all_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Test experiment configurations import and construction.""" 17 | 18 | from absl import app 19 | from absl import flags 20 | from absl.testing import absltest 21 | import jax 22 | from paxml import experiment_imports_test_helper 23 | from paxml import experiment_registry 24 | 25 | flags.DEFINE_list( 26 | 'exclude_regexes', [], 27 | 'Exclusion regexes of experiment configurations to be passed to the smoke ' 28 | 'test. The matching experiment configurations will be disabled from the ' 29 | 'smoke test.') 30 | flags.DEFINE_list( 31 | 'include_only_regexes', [], 32 | 'If provided, only experiments with names matching these regexes will be ' 33 | 'tested.') 34 | 35 | FLAGS = flags.FLAGS 36 | 37 | 38 | class Test(experiment_imports_test_helper.ExperimentImportsTestHelper): 39 | pass 40 | 41 | 42 | def main(args): 43 | del args # Unused. 44 | 45 | n = Test.create_test_methods_for_all_registered_experiments( 46 | experiment_registry, 47 | task_regexes=[''], 48 | exclude_regexes=FLAGS.exclude_regexes, 49 | include_only_regexes=FLAGS.include_only_regexes) 50 | assert n > 0, 'No experiment registered!' 51 | 52 | absltest.main() 53 | 54 | 55 | if __name__ == '__main__': 56 | jax.config.parse_flags_with_absl() 57 | app.run(main) 58 | -------------------------------------------------------------------------------- /paxml/tools/fiddle/convert_seqio_task_objects.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Converts SeqIO task objects to configs referencing their name.""" 17 | 18 | from typing import TypeVar 19 | 20 | from fiddle import daglish 21 | from praxis import pax_fiddle 22 | import seqio 23 | 24 | _T = TypeVar("_T") 25 | 26 | 27 | def convert_seqio_tasks(config: _T) -> _T: 28 | """Converts SeqIO Task objects within a config to more config-like objects. 29 | 30 | Currently, SeqIO tasks are identifiable by their name, which appears in a 31 | global registry. This is not an ideal pattern, but since it's hard to convert 32 | back from a Task object into a config producing that Task, using the name 33 | seems to be a reasonable strategy for now. 34 | 35 | Args: 36 | config: Fiddle config, or nested structure of configs. 37 | 38 | Returns: 39 | Version of config without SeqIO task instances. 40 | """ 41 | 42 | def transform(value, state: daglish.State): 43 | if isinstance(value, seqio.Task): 44 | # Note: The following object has a `task_or_mixture_name` parameter 45 | # instead of a `name` parameter, effectively changing the config API 46 | # slightly. 47 | return pax_fiddle.Config(seqio.get_mixture_or_task, value.name) 48 | return state.map_children(value) 49 | 50 | return daglish.MemoizedTraversal.run(transform, config) 51 | -------------------------------------------------------------------------------- /paxml/tools/fiddle/wrap_nested_maps_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for wrap_nested_maps.""" 17 | 18 | from absl.testing import absltest 19 | import fiddle.testing 20 | from paxml.experimental import nested_map_config_helper 21 | from paxml.tools.fiddle import wrap_nested_maps 22 | from praxis import pax_fiddle 23 | from praxis import pytypes 24 | 25 | 26 | class WrapNestedMapsTest(fiddle.testing.TestCase): 27 | 28 | def test_wrap_nested_maps(self): 29 | shared = {"value": 1} 30 | config = [ 31 | pytypes.NestedMap( 32 | foo={ 33 | "subdict": pytypes.NestedMap(bar=shared), 34 | "another_value": (shared,), 35 | } 36 | ) 37 | ] 38 | result = wrap_nested_maps.wrap_nested_maps(config=config) 39 | shared2 = {"value": 1} 40 | expected = [ 41 | pax_fiddle.Config( 42 | nested_map_config_helper.make_nested_map, 43 | base={ 44 | "foo": { 45 | "subdict": pax_fiddle.Config( 46 | nested_map_config_helper.make_nested_map, 47 | base={"bar": shared2}, 48 | ), 49 | "another_value": (shared2,), 50 | } 51 | }, 52 | ) 53 | ] 54 | self.assertDagEqual(result, expected) 55 | 56 | 57 | if __name__ == "__main__": 58 | absltest.main() 59 | -------------------------------------------------------------------------------- /paxml/experimental/BUILD: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Experimental packages. 17 | 18 | load("//paxml:paxml.bzl", "pytype_strict_library", "pytype_strict_test") 19 | load("//praxis:build-visibility.bzl", "JAX_VISIBILITY") 20 | 21 | package(default_visibility = JAX_VISIBILITY) 22 | 23 | licenses(["notice"]) 24 | 25 | pytype_strict_library( 26 | name = "baseline_experiment", 27 | srcs = ["baseline_experiment.py"], 28 | deps = [ 29 | # Implicit fiddle dependency. 30 | "//paxml:parameterized_experiment", 31 | "//praxis:pax_fiddle", 32 | ], 33 | ) 34 | 35 | pytype_strict_library( 36 | name = "nested_map_config_helper", 37 | srcs = ["nested_map_config_helper.py"], 38 | deps = ["//praxis:pytypes"], 39 | ) 40 | 41 | pytype_strict_test( 42 | name = "nested_map_config_helper_test", 43 | srcs = ["nested_map_config_helper_test.py"], 44 | deps = [ 45 | ":nested_map_config_helper", 46 | # Implicit absl.testing.absltest.absltest dependency. 47 | ], 48 | ) 49 | 50 | pytype_strict_test( 51 | name = "baseline_experiment_test", 52 | srcs = ["baseline_experiment_test.py"], 53 | deps = [ 54 | ":baseline_experiment", 55 | # Implicit absl.testing.absltest.absltest dependency. 56 | "//paxml:parameterized_experiment", 57 | "//paxml:tasks_lib", 58 | "//praxis:base_model", 59 | "//praxis:pax_fiddle", 60 | ], 61 | ) 62 | -------------------------------------------------------------------------------- /WORKSPACE: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Workspace file for PAXML.""" 17 | 18 | load("//paxml:build_defs.bzl", "pax_targets") # @unused 19 | load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") 20 | 21 | http_archive( 22 | name = "bazel_skylib", 23 | urls = [ 24 | "https://mirror.bazel.build/github.com/bazelbuild/bazel-skylib/releases/download/1.2.1/bazel-skylib-1.2.1.tar.gz", 25 | "https://github.com/bazelbuild/bazel-skylib/releases/download/1.2.1/bazel-skylib-1.2.1.tar.gz", 26 | ], 27 | sha256 = "f7be3474d42aae265405a592bb7da8e171919d74c16f082a5457840f06054728", 28 | ) 29 | 30 | load("@bazel_skylib//:workspace.bzl", "bazel_skylib_workspace") #buildifier: disable=load-on-top 31 | 32 | bazel_skylib_workspace() 33 | 34 | http_archive( 35 | name = "rules_python", 36 | sha256 = "cdf6b84084aad8f10bf20b46b77cb48d83c319ebe6458a18e9d2cebf57807cdd", 37 | strip_prefix = "rules_python-0.8.1", 38 | url = "https://github.com/bazelbuild/rules_python/archive/refs/tags/0.8.1.tar.gz", 39 | ) 40 | 41 | http_archive( 42 | name = "zlib", 43 | build_file = "@com_google_protobuf//:third_party/zlib.BUILD", 44 | sha256 = "c3e5e9fdd5004dcb542feda5ee4f0ff0744628baf8ed2dd5d66f8ca1197cb1a1", 45 | strip_prefix = "zlib-1.2.11", 46 | urls = [ 47 | "https://mirror.bazel.build/zlib.net/zlib-1.2.11.tar.gz", 48 | "https://zlib.net/zlib-1.2.11.tar.gz", 49 | ], 50 | ) 51 | -------------------------------------------------------------------------------- /paxml/contrib/gpu/scripts_gpu/run_pile_singlenode.sh: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | #! /bin/bash 17 | set -u 18 | set -o pipefail 19 | 20 | TFDS_DATA_DIR=$1 21 | VOCAB_PATH=$2 22 | PREC=${3:-"bfloat16"} # Precision (float32, bfloat16) 23 | NUM_GPUS=${4:-8} # Number of GPUs (1, 2, 4, 8) 24 | PERCORE_BATCH_SIZE=${5:-4} 25 | LOG_DIR=${6:-"test_logdir"} 26 | 27 | export VOCAB_PATH=$VOCAB_PATH 28 | 29 | BASE_XLA_FLAGS=${BASE_XLA_FLAGS:-"--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_triton_gemm=false 30 | --xla_gpu_enable_highest_priority_async_stream=true 31 | --xla_gpu_all_reduce_combine_threshold_bytes=51200 32 | --xla_gpu_enable_command_buffer=''"} 33 | export XLA_FLAGS="$BASE_XLA_FLAGS ${XLA_FLAGS:-}" 34 | 35 | 36 | mkdir -p ${LOG_DIR} 37 | python3 -u -m paxml.main \ 38 | --job_log_dir=${LOG_DIR} \ 39 | --fdl_config=paxml.contrib.gpu.scripts_gpu.configs.Pile126M \ 40 | --fdl.FPROP_DTYPE=\"${PREC}\" \ 41 | --fdl.ICI_MESH_SHAPE="[${NUM_GPUS}, 1, 1]" \ 42 | --fdl.DCN_MESH_SHAPE="[1,1,1]" \ 43 | --fdl.PERCORE_BATCH_SIZE=$PERCORE_BATCH_SIZE \ 44 | --tfds_data_dir=$TFDS_DATA_DIR \ 45 | --alsologtostderr \ 46 | 2>&1 | tee ${LOG_DIR}/pile_output.log 47 | 48 | EXP_STATUS=$? 49 | 50 | if [ $EXP_STATUS != 0 ]; then 51 | echo "Run failed" 52 | else 53 | echo "Run succeeded!" 54 | fi 55 | 56 | echo Output written to ${LOG_DIR}/pile_output.log 57 | -------------------------------------------------------------------------------- /paxml/partitioning_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for partitioning.""" 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | import jax 20 | from paxml import partitioning 21 | from praxis import py_utils 22 | from praxis import test_utils 23 | 24 | NestedMap = py_utils.NestedMap 25 | PartitionSpec = jax.sharding.PartitionSpec 26 | 27 | 28 | class PartitioningTest(test_utils.TestCase): 29 | 30 | @parameterized.parameters([(NestedMap, dict), (dict, NestedMap)]) 31 | def test_filter_nested_map_basics(self, src_type, filter_type): 32 | full_set = src_type(a=1, b=src_type(c=2, d=[3, src_type(e=6, f=7)])) 33 | partial_set = filter_type(a=0, b=filter_type(d=[0, filter_type(e=0)])) 34 | 35 | expected = src_type(a=1, b=src_type(d=[3, src_type(e=6)])) 36 | actual = partitioning.filter_nestedmap(full_set, partial_set) 37 | 38 | self.assertIsInstance(actual, src_type) 39 | self.assertEqual(expected, actual) 40 | 41 | def test_filter_nested_map_with_partition_spec(self): 42 | full_set = dict(a=[PartitionSpec(None), dict(b=2, c=PartitionSpec(None))]) 43 | partial_set = dict(a=[0, dict(c=0)]) 44 | 45 | expected = dict(a=[PartitionSpec(None), dict(c=PartitionSpec(None))]) 46 | actual = partitioning.filter_nestedmap(full_set, partial_set) 47 | 48 | self.assertIsInstance(actual, dict) 49 | self.assertEqual(expected, actual) 50 | 51 | 52 | if __name__ == '__main__': 53 | absltest.main() 54 | -------------------------------------------------------------------------------- /paxml/contrib/gpu/scripts_gpu/run_lambada_singlenode.sh: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | #! /bin/bash 17 | set -u 18 | set -o pipefail 19 | 20 | TFDS_DATA_DIR=$1 21 | VOCAB_PATH=$2 22 | PREC=${3:-"bfloat16"} # Precision (float32, bfloat16) 23 | NUM_GPUS=${4:-8} # Number of GPUs (1, 2, 4, 8) 24 | PERCORE_BATCH_SIZE=${5:-4} 25 | ### path to pretrained log_dir 26 | LOG_DIR=$6 27 | 28 | export VOCAB_PATH=$VOCAB_PATH 29 | BASE_XLA_FLAGS=${BASE_XLA_FLAGS:-"--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_triton_gemm=false 30 | --xla_gpu_enable_highest_priority_async_stream=true 31 | --xla_gpu_all_reduce_combine_threshold_bytes=51200 32 | --xla_gpu_enable_command_buffer=''"} 33 | export XLA_FLAGS="$BASE_XLA_FLAGS ${XLA_FLAGS:-}" 34 | 35 | 36 | mkdir -p ${LOG_DIR} 37 | python3 -u -m paxml.main \ 38 | --job_log_dir=${LOG_DIR} \ 39 | --fdl_config=paxml.contrib.gpu.scripts_gpu.configs.Lambada126M \ 40 | --fdl.FPROP_DTYPE=\"${PREC}\" \ 41 | --fdl.ICI_MESH_SHAPE="[${NUM_GPUS}, 1, 1]" \ 42 | --fdl.DCN_MESH_SHAPE="[1,1,1]" \ 43 | --fdl.PERCORE_BATCH_SIZE=$PERCORE_BATCH_SIZE \ 44 | --tfds_data_dir=$TFDS_DATA_DIR \ 45 | --mode='eval' \ 46 | --alsologtostderr \ 47 | 2>&1 | tee ${LOG_DIR}/lambada_output.log 48 | 49 | EXP_STATUS=$? 50 | 51 | if [ $EXP_STATUS != 0 ]; then 52 | echo "Run failed" 53 | else 54 | echo "Run succeeded!" 55 | fi 56 | 57 | echo Output written to ${LOG_DIR}/lambada_output.log 58 | -------------------------------------------------------------------------------- /paxml/contrib/gpu/scripts_gpu/run_base_config_multinode.sh: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | #! /bin/bash 17 | # Assumes you are using a SLURM cluster. Edit flags under --multiprocess_gpu below to suit your setup 18 | set -u 19 | 20 | CONFIG=$1 21 | PREC=${2:-"bfloat16"} # Precision (float32, bfloat16) 22 | NUM_GPUS=${3:-8} # Number of GPUs (1, 2, 4, 8) 23 | LOG_DIR=${4:-"test_logdir"} 24 | TFDS_DATA_DIR=${5:-'None'} 25 | VOCAB_PATH=${6:-'gs://t5-data/vocabs/cc_all.32000.100extra/sentencepiece.model'} 26 | ADDITIONAL_ARGS=${7:-""} 27 | 28 | export VOCAB_PATH=$VOCAB_PATH 29 | 30 | export XLA_PYTHON_CLIENT_MEM_FRACTION=${XLA_PYTHON_CLIENT_MEM_FRACTION:-0.9} 31 | BASE_XLA_FLAGS=${BASE_XLA_FLAGS:-"\ 32 | --xla_gpu_enable_latency_hiding_scheduler=true \ 33 | --xla_gpu_enable_triton_gemm=false \ 34 | --xla_gpu_enable_highest_priority_async_stream=true \ 35 | --xla_gpu_all_reduce_combine_threshold_bytes=51200 \ 36 | --xla_gpu_enable_command_buffer=''"} 37 | export XLA_FLAGS="$BASE_XLA_FLAGS ${XLA_FLAGS:-}" 38 | 39 | 40 | mkdir -p $LOG_DIR 41 | python3 -u -m paxml.main \ 42 | --job_log_dir=$LOG_DIR \ 43 | --fdl_config=paxml.contrib.gpu.scripts_gpu.configs.${CONFIG} \ 44 | --fdl.FPROP_DTYPE=\"${PREC}\" \ 45 | --multiprocess_gpu \ 46 | --server_addr=${SLURM_LAUNCH_NODE_IPADDR}:12345 \ 47 | --num_hosts=$SLURM_NTASKS \ 48 | --host_idx=$SLURM_PROCID \ 49 | --alsologtostderr \ 50 | $([[ $TFDS_DATA_DIR != "None" ]] && echo --tfds_data_dir=$TFDS_DATA_DIR) \ 51 | ${ADDITIONAL_ARGS} 52 | 53 | -------------------------------------------------------------------------------- /paxml/contrib/gpu/scripts_gpu/run_pile_multinode.sh: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | #! /bin/bash 17 | # Assumes you are using a SLURM cluster. Edit flags under --multiprocess_gpu below to suit your setup 18 | set -u 19 | 20 | TFDS_DATA_DIR=$1 21 | VOCAB_PATH=$2 22 | PREC=${3:-"bfloat16"} # Precision (float32, bfloat16) 23 | NUM_GPUS=${4:-8} # Number of GPUs (1, 2, 4, 8) 24 | PERCORE_BATCH_SIZE=${5:-4} 25 | LOG_DIR=${6:-"test_logdir"} 26 | 27 | export VOCAB_PATH=$VOCAB_PATH 28 | export XLA_PYTHON_CLIENT_MEM_FRACTION=${XLA_PYTHON_CLIENT_MEM_FRACTION:-0.85} 29 | BASE_XLA_FLAGS=${BASE_XLA_FLAGS:-"--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_triton_gemm=false 30 | --xla_gpu_enable_highest_priority_async_stream=true 31 | --xla_gpu_all_reduce_combine_threshold_bytes=51200 32 | --xla_gpu_enable_command_buffer=''"} 33 | export XLA_FLAGS="$BASE_XLA_FLAGS ${XLA_FLAGS:-}" 34 | 35 | 36 | ## NOTE: 126M trained with pure data parallel 37 | mkdir -p $LOG_DIR 38 | python3 -u -m paxml.main \ 39 | --job_log_dir=$LOG_DIR \ 40 | --fdl_config=paxml.contrib.gpu.scripts_gpu.configs.Pile126M \ 41 | --tfds_data_dir=$TFDS_DATA_DIR \ 42 | --fdl.FPROP_DTYPE=\"${PREC}\" \ 43 | --fdl.ICI_MESH_SHAPE="[${NUM_GPUS},1,1]" \ 44 | --fdl.DCN_MESH_SHAPE="[${SLURM_JOB_NUM_NODES},1,1]" \ 45 | --fdl.PERCORE_BATCH_SIZE=$PERCORE_BATCH_SIZE \ 46 | --multiprocess_gpu \ 47 | --server_addr=${SLURM_LAUNCH_NODE_IPADDR}:12345 \ 48 | --num_hosts=$SLURM_NTASKS \ 49 | --host_idx=$SLURM_PROCID \ 50 | --alsologtostderr 51 | 52 | -------------------------------------------------------------------------------- /paxml/base_executor.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Program executor that drives the training/evaluation loops.""" 17 | 18 | import abc 19 | from typing import Any, Sequence 20 | 21 | from etils import epath 22 | from paxml import decode_programs as decode_program_lib 23 | from paxml import partitioning 24 | from paxml import programs 25 | from paxml import tasks_lib 26 | from paxml import trainer_lib 27 | from praxis import base_input 28 | from praxis import pax_fiddle 29 | 30 | 31 | class BaseExecutor(metaclass=abc.ABCMeta): 32 | 33 | @abc.abstractmethod 34 | def setup( 35 | self, 36 | jax_task: tasks_lib.SingleTask, 37 | job_log_dir: epath.Path, 38 | checkpointer: Any, 39 | partitioner: partitioning.Partitioner, 40 | input_specs_provider: base_input.BaseInputSpecsProvider, 41 | # TODO(laigd): encapsulate train_input_p in train_program. 42 | train_input_p: pax_fiddle.Config[base_input.BaseInput], 43 | train_program: programs.BaseTrainProgram, 44 | eval_programs: Sequence[programs.BaseEvalProgram], 45 | decode_programs: Sequence[decode_program_lib.SingleTaskDecodeProgram], 46 | # TODO(laigd): this shouldn't be part of the executor API, consider adding 47 | # a dedicated executor for auto-tuning and get rid of this instead. 48 | early_stopping_fn: trainer_lib.EarlyStoppingFn | None, 49 | exit_after_ondemand_checkpoint: bool = False, 50 | enable_summary_writer: bool = True, 51 | ) -> None: 52 | """Sets up the programs and the executor.""" 53 | 54 | @abc.abstractmethod 55 | def start(self) -> None: 56 | """Start executing the programs.""" 57 | -------------------------------------------------------------------------------- /paxml/experiment_vars_summary_parser.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Util for parsing experiment summary text files. 17 | 18 | This util is intentionally in a separate lightweight module that can be used in 19 | other libraries without adding heavy dependencies on other core paxml modules. 20 | """ 21 | 22 | import ast 23 | from typing import Any 24 | 25 | from absl import logging 26 | 27 | 28 | def parse(cls_vars_summary: str) -> dict[str, Any]: 29 | """Parses a class variables summary into a dictionary of vars to values. 30 | 31 | Parses summaries created by experiment_utils.get_cls_vars_summary(). Values 32 | are left as strings if ast.literal_eval fails. For example, class instances 33 | will be strings. 34 | 35 | Args: 36 | cls_vars_summary: A summary in the format created by 37 | experiment_utils.get_cls_vars_summary(). 38 | 39 | Returns: 40 | Dictionary of variable names and values. 41 | """ 42 | cls_vars = {} 43 | lines = cls_vars_summary.splitlines() 44 | for line in lines: 45 | # Skip empty lines and class names 46 | if line.strip().endswith(':') or not line.strip(): 47 | continue 48 | # Get name and value of each class variable 49 | keyval = [s.strip() for s in line.split(':', maxsplit=1)] 50 | if len(keyval) == 2: 51 | (key, val) = keyval 52 | try: 53 | cls_vars[key] = ast.literal_eval(val) 54 | except (ValueError, SyntaxError): 55 | # If unable to evaluate as literal, then store value as string 56 | cls_vars[key] = val 57 | else: 58 | logging.warning('Warning: Ignoring line with unexpected format: %s', line) 59 | return cls_vars 60 | -------------------------------------------------------------------------------- /paxml/checkpoint_version.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Stores current checkpoint version and version history.""" 17 | 18 | # 19 | # Past versions: 20 | # 1.2 21 | # - State checkpoint uses Tensorstore's OCDBT format. 22 | # - The state will consist of Tensorstore-managed files plus a 'checkpoint' file 23 | # managed by Orbax, which stores the PyTree structure. 24 | # 25 | # 1.1 26 | # - Metadata has a new key 'train_state_metadata', which is a pytree of array 27 | # metadata corresponding to the train state, including shape, dtype and 28 | # is_masked_node for `TrainState.mdl_vars`. 29 | # 30 | # 1.0 31 | # - Checkpoints folders are organized into per-step directories, where each has 32 | # a subdirectory for every item. 33 | # - The items are 'state' and 'metadata'. 34 | # - Per-step metadata contains a version key. 35 | # 36 | # 0.0 37 | # - Checkpoints do not have per-item directories. 38 | # - Flax checkpoints may or may not be contained within a step directory. In 39 | # other words, the msgpack file may be 'checkpoint_1' instead of 40 | # 'checkpoint_1/checkpoint', where 'checkpoint' is the msgpack file. 41 | 42 | # TODO(b/273803615) When rolled out globally, make _OCDBT_VERSION the standard 43 | # version. 44 | _OCDBT_VERSION: float = 1.2 45 | _VERSION: float = 1.1 46 | _VERSION_KEY: str = 'version' 47 | 48 | 49 | def get_version(tensorstore_use_ocdbt: bool | None = None) -> float: 50 | if tensorstore_use_ocdbt is None: 51 | raise ValueError('Must set the value of `tensorstore_use_ocdbt`.') 52 | if tensorstore_use_ocdbt: 53 | return _OCDBT_VERSION 54 | return _VERSION 55 | 56 | 57 | def get_version_key() -> str: 58 | return _VERSION_KEY 59 | -------------------------------------------------------------------------------- /paxml/pip_package/build_pip_pkg.sh: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | #!/usr/bin/env bash 17 | 18 | set -e 19 | 20 | PLATFORM="$(uname -s | tr 'A-Z' 'a-z')" 21 | 22 | PIP_FILE_PREFIX="pip_package/" 23 | 24 | export PYTHON_VERSION="${PYTHON_VERSION:-3}" 25 | export PYTHON_MINOR_VERSION="${PYTHON_MINOR_VERSION}" 26 | 27 | if [[ -z "${PYTHON_MINOR_VERSION}" ]]; then 28 | PYTHON="python${PYTHON_VERSION}" 29 | else 30 | PYTHON="python${PYTHON_VERSION}.${PYTHON_MINOR_VERSION}" 31 | fi 32 | 33 | function main() { 34 | DEST=${1} 35 | if [[ -z "${DEST}" ]]; then 36 | echo "No destination directory provided." 37 | exit 1 38 | fi 39 | 40 | # Create the directory, then do dirname on a non-existent file inside it to 41 | # give us an absolute paths with tilde characters resolved to the destination 42 | # directory. 43 | [ ! -d $DEST ] && mkdir -p "${DEST}" 44 | 45 | DEST=$(readlink -f "${DEST}") 46 | echo "=== destination directory: ${DEST}" 47 | 48 | TMPDIR=$(mktemp -d -t tmp.XXXXXXXXXX) 49 | 50 | echo $(date) : "=== Using tmpdir: ${TMPDIR}" 51 | 52 | echo "=== Copy paxml files" 53 | 54 | cp setup.py "${TMPDIR}" 55 | cp requirements.in "${TMPDIR}" 56 | cp LICENSE "${TMPDIR}" 57 | rsync -avm -L paxml "${TMPDIR}" 58 | rsync -avm -L --include="*.so" --include="*_pb2.py" \ 59 | --exclude="*.runfiles" --exclude="*_obj" --include="*/" --exclude="*" \ 60 | bazel-bin/paxml "${TMPDIR}" 61 | 62 | pushd ${TMPDIR} 63 | echo $(date) : "=== Building wheel" 64 | 65 | ${PYTHON} setup.py bdist_wheel 66 | cp dist/*.whl "${DEST}" 67 | popd 68 | rm -rf ${TMPDIR} 69 | echo $(date) : "=== Output wheel file is in: ${DEST}" 70 | } 71 | 72 | main "$@" 73 | -------------------------------------------------------------------------------- /paxml/main_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from paxml import base_experiment 17 | from paxml import experiment_registry 18 | from paxml import main 19 | from absl.testing import absltest 20 | 21 | 22 | class FakeExperimentClassForTest(base_experiment.BaseExperiment): 23 | pass 24 | 25 | 26 | class MainTest(absltest.TestCase): 27 | 28 | def test_get_experiment_failure_to_import_module(self): 29 | with self.assertRaisesRegex( 30 | ValueError, 31 | 'Could not find experiment' 32 | ' `fake_module_for_paxml_main_test.Experiment9876` because could not' 33 | ' import module `fake_module_for_paxml_main_test`', 34 | ): 35 | # my_module is not a module that exists in this test. 36 | _ = main.get_experiment('fake_module_for_paxml_main_test.Experiment9876') 37 | 38 | def test_get_experiment_failure_to_find_experiment_in_module(self): 39 | with self.assertRaisesRegex( 40 | ValueError, 41 | 'Could not find experiment `builtins.Experiment9876`.\n' 42 | 'Registered experiments are: {}', 43 | ): 44 | # Module builtins is guaranteed to exist, but there's no corresponding 45 | # experiment in the builtins module. 46 | _ = main.get_experiment('builtins.Experiment9876') 47 | 48 | def test_get_experiment_success(self): 49 | try: 50 | experiment_registry.register(FakeExperimentClassForTest) 51 | 52 | actual = main.get_experiment( 53 | FakeExperimentClassForTest.__module__ 54 | + '.' 55 | + FakeExperimentClassForTest.__qualname__ 56 | ) 57 | 58 | expected = FakeExperimentClassForTest 59 | self.assertEqual(actual, expected) 60 | 61 | finally: 62 | # Reset registry to empty. 63 | experiment_registry._ExperimentRegistryHelper._registry = {} 64 | 65 | 66 | if __name__ == '__main__': 67 | absltest.main() 68 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Setup.py file for paxml.""" 17 | 18 | import os 19 | from setuptools import find_namespace_packages 20 | from setuptools import setup 21 | 22 | # Set this envvar to avoid installing packages from head that can overwrite 23 | # existing installs of those packages, e.g., jax 24 | SKIP_HEAD_INSTALLS = os.environ.get('SKIP_HEAD_INSTALLS', '') 25 | 26 | def _get_requirements(): 27 | """Parses requirements.txt file.""" 28 | install_requires_tmp = [] 29 | with open( 30 | os.path.join(os.path.dirname(__file__), './requirements.in'), 'r' 31 | ) as f: 32 | for line in f: 33 | package_name = line.strip() 34 | # Skip empty line or comments starting with "#". 35 | if ( 36 | not package_name 37 | or package_name[0] == '#' 38 | or (' @ ' in package_name and SKIP_HEAD_INSTALLS) 39 | ): 40 | continue 41 | else: 42 | install_requires_tmp.append(package_name) 43 | return install_requires_tmp 44 | 45 | 46 | install_requires = _get_requirements() 47 | 48 | setup( 49 | name='paxml', 50 | version='1.4.0', # use major/minor version number, e.g. "0.1.0" 51 | description=( 52 | 'Framework to configure and run machine learning experiments ' 53 | 'on top of Jax.' 54 | ), 55 | author='PAX team', 56 | author_email='pax-dev@google.com', 57 | packages=find_namespace_packages(include=['paxml*']), 58 | python_requires='>=3.10', 59 | install_requires=install_requires, 60 | url='https://github.com/google/paxml', 61 | license='Apache-2.0', 62 | extras_require={ 63 | 'gpu': ['jsonlines==3.1.0', 'pysimdjson==5.0.2', 'zstandard==0.18.0'], 64 | }, 65 | classifiers=[ 66 | 'Programming Language :: Python :: 3.10', 67 | 'Programming Language :: Python :: 3.11', 68 | ], 69 | zip_safe=False, 70 | ) 71 | -------------------------------------------------------------------------------- /paxml/tools/fiddle/codegen_highlevel_parameterization_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for codegen_highlevel_parameterization.""" 17 | 18 | from absl.testing import absltest 19 | from fiddle._src.codegen.auto_config import ir_printer 20 | from paxml.tools.fiddle import codegen 21 | from paxml.tools.fiddle import codegen_highlevel_parameterization 22 | from paxml.tools.fiddle import codegen_tracer 23 | 24 | 25 | class CodegenTest(absltest.TestCase): 26 | 27 | def test_highlevel_parameterization_transform(self): 28 | init_pass = codegen.InitTask() 29 | codegen_pass = codegen_highlevel_parameterization.HighlevelParameterization( 30 | lowercasing=False 31 | ) 32 | tracer_obj = codegen_tracer.make_tracer("tracer_name", 1) 33 | task = init_pass(tracer_obj) 34 | self.assertIs(codegen_pass(task), task) 35 | self.assertEqual( 36 | ir_printer.format_expr(task.top_level_call.fn.output_value), 37 | "self.tracer_name", 38 | ) 39 | 40 | def test_highlevel_parameterization_transforms_keys(self): 41 | init_pass = codegen.InitTask() 42 | codegen_pass = codegen_highlevel_parameterization.HighlevelParameterization( 43 | lowercasing=False 44 | ) 45 | tracer_obj = codegen_tracer.make_tracer("tracer_foo", 1) 46 | tracer_obj_2 = codegen_tracer.make_tracer("tracer_bar", 2) 47 | task = init_pass({tracer_obj: [1, 2, 3], 0: 10, tracer_obj_2: [4, 5, 6]}) 48 | self.assertIs(codegen_pass(task), task) 49 | converted = task.top_level_call.fn.output_value 50 | 51 | with self.subTest("order_preservation"): 52 | self.assertEqual(list(converted.values()), [[1, 2, 3], 10, [4, 5, 6]]) 53 | 54 | with self.subTest("tracer_conversion"): 55 | self.assertEqual(list(converted.keys())[0].attribute, "tracer_foo") 56 | self.assertEqual(list(converted.keys())[2].attribute, "tracer_bar") 57 | 58 | 59 | if __name__ == "__main__": 60 | absltest.main() 61 | -------------------------------------------------------------------------------- /paxml/experiment_vars_summary_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Unit tests for creating and parsing experiment class variables summaries.""" 17 | 18 | import unittest 19 | 20 | from paxml import base_experiment 21 | from paxml import experiment_utils 22 | from paxml import experiment_vars_summary_parser 23 | 24 | 25 | class TestExperimentA(base_experiment.BaseExperiment): 26 | INT_VAR = 0 27 | STR_VAR = 'A' 28 | TUPLE_VAR = (0, 'A') 29 | 30 | 31 | class TestExperimentB(TestExperimentA): 32 | STR_VAR = 'B' 33 | BOOL_VAR = True 34 | 35 | 36 | class ExperimentVarsSummaryTest(unittest.TestCase): 37 | 38 | def test_create_and_parse_cls_vars_summary(self): 39 | summary = experiment_utils.get_cls_vars_summary(TestExperimentB) 40 | 41 | summary_lines = summary.splitlines() 42 | self.assertEqual(summary_lines[0], 'paxml.base_experiment.BaseExperiment:') 43 | self.assertRegex( 44 | summary_lines[1], ' _abc_impl: <_abc._abc_data object at .*>' 45 | ) 46 | self.assertEqual(summary_lines[2], '') 47 | self.assertEqual(summary_lines[3], '__main__.TestExperimentA:') 48 | self.assertEqual(summary_lines[4], ' INT_VAR: 0') 49 | self.assertEqual(summary_lines[5], " TUPLE_VAR: (0, 'A')") 50 | self.assertEqual(summary_lines[6], '') 51 | self.assertEqual(summary_lines[7], '__main__.TestExperimentB:') 52 | self.assertEqual(summary_lines[8], ' STR_VAR: B') 53 | self.assertEqual(summary_lines[9], ' BOOL_VAR: True') 54 | 55 | cls_vars = experiment_vars_summary_parser.parse(summary) 56 | self.assertCountEqual( 57 | cls_vars.keys(), 58 | ['_abc_impl', 'INT_VAR', 'STR_VAR', 'TUPLE_VAR', 'BOOL_VAR'], 59 | ) 60 | self.assertEqual(cls_vars['INT_VAR'], 0) 61 | self.assertEqual(cls_vars['TUPLE_VAR'], (0, 'A')) 62 | self.assertEqual(cls_vars['STR_VAR'], 'B') 63 | self.assertEqual(cls_vars['BOOL_VAR'], True) 64 | 65 | 66 | if __name__ == '__main__': 67 | unittest.main() 68 | -------------------------------------------------------------------------------- /paxml/tools/fiddle/graphviz_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Provides a utility function for rendering configs with normalization.""" 17 | 18 | from fiddle import graphviz as fiddle_graphviz 19 | import graphviz 20 | from paxml.tools.fiddle import config_normalization 21 | 22 | 23 | def render( 24 | config, 25 | *, 26 | max_depth: int | None = 4, 27 | max_str_length: int | None = 100, 28 | remove_defaults: bool = True, 29 | convert_dataclasses: bool = True, 30 | remove_sharding_annotations: bool = False, 31 | unshare_sharding_config: bool = True, 32 | ) -> graphviz.Graph: 33 | """Renders a config with normalization. 34 | 35 | Args: 36 | config: The config to render. 37 | max_depth: The maximum depth of the rendered graph. 38 | max_str_length: The maximum length of the rendered strings. 39 | remove_defaults: Whether to remove default values. Often with Pax configs, 40 | dataclass field defaulting magic means that you get large, expanded 41 | templates that may actually be unused or equal to their default values. 42 | convert_dataclasses: Whether to convert dataclass instances to configs. This 43 | will only be applied if the dataclasses do not have __post_init__ 44 | functions, as __post_init__ can obscure the initial call values. 45 | remove_sharding_annotations: Whether to remove sharding annotations. 46 | unshare_sharding_config: If remove_sharding_annotations=False, whether to 47 | unshare values in sharding configuration. If 48 | remove_sharding_annotations=True, this should be False. 49 | 50 | Returns: 51 | A rendered graph. 52 | """ 53 | normalizer = config_normalization.ConfigNormalizer( 54 | remove_defaults=remove_defaults, 55 | convert_dataclasses=convert_dataclasses, 56 | remove_sharding_annotations=remove_sharding_annotations, 57 | unshare_sharding_config=unshare_sharding_config, 58 | ) 59 | config = normalizer(config) 60 | return fiddle_graphviz.render( 61 | config=config, max_depth=max_depth, max_str_length=max_str_length 62 | ) 63 | -------------------------------------------------------------------------------- /paxml/tools/fiddle/unshare_sharding.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Copies sharding annotations, making them unshared. 17 | 18 | Sharding doesn't matter for sharding annotations, but Fiddle maintains this 19 | information in general for shared objects. 20 | """ 21 | 22 | from typing import Any, TypeVar 23 | 24 | from fiddle import daglish 25 | from paxml.tools.fiddle import remove_sharding 26 | 27 | 28 | _T = TypeVar("_T") 29 | 30 | 31 | def _deep_copy(value): 32 | """Returns a copy of a value, unsharing all sub-values. 33 | 34 | Unlike regular copy.deepcopy, this method removes all sharing of sub-objects 35 | in `value`. See deepcopy documentation. Given an input 36 | 37 | input = [shared, shared] 38 | 39 | where `shared` is any immutable variable, 40 | 41 | `copy.deepcopy(input)` returns `[shared_2, shared_2]` 42 | `_deep_copy(input)` returns `[unshared_3, unshared_4]` 43 | 44 | where usually shared == shared_2 == unshared_3 == unshared_4 (except custom 45 | equivalence operators). As indicated, the deepcopy input has the same object 46 | in both list positions, whereas this _deep_copy() produces different objects. 47 | 48 | Args: 49 | value: Value to copy. 50 | """ 51 | return daglish.BasicTraversal.run( 52 | lambda x, state: state.map_children(x), value 53 | ) 54 | 55 | 56 | class CustomMemoizedTraversal(daglish.MemoizedTraversal): 57 | 58 | def apply(self, value: Any, state: Any) -> Any: 59 | result = super().apply(value, state) 60 | if remove_sharding._is_sharding_annotation(value, state.current_path): # pylint: disable=protected-access 61 | del self.memo[id(value)] 62 | return result 63 | 64 | 65 | def unshare_sharding(config: _T) -> _T: 66 | """Unshares sharding annotations. 67 | 68 | Args: 69 | config: Base configuration or structure of configuration. 70 | 71 | Returns: 72 | Config sharding annotations unshared. 73 | """ 74 | 75 | def transform(value, state: daglish.State): 76 | if remove_sharding._is_sharding_annotation(value, state.current_path): # pylint: disable=protected-access 77 | return _deep_copy(value) 78 | return state.map_children(value) 79 | 80 | return CustomMemoizedTraversal.run(transform, config) 81 | -------------------------------------------------------------------------------- /paxml/docs/tasks.md: -------------------------------------------------------------------------------- 1 | # Tasks 2 | 3 | doc/pax/tasks 4 | 5 | [TOC] 6 | 7 | ## Introduction 8 | 9 | In general, an ML **task** identifies what you are trying to accomplish with 10 | your ML model. Some examples of ML tasks include: 11 | 12 |
13 | 14 |
15 | 16 | * [Regression][regression] 17 | * [Classification][classification] 18 | 19 |
20 | 21 |
22 | 23 | * [Clustering][clustering] 24 | * [Anomaly detection][anomaly] 25 | 26 |
27 | 28 |
29 | 30 | * [Transcription][transcription] 31 | * [Machine translation][translation] 32 | 33 |
34 | 35 |
36 | 37 | In Pax, a **Task** is an object (derived from the [BaseTask][base-task] class) 38 | that is composed of the necessary components to address a given ML task. These 39 | components include: 40 | 41 | * A Model 42 | * Metrics (optional) 43 | * An optimizer 44 | 45 | A **Mixture** is a term for *a collection of Tasks* and enables fine-tuning a 46 | model on multiple Tasks simultaneously. 47 | 48 | > Tip: For a rudamentary introduction to the basic Pax components, check out 49 | > [Pax Elements][pax-elements]. If you want to dive in for a hands-on 50 | > experience, try the [Pax Model and Task Jupyter Notebook][model_ipynb]. 51 | 52 | ## Task How-To's 53 | 54 | ### Define a Task 55 | 56 | A Task contains one or more Models and Learner/Optimizers. The simplest Task 57 | subclass is a `SingleTask` which requires the following Hparams: 58 | 59 | ```python 60 | class HParams(base_task.BaseTask.HParams): 61 | """Task parameters. 62 | 63 | Attributes: 64 | name: Name of this task object, must be a valid identifier. 65 | model: The underlying JAX model encapsulating all the layers. 66 | train: HParams to control how this task should be trained. 67 | metrics: A BaseMetrics aggregator class to determine how metrics are 68 | computed. 69 | loss_aggregator: A LossAggregator aggregator class to derermine how the 70 | losses are aggregated (e.g single or MultiLoss) 71 | vn: HParams to control variational noise. 72 | ``` 73 | 74 | 75 | 76 | 77 | 78 | [anomaly]: internal-link/ml-glossary/#anomaly-detection 79 | [base-task]: https://github.com/google/paxml/tree/main/paxml/base_task.py 80 | [classification]: internal-link/ml-glossary/#classification_model 81 | [clustering]: internal-link/ml-glossary/#clustering 82 | [model_ipynb]: https://github.com/google/paxml/tree/main/paxml/docs/hands-on-tutorials.md#pax-model-and-task 83 | [pax-elements]: https://github.com/google/paxml/tree/main/paxml/docs/learning-pax.md#pax-elements 84 | [regression]: internal-link/ml-glossary/#regression-model 85 | [transcription]: https://fireflies.ai/blog/what-is-ai-transcription/ 86 | [translation]: https://en.wikipedia.org/wiki/Machine_translation 87 | -------------------------------------------------------------------------------- /paxml/tasks/vision/params/BUILD: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Description: 17 | # Vision modeling model configurations. 18 | 19 | load("//paxml:paxml.bzl", "pytype_library") 20 | # Import for Google-internal testing code. 21 | load("//paxml:build_defs.bzl", "pax_targets") 22 | load("//praxis:build-visibility.bzl", "JAX_VISIBILITY") 23 | 24 | package(default_visibility = JAX_VISIBILITY) 25 | 26 | licenses(["notice"]) 27 | 28 | pytype_library( 29 | name = "params", 30 | srcs = [ 31 | "imagenet_resnets.py", 32 | ], 33 | deps = [ 34 | # Implicit absl.flags dependency. 35 | # Implicit jax dependency. 36 | "//paxml:base_experiment", 37 | "//paxml:experiment_registry", 38 | "//paxml:learners", 39 | "//paxml:tasks_lib", 40 | "//paxml/tasks/vision:input_generator", 41 | "//praxis:base_input", 42 | "//praxis:base_layer", 43 | "//praxis:optimizers", 44 | "//praxis:pax_fiddle", 45 | "//praxis:py_utils", 46 | "//praxis:pytypes", 47 | "//praxis:schedules", 48 | "//praxis/layers", 49 | ], 50 | ) 51 | 52 | pytype_library( 53 | name = "mnist", 54 | srcs = ["mnist.py"], 55 | deps = [ 56 | # Implicit jax dependency. 57 | "//paxml:base_experiment", 58 | "//paxml:base_task", 59 | "//paxml:experiment_registry", 60 | "//paxml:learners", 61 | "//paxml:tasks_lib", 62 | "//praxis:base_input", 63 | "//praxis:base_layer", 64 | "//praxis:base_model", 65 | "//praxis:optimizers", 66 | "//praxis:pax_fiddle", 67 | "//praxis:py_utils", 68 | "//praxis:pytypes", 69 | "//praxis:schedules", 70 | "//praxis/layers:activations", 71 | "//praxis/layers:convolutions", 72 | "//praxis/layers:linears", 73 | "//praxis/layers:poolings", 74 | # Implicit tensorflow_no_contrib dependency. 75 | # Implicit tensorflow_datasets dependency. 76 | ], 77 | ) 78 | 79 | pax_targets( 80 | experiments = [ 81 | ":mnist", 82 | ], 83 | prefix_name = "mnist", 84 | ) 85 | 86 | pax_targets( 87 | experiments = [ 88 | ":params", 89 | ], 90 | ) 91 | 92 | # Google-internal testing code. 93 | -------------------------------------------------------------------------------- /paxml/tasks/lm/BUILD: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Description: 17 | # Language modeling-specific libraries and model configurations 18 | 19 | load("//paxml:paxml.bzl", "pytype_strict_library", "pytype_strict_test") 20 | load("//praxis:build-visibility.bzl", "JAX_VISIBILITY") 21 | 22 | package(default_visibility = JAX_VISIBILITY) 23 | 24 | licenses(["notice"]) 25 | 26 | pytype_strict_library( 27 | name = "input_generator", 28 | srcs = ["input_generator.py"], 29 | deps = [ 30 | # Implicit absl.logging dependency. 31 | # Implicit jax dependency. 32 | "//praxis:base_input", 33 | "//praxis:pax_fiddle", 34 | "//praxis:py_utils", 35 | "//praxis:pytypes", 36 | # Implicit tensorflow_no_contrib dependency. 37 | ], 38 | ) 39 | 40 | pytype_strict_library( 41 | name = "model_params", 42 | srcs = ["model_params.py"], 43 | deps = [ 44 | # Implicit fiddle dependency. 45 | # Implicit jax dependency. 46 | "//paxml:base_experiment", 47 | "//paxml:tasks_lib", 48 | "//praxis:asserts", 49 | "//praxis:base_layer", 50 | "//praxis:base_model", 51 | "//praxis:optimizers", 52 | "//praxis:pax_fiddle", 53 | "//praxis:py_utils", 54 | "//praxis:schedules", 55 | "//praxis/layers", 56 | "//praxis/layers:activations", 57 | "//praxis/layers:embedding_softmax", 58 | "//praxis/layers:gpu_fast_attention", 59 | "//praxis/layers:models", 60 | "//praxis/layers:transformer_models", 61 | "//praxis/layers/injection:fp8_nvidia_gpu", 62 | ], 63 | ) 64 | 65 | filegroup( 66 | name = "testdata", 67 | testonly = 1, 68 | srcs = glob(["testdata/*"]), 69 | ) 70 | 71 | pytype_strict_test( 72 | name = "input_generator_test", 73 | srcs = ["input_generator_test.py"], 74 | data = [":testdata"], 75 | deps = [ 76 | ":input_generator", 77 | # Implicit absl.testing.absltest.absltest dependency. 78 | # Implicit absl.testing.parameterized dependency. 79 | # Implicit numpy dependency. 80 | "//paxml:test_helper", 81 | "//praxis:base_hyperparams", 82 | "//praxis:pax_fiddle", 83 | "//praxis:test_utils", 84 | # Implicit tensorflow_no_contrib dependency. 85 | ], 86 | ) 87 | -------------------------------------------------------------------------------- /paxml/tools/BUILD: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Tools and utilities for Pax users. 17 | 18 | load("//paxml:paxml.bzl", "pytype_strict_library") 19 | load("//praxis:build-visibility.bzl", "JAX_VISIBILITY") 20 | 21 | package(default_visibility = JAX_VISIBILITY) 22 | 23 | licenses(["notice"]) 24 | 25 | pytype_strict_library( 26 | name = "dump_input_specs_lib", 27 | srcs = ["dump_input_specs_lib.py"], 28 | deps = [ 29 | # Implicit absl.logging dependency. 30 | # Implicit jax dependency. 31 | "//paxml:base_experiment", 32 | "//praxis:base_hyperparams", 33 | "//praxis:base_input", 34 | "//praxis:pytypes", 35 | # Implicit pyglove dependency. 36 | ], 37 | ) 38 | 39 | exports_files([ 40 | "dump_hparams.py", 41 | "dump_input_specs.py", 42 | "model_analysis.py", 43 | "validate_config.py", 44 | ]) 45 | 46 | pytype_strict_library( 47 | name = "validate_config_lib", 48 | srcs = ["validate_config_lib.py"], 49 | deps = [ 50 | # Implicit absl.app dependency. 51 | # Implicit absl.flags dependency. 52 | # Implicit absl.logging dependency. 53 | # Implicit fiddle.absl_flags dependency. 54 | # Implicit jax dependency. 55 | # Implicit numpy dependency. 56 | "//paxml:base_experiment", 57 | "//paxml:experiment_registry", 58 | "//praxis:base_layer", 59 | "//praxis:pax_fiddle", 60 | "//praxis:py_utils", 61 | ], 62 | ) 63 | 64 | pytype_strict_library( 65 | name = "dump_hparams_lib", 66 | srcs = ["dump_hparams_lib.py"], 67 | deps = [ 68 | # Implicit absl.app dependency. 69 | # Implicit absl.flags dependency. 70 | # Implicit absl.logging dependency. 71 | # Implicit etils dependency. 72 | # Implicit fiddle.absl_flags dependency. 73 | # Implicit jax dependency. 74 | # Implicit numpy dependency. 75 | "//paxml:base_experiment", 76 | "//paxml:experiment_registry", 77 | "//paxml:experiment_utils", 78 | "//praxis:base_hyperparams", 79 | "//praxis:base_layer", 80 | "//praxis:pax_fiddle", 81 | "//praxis:py_utils", 82 | "//praxis:pytypes", 83 | # Implicit pyglove dependency. 84 | # Implicit tensorflow_no_contrib dependency. 85 | ], 86 | ) 87 | -------------------------------------------------------------------------------- /paxml/pip_package/release.Dockerfile: -------------------------------------------------------------------------------- 1 | ARG cpu_base_image="ubuntu:22.04" 2 | ARG base_image=$cpu_base_image 3 | FROM $base_image 4 | 5 | LABEL maintainer="Pax team " 6 | 7 | # Re-declare args because the args declared before FROM can't be used in any 8 | # instruction after a FROM. 9 | ARG cpu_base_image="ubuntu:22.04" 10 | ARG base_image=$cpu_base_image 11 | ARG praxis_version 12 | ARG wheel_folder 13 | ENV WHEEL_FOLDER $wheel_folder 14 | ENV PYTHON_VERSION="3" 15 | ENV PYTHON_MINOR_VERSION="10" 16 | 17 | # Pick up some TF dependencies 18 | RUN apt update && DEBIAN_FRONTEND=noninteractive apt install -y --no-install-recommends software-properties-common 19 | RUN apt update && DEBIAN_FRONTEND=noninteractive apt install -y --no-install-recommends \ 20 | build-essential \ 21 | curl \ 22 | git \ 23 | pkg-config \ 24 | rename \ 25 | rsync \ 26 | unzip \ 27 | vim \ 28 | && \ 29 | apt-get clean && \ 30 | rm -rf /var/lib/apt/lists/* 31 | 32 | # Install python 3.10 33 | RUN apt-get update && apt-get install -y \ 34 | python3 python3-dev python3-pip python3-venv && \ 35 | rm -rf /var/lib/apt/lists/* && \ 36 | python3.10 -m pip install pip --upgrade && \ 37 | update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.10 0 38 | 39 | # Make python3.10 the default python version 40 | RUN update-alternatives --install /usr/bin/python python /usr/bin/python3.10 0 41 | 42 | ARG bazel_version=5.1.1 43 | # This is to install bazel, for development purposes. 44 | ENV BAZEL_VERSION ${bazel_version} 45 | RUN mkdir /bazel && \ 46 | cd /bazel && \ 47 | curl -H "User-Agent: Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/57.0.2987.133 Safari/537.36" -fSsL -O https://github.com/bazelbuild/bazel/releases/download/$BAZEL_VERSION/bazel-$BAZEL_VERSION-installer-linux-x86_64.sh && \ 48 | curl -H "User-Agent: Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/57.0.2987.133 Safari/537.36" -fSsL -o /bazel/LICENSE.txt https://raw.githubusercontent.com/bazelbuild/bazel/master/LICENSE && \ 49 | chmod +x bazel-*.sh && \ 50 | ./bazel-$BAZEL_VERSION-installer-linux-x86_64.sh && \ 51 | cd / && \ 52 | rm -f /bazel/bazel-$BAZEL_VERSION-installer-linux-x86_64.sh 53 | 54 | COPY . /paxml 55 | RUN mkdir $WHEEL_FOLDER 56 | RUN if [ "$praxis_version" = "release-test" ] ; then git clone https://github.com/google/praxis.git; else git clone -b r${praxis_version} https://github.com/google/praxis.git; fi 57 | RUN if [ "$praxis_version" = "release-test" ] ; then sed -i 's/ @ git.*//g' /praxis/requirements.in; fi 58 | RUN pip3 install -e praxis 59 | 60 | RUN cp -r praxis/praxis /paxml/ 61 | RUN sed -i 's/ @ git.*//g' paxml/requirements.in 62 | RUN pip3 install -r paxml/requirements.in 63 | 64 | RUN mv paxml/paxml/pip_package /paxml/ 65 | RUN cd /paxml && bash pip_package/build.sh 66 | 67 | WORKDIR / 68 | 69 | CMD ["/bin/bash"] 70 | -------------------------------------------------------------------------------- /paxml/host_callback_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Unit tests for callback.""" 17 | 18 | from absl.testing import absltest 19 | from paxml import host_callback 20 | 21 | 22 | class RepositoryTest(absltest.TestCase): 23 | 24 | def test_namespace(self): 25 | regex_id = host_callback.repository('namespace_1').add('regex') 26 | self.assertEqual( 27 | host_callback.repository('namespace_1').get(regex_id), 'regex' 28 | ) 29 | 30 | def test_namespace_same_object(self): 31 | repository1 = host_callback.repository('namespace_1') 32 | repository2 = host_callback.repository('namespace_1') 33 | self.assertIs(repository1, repository2) 34 | 35 | def test_namespace_different_object(self): 36 | repository1 = host_callback.repository('namespace_2') 37 | repository2 = host_callback.repository('namespace_3') 38 | self.assertIsNot(repository1, repository2) 39 | 40 | def test_add_and_pop(self): 41 | repository = host_callback.Repository() 42 | self.assertEqual(repository.size, 0) 43 | regex_id = repository.add('regex') 44 | self.assertEqual(repository.size, 1) 45 | self.assertTrue(repository.pop(regex_id)) 46 | self.assertEqual(repository.size, 0) 47 | 48 | def test_pop_unknown_id(self): 49 | repository = host_callback.Repository() 50 | self.assertEqual(repository.size, 0) 51 | self.assertFalse(repository.pop(0)) 52 | self.assertEqual(repository.size, 0) 53 | 54 | def test_add_and_get(self): 55 | repository = host_callback.Repository() 56 | regex_id = repository.add('regex') 57 | self.assertEqual(repository.get(regex_id), 'regex') 58 | 59 | def test_get_unknown_id(self): 60 | repository = host_callback.Repository() 61 | with self.assertRaises(KeyError): 62 | repository.get(0) 63 | 64 | def test_eviction(self): 65 | repository = host_callback.Repository(max_size=1) 66 | self.assertEqual(repository.size, 0) 67 | regex_ids = [] 68 | for i in range(4): 69 | regex_ids.append(repository.add(f'regex_{i}')) 70 | self.assertEqual(repository.size, 1) 71 | 72 | # The last regex still remains. 73 | self.assertEqual(repository.get(regex_ids[-1]), 'regex_3') 74 | 75 | # The others were evicted. 76 | for regex_id in regex_ids[:-1]: 77 | with self.assertRaises(KeyError): 78 | repository.get(regex_id) 79 | 80 | 81 | if __name__ == '__main__': 82 | absltest.main() 83 | -------------------------------------------------------------------------------- /paxml/setup_jax.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | r"""Utilities to set up JAX global configs.""" 17 | 18 | import dataclasses 19 | 20 | from absl import logging 21 | import jax 22 | from praxis import py_utils 23 | import tensorflow.compat.v2 as tf 24 | 25 | 26 | @dataclasses.dataclass 27 | class JaxDistributedOptions: 28 | coordinator_address: str 29 | num_processes: int 30 | process_id: int 31 | 32 | 33 | @py_utils.benchmark('[PAX STATUS]: ') 34 | def setup_jax( 35 | globally_use_hardware_rng: bool, 36 | jax_backend_target: str | None, 37 | jax_xla_backend: str | None, 38 | jax_enable_checks: bool, 39 | jax_traceback_filtering_option: str = 'auto', 40 | should_initialize_jax_distributed: bool = False, 41 | jax_distributed_options: JaxDistributedOptions | None = None, 42 | ) -> None: 43 | """Setups JAX and logs information about this job.""" 44 | 45 | # Hide any GPUs from TensorFlow. Otherwise TF might reserve memory and make 46 | # it unavailable to JAX. 47 | tf.config.experimental.set_visible_devices([], 'GPU') 48 | if globally_use_hardware_rng: 49 | py_utils.set_globally_use_rbg_prng_key() 50 | 51 | # Log tracing and compilation time. 52 | jax.config.update('jax_log_compiles', True) 53 | 54 | # Allow users to configure JAX traceback filtering. 55 | # https://github.com/google/jax/blob/main/jax/_src/config.py 56 | jax.config.update('jax_traceback_filtering', jax_traceback_filtering_option) 57 | 58 | if jax_enable_checks: 59 | jax.config.update('jax_enable_checks', True) 60 | logging.info('jax_enable_checks has been enabled.') 61 | 62 | if jax_backend_target: 63 | logging.info('Using JAX backend target %s', jax_backend_target) 64 | jax_xla_backend = 'None' if jax_xla_backend is None else jax_xla_backend 65 | logging.info('Using JAX XLA backend %s', jax_xla_backend) 66 | 67 | if should_initialize_jax_distributed: 68 | if jax_distributed_options: 69 | jax.distributed.initialize( 70 | jax_distributed_options.coordinator_address, 71 | jax_distributed_options.num_processes, 72 | jax_distributed_options.process_id, 73 | ) 74 | else: 75 | jax.distributed.initialize() 76 | 77 | logging.info('JAX process: %d / %d', jax.process_index(), jax.process_count()) 78 | logging.info('JAX devices: %r', jax.devices()) 79 | logging.info('jax.device_count(): %d', jax.device_count()) 80 | logging.info('jax.local_device_count(): %d', jax.local_device_count()) 81 | logging.info('jax.process_count(): %d', jax.process_count()) 82 | -------------------------------------------------------------------------------- /paxml/pip_package/build.sh: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | #!/bin/bash 17 | 18 | # This script uses custom-op docker, downloads code, builds and tests and then 19 | # builds a pip package. 20 | 21 | # See README.md for instructions to use this script. 22 | 23 | set -e -x 24 | 25 | # Override the following env variables if necessary. 26 | export PYTHON_VERSION="${PYTHON_VERSION:-3}" 27 | export PYTHON_MINOR_VERSION="${PYTHON_MINOR_VERSION}" 28 | export PIP_MANYLINUX2010="${PIP_MANYLINUX2010:-1}" 29 | export DEST="${WHEEL_FOLDER:-/tmp/wheels}" 30 | 31 | if [[ -z "${PYTHON_MINOR_VERSION}" ]]; then 32 | PYTHON="python${PYTHON_VERSION}" 33 | else 34 | PYTHON="python${PYTHON_VERSION}.${PYTHON_MINOR_VERSION}" 35 | fi 36 | update-alternatives --install /usr/bin/python3 python3 "/usr/bin/$PYTHON" 1 37 | 38 | function write_to_bazelrc() { 39 | echo "$1" >> .bazelrc 40 | } 41 | 42 | function write_action_env_to_bazelrc() { 43 | write_to_bazelrc "build --action_env $1=\"$2\"" 44 | } 45 | 46 | # Remove .bazelrc if it already exist 47 | [ -e .bazelrc ] && rm .bazelrc 48 | 49 | write_to_bazelrc "build -c opt" 50 | write_to_bazelrc 'build --cxxopt="-std=c++14"' 51 | write_to_bazelrc 'build --cxxopt="-D_GLIBCXX_USE_CXX11_ABI=0"' 52 | write_to_bazelrc 'build --auto_output_filter=subpackages' 53 | write_to_bazelrc 'build --copt="-Wall" --copt="-Wno-sign-compare"' 54 | write_to_bazelrc 'build --linkopt="-lrt -lm"' 55 | 56 | TF_NEED_CUDA=0 57 | echo 'Using installed tensorflow' 58 | TF_CFLAGS=( $(${PYTHON} -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_compile_flags()))') ) 59 | TF_LFLAGS="$(${PYTHON} -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_link_flags()))')" 60 | 61 | write_action_env_to_bazelrc "TF_HEADER_DIR" ${TF_CFLAGS:2} 62 | SHARED_LIBRARY_DIR=${TF_LFLAGS:2} 63 | SHARED_LIBRARY_NAME=$(echo $TF_LFLAGS | rev | cut -d":" -f1 | rev) 64 | if ! [[ $TF_LFLAGS =~ .*:.* ]]; then 65 | if [[ "$(uname)" == "Darwin" ]]; then 66 | SHARED_LIBRARY_NAME="libtensorflow_framework.dylib" 67 | else 68 | SHARED_LIBRARY_NAME="libtensorflow_framework.so" 69 | fi 70 | fi 71 | write_action_env_to_bazelrc "TF_SHARED_LIBRARY_DIR" ${SHARED_LIBRARY_DIR} 72 | write_action_env_to_bazelrc "TF_SHARED_LIBRARY_NAME" ${SHARED_LIBRARY_NAME} 73 | write_action_env_to_bazelrc "TF_NEED_CUDA" ${TF_NEED_CUDA} 74 | 75 | bazel clean 76 | bazel build ... 77 | bazel test --test_output=all --test_verbose_timeout_warnings -- paxml/... -paxml/tasks/vision:input_generator_test 78 | 79 | ./pip_package/build_pip_pkg.sh "$DEST" ${PYTHON_VERSION} 80 | pip3 freeze > "${DEST}/dependencies.txt" 81 | -------------------------------------------------------------------------------- /paxml/tools/fiddle/remove_sharding.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Pass to remove sharding annotations from a model. 17 | 18 | This can be useful for visualization, and code generation. Currently, the goal 19 | is not perfection; for visualization we just want to remove enough to make 20 | config more readable, and for code generation, sharding annotations are added 21 | back in later (it just helps to factor them into a separate function). 22 | """ 23 | 24 | from typing import TypeVar 25 | 26 | import fiddle as fdl 27 | from fiddle import daglish 28 | from praxis import base_layer 29 | from praxis import pax_fiddle 30 | 31 | _T = TypeVar("_T") 32 | 33 | BASE_LAYER_SHARDING_FIELDS = { 34 | "weight_split_dims_mapping", 35 | "activation_split_dims_mapping", 36 | } 37 | 38 | 39 | SHARDING_TYPES = ( 40 | base_layer.BaseLayer.WeightSharding, 41 | base_layer.BaseLayer.ActivationSharding, 42 | ) 43 | 44 | 45 | def _is_sharding_annotation(value, path: daglish.Path) -> bool: 46 | """Returns whether the current value or path is for a sharding annotation.""" 47 | if path: 48 | last_elt = path[-1] 49 | if ( 50 | isinstance(last_elt, daglish.Attr) 51 | and last_elt.name in BASE_LAYER_SHARDING_FIELDS 52 | ): 53 | return True 54 | 55 | if isinstance(value, fdl.Buildable): 56 | fn_or_cls = fdl.get_callable(value) 57 | return isinstance(fn_or_cls, type) and issubclass(fn_or_cls, SHARDING_TYPES) 58 | 59 | return False 60 | 61 | 62 | class RemoveSentinel: 63 | pass 64 | 65 | 66 | _remove_sentinel = RemoveSentinel() 67 | 68 | 69 | def remove_sharding(config: _T, replace_with_default: bool = False) -> _T: 70 | """Removes sharding annotations from a config. 71 | 72 | Args: 73 | config: Base configuration or structure of configuration. 74 | replace_with_default: Instead of just removing the sharding annotations, 75 | replace them with the default values. 76 | 77 | Returns: 78 | Config without sharding annotations. 79 | """ 80 | 81 | def transform(value, state: daglish.State): 82 | value = state.map_children(value) 83 | if _is_sharding_annotation(value, state.current_path): 84 | return _remove_sentinel 85 | elif isinstance(value, fdl.Buildable): 86 | for name, sub_value in fdl.ordered_arguments(value).items(): 87 | if sub_value is _remove_sentinel: 88 | delattr(value, name) 89 | if replace_with_default: 90 | default_obj = pax_fiddle.Config(fdl.get_callable(value)) 91 | setattr(value, name, getattr(default_obj, name)) 92 | return value 93 | 94 | return daglish.MemoizedTraversal.run(transform, config) 95 | -------------------------------------------------------------------------------- /paxml/pip_package/prepare_release.sh: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | #!/bin/bash 17 | 18 | # This script prepare a new release by: 19 | # 1) update version number in setup.py and cloudbuild-release.yaml 20 | # 2) add a new section in RELEASE.md with version and corresponding commit 21 | 22 | set -e -x 23 | 24 | function print_help_and_exit { 25 | echo "Usage: prepare_release.sh -v -x -d " 26 | echo "exp: bash prepare_release.sh -v 0.2.0 -x 0.2.0 -d 20221114" 27 | exit 0 28 | } 29 | 30 | while getopts "hv:d:x:" opt; do 31 | case $opt in 32 | v) 33 | PAXML_VERSION=${OPTARG} 34 | ;; 35 | x) 36 | PRAXIS_VERSION=${OPTARG} 37 | ;; 38 | d) 39 | BUILD_DATE=${OPTARG} 40 | ;; 41 | *) 42 | print_help_and_exit 43 | ;; 44 | esac 45 | done 46 | 47 | RELEASE_NOTE="../RELEASE.md" 48 | RELEASE_NOTE_NEW="release_new.md" 49 | 50 | if [[ -z "$BUILD_DATE" ]]; then 51 | echo "Build date is required!" 52 | exit 1 53 | fi 54 | 55 | if [[ -z "$PAXML_VERSION" ]]; then 56 | echo "paxml version is required!" 57 | exit 1 58 | fi 59 | 60 | echo "Build date: "$BUILD_DATE 61 | echo "PAXML version: "$PAXML_VERSION 62 | 63 | if [[ ! -z "$PRAXIS_VERSION" ]]; then 64 | sed -i "s/_PRAXIS_VERSION: '[0-9.]*'/_PRAXIS_VERSION: '$PRAXIS_VERSION'/" cloudbuild-release.yaml 65 | fi 66 | 67 | sed -i "s/version='[0-9.]*'/version='$PAXML_VERSION'/" setup.py 68 | sed -i "s/_RELEASE_VERSION: '[0-9.]*'/_RELEASE_VERSION: '$PAXML_VERSION'/" cloudbuild-release.yaml 69 | gsutil cp gs://pax-on-cloud-tpu-project/wheels/"$BUILD_DATE"/paxml_commit.txt ./ 70 | gsutil cp gs://pax-on-cloud-tpu-project/wheels/"$BUILD_DATE"/praxis_commit.txt ./ 71 | PAXML_COMMIT=$(> $RELEASE_NOTE_NEW 78 | echo "## Major Features and Improvements" >> $RELEASE_NOTE_NEW 79 | echo "## Breaking changes" >> $RELEASE_NOTE_NEW 80 | echo "## Deprecations" >> $RELEASE_NOTE_NEW 81 | echo "## Note" >> $RELEASE_NOTE_NEW 82 | echo "* Version: $PAXML_VERSION" >> $RELEASE_NOTE_NEW 83 | echo "* Build Date: $BUILD_DATE" >> $RELEASE_NOTE_NEW 84 | echo "* Paxml commit: $PAXML_COMMIT" >> $RELEASE_NOTE_NEW 85 | echo "* Praxis version: $PRAXIS_VERSION" >> $RELEASE_NOTE_NEW 86 | echo "* Praxis commit: $PRAXIS_COMMIT" >> $RELEASE_NOTE_NEW 87 | RELEASE_NOTE_TMP="RELEASE.tmp.md" 88 | cat $RELEASE_NOTE_NEW $RELEASE_NOTE >> $RELEASE_NOTE_TMP 89 | rm $RELEASE_NOTE_NEW 90 | rm $RELEASE_NOTE 91 | mv $RELEASE_NOTE_TMP $RELEASE_NOTE 92 | -------------------------------------------------------------------------------- /paxml/docs/tutorials/inputs_in_Pax-eval.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "4Rh0P4V34wZm" 7 | }, 8 | "source": [ 9 | "# Pax Workshop\n", 10 | "## Inputs in Pax - eval\n", 11 | "\n", 12 | "This colab demonstrates how inputs in Pax work.\n" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": null, 18 | "metadata": { 19 | "id": "mNJyBaezlI7O" 20 | }, 21 | "outputs": [], 22 | "source": [ 23 | "from praxis import base_input\n", 24 | "from praxis import base_hyperparams\n", 25 | "from paxml import seqio_input\n", 26 | "import numpy as np" 27 | ] 28 | }, 29 | { 30 | "cell_type": "markdown", 31 | "metadata": { 32 | "id": "2PJLV0fSCsyM" 33 | }, 34 | "source": [ 35 | "Let's look at some eval data." 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": null, 41 | "metadata": { 42 | "id": "x2-q5_vhBByz" 43 | }, 44 | "outputs": [], 45 | "source": [ 46 | "import t5.data.mixtures\n", 47 | "p_eval = seqio_input.SeqIOInput.HParams(\n", 48 | " mixture_name='super_glue_copa_v102',\n", 49 | " split_name='validation',\n", 50 | " task_feature_lengths={\n", 51 | " 'inputs': 256,\n", 52 | " 'targets': 32,\n", 53 | " },\n", 54 | " feature_converter=seqio_input.LanguageModelFeatures(pack=False),\n", 55 | " batch_size=4,\n", 56 | " infeed_host_index=1,\n", 57 | " num_infeed_hosts=4,\n", 58 | " is_training=False,\n", 59 | " reset_for_eval=True)\n", 60 | "inp_eval = base_hyperparams.instantiate(p_eval)" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": null, 66 | "metadata": { 67 | "id": "q0gj3RvcBoNU" 68 | }, 69 | "outputs": [], 70 | "source": [ 71 | "# The dataset is finite, and raises exception after being exhausted\n", 72 | "inp_eval.reset()\n", 73 | "for i in range(10):\n", 74 | " try:\n", 75 | " batch = inp_eval.get_next()\n", 76 | " except StopIteration:\n", 77 | " print(i, 'batches')\n", 78 | " break" 79 | ] 80 | }, 81 | { 82 | "cell_type": "markdown", 83 | "metadata": { 84 | "id": "he9Np05oCHD_" 85 | }, 86 | "source": [ 87 | "Try tweaking `infeed_host_index` above and verify that each host raises after the same number of batches. What happens if we change `num_infeed_hosts`? How does that change the number of batches it takes to exhaust the data?\n", 88 | "\n", 89 | "Also verify that `batch.eval_sample_weights` field is used to tell which examples are added as paddings." 90 | ] 91 | } 92 | ], 93 | "metadata": { 94 | "kernelspec": { 95 | "display_name": "Python 3 (ipykernel)", 96 | "language": "python", 97 | "name": "python3" 98 | }, 99 | "language_info": { 100 | "codemirror_mode": { 101 | "name": "ipython", 102 | "version": 3 103 | }, 104 | "file_extension": ".py", 105 | "mimetype": "text/x-python", 106 | "name": "python", 107 | "nbconvert_exporter": "python", 108 | "pygments_lexer": "ipython3", 109 | "version": "3.8.10" 110 | } 111 | }, 112 | "nbformat": 4, 113 | "nbformat_minor": 1 114 | } 115 | -------------------------------------------------------------------------------- /paxml/ghostnorm/BUILD: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Description: 17 | # Library of layers that implement ghost norm protocol 18 | load("//paxml:paxml.bzl", "pytype_strict_library", "pytype_strict_test") 19 | load("//praxis:build-visibility.bzl", "JAX_VISIBILITY") 20 | 21 | package(default_visibility = JAX_VISIBILITY) 22 | 23 | pytype_strict_library( 24 | name = "base", 25 | srcs = ["base.py"], 26 | deps = [ 27 | # Implicit flax.core dependency. 28 | # Implicit jax dependency. 29 | "//praxis:pytypes", 30 | ], 31 | ) 32 | 33 | pytype_strict_library( 34 | name = "linears", 35 | srcs = ["linears.py"], 36 | deps = [ 37 | ":base", 38 | # Implicit jax dependency. 39 | "//praxis:pytypes", 40 | "//praxis/layers", 41 | "//praxis/layers:base_ops", 42 | ], 43 | ) 44 | 45 | pytype_strict_library( 46 | name = "embedding", 47 | srcs = ["embedding.py"], 48 | deps = [ 49 | ":base", 50 | ":linears", 51 | # Implicit jax dependency. 52 | # Implicit jax.experimental.sparse dependency. 53 | "//praxis:base_layer", 54 | "//praxis:pytypes", 55 | "//praxis/layers", 56 | "//praxis/layers:base_ops", 57 | ], 58 | ) 59 | 60 | pytype_strict_library( 61 | name = "generic_wrapper", 62 | srcs = ["generic_wrapper.py"], 63 | deps = [ 64 | ":base", 65 | ":embedding", 66 | ":linears", 67 | # Implicit flax.core dependency. 68 | # Implicit jax dependency. 69 | "//praxis:base_layer", 70 | "//praxis:pax_fiddle", 71 | "//praxis:pytypes", 72 | "//praxis/layers:attentions", 73 | "//praxis/layers:embedding_softmax", 74 | "//praxis/layers:linears", 75 | "//praxis/layers:normalizations", 76 | "//praxis/layers:transformers", 77 | ], 78 | ) 79 | 80 | pytype_strict_test( 81 | name = "layers_test", 82 | srcs = ["layers_test.py"], 83 | deps = [ 84 | ":base", 85 | ":embedding", 86 | ":generic_wrapper", 87 | ":linears", 88 | # Implicit absl.testing.absltest.absltest dependency. 89 | # Implicit absl.testing.parameterized dependency. 90 | # Implicit jax dependency. 91 | # Implicit numpy dependency. 92 | # Implicit optax dependency. 93 | "//praxis:base_layer", 94 | "//praxis:pax_fiddle", 95 | "//praxis:py_utils", 96 | "//praxis/layers:embedding_softmax", 97 | "//praxis/layers:linears", 98 | "//praxis/layers:transformer_models", 99 | # Implicit tensorflow_no_contrib dependency. 100 | ], 101 | ) 102 | -------------------------------------------------------------------------------- /paxml/contrib/gpu/scripts_gpu/lora_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from abc import ABC 17 | import fiddle as fdl 18 | from paxml import tasks_lib 19 | from praxis import pax_fiddle 20 | from praxis.contrib.gpu.scripts_gpu.lora_layers import ( 21 | LoraAttentionProjection, 22 | LoraCombinedQKVProjection, 23 | LoraLinear, 24 | ) 25 | from praxis.layers import transformers 26 | 27 | 28 | class LoRAMixin(ABC): 29 | USE_LORA = False 30 | LORA_RANK = 8 31 | LORA_TARGET_LAYERS = "all" 32 | 33 | def _validate(self): 34 | if self.LORA_TARGET_LAYERS not in ["all", "attention", "mlp"]: 35 | raise ValueError( 36 | "LAYERS_TO_INCLUDE_FOR_LORA should be one of all, attention or mlp." 37 | ) 38 | 39 | def configure_lora( 40 | self, task_p: pax_fiddle.Config[tasks_lib.SingleTask] 41 | ) -> pax_fiddle.Config[tasks_lib.SingleTask]: 42 | if not self.USE_LORA: 43 | return task_p 44 | 45 | self._validate() 46 | train_p = task_p.train 47 | 48 | if hasattr(self, "CHECKPOINT_IGNORE_RULES"): 49 | self.CHECKPOINT_IGNORE_RULES = [r"^.*lora.*$"] 50 | 51 | train_p.learner.bprop_variable_inclusion = [r"^.*lora.*$"] 52 | stacked_p = task_p.model.lm_tpl.stacked_transformer_tpl 53 | if issubclass( 54 | fdl.get_callable(stacked_p), transformers.StackedTransformerRepeated 55 | ): 56 | stacked_p = stacked_p.block 57 | stacked_p = stacked_p.transformer_layer_params_tpl 58 | 59 | if self.LORA_TARGET_LAYERS in ["all", "mlp"]: 60 | ff_templ = stacked_p.tr_fflayer_tpl.fflayer_tpl 61 | original_linear_p = ff_templ.linear_tpl 62 | ff_templ.linear_tpl = pax_fiddle.Config( 63 | LoraLinear, 64 | rank=self.LORA_RANK, 65 | name="lora_linear", 66 | ) 67 | ff_templ.linear_tpl.copy_fields_from(original_linear_p) 68 | 69 | if self.LORA_TARGET_LAYERS in ["all", "attention"]: 70 | if hasattr(stacked_p.tr_atten_tpl, "combined_qkv_proj_tpl"): 71 | original_combined_qkv_p = stacked_p.tr_atten_tpl.combined_qkv_proj_tpl 72 | stacked_p.tr_atten_tpl.combined_qkv_proj_tpl = pax_fiddle.Config( 73 | LoraCombinedQKVProjection, 74 | name="lora_qkv_projection", 75 | rank=self.LORA_RANK, 76 | ) 77 | stacked_p.tr_atten_tpl.combined_qkv_proj_tpl.copy_fields_from( 78 | original_combined_qkv_p 79 | ) 80 | 81 | original_proj_p = stacked_p.tr_atten_tpl.proj_tpl 82 | stacked_p.tr_atten_tpl.proj_tpl = pax_fiddle.Config( 83 | LoraAttentionProjection, 84 | name="lora_attention_projection", 85 | rank=self.LORA_RANK, 86 | ) 87 | stacked_p.tr_atten_tpl.proj_tpl.copy_fields_from(original_proj_p) 88 | 89 | return task_p 90 | -------------------------------------------------------------------------------- /paxml/tools/fiddle/make_parameterized_experiment.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Makes a parameterized experiment from a legacy BaseExperiment subclass. 17 | 18 | TODO(b/292000357): Add unit tests. This is currently decently well integration 19 | tested through codegen_test.py 20 | """ 21 | 22 | from typing import Type 23 | 24 | from paxml import base_experiment 25 | from paxml import parameterized_experiment 26 | from paxml.tools.fiddle import config_normalization 27 | from praxis import pax_fiddle 28 | 29 | 30 | def from_legacy( 31 | experiment_cls: Type[base_experiment.BaseExperiment], 32 | *, 33 | normalizer: config_normalization.ConfigNormalizer 34 | | None = config_normalization.default_normalizer(), 35 | has_train_dataset: bool = True, 36 | has_input_specs_provider: bool = True, 37 | ) -> pax_fiddle.Config[parameterized_experiment.ParameterizedExperiment]: 38 | """Returns a ParameterizedExperiment config from a legacy experiment. 39 | 40 | Args: 41 | experiment_cls: Subclass of BaseExperiment. 42 | normalizer: Object that will normalize the output configuration. Pass None 43 | if you don't want any normalization. 44 | has_train_dataset: Whether the experiment has train datasets. Usually this 45 | is true, but some experiments may be inference-only, and calling their 46 | .train_datasets() method might raise an error. Set this to False in those 47 | cases. 48 | has_input_specs_provider: Likewise, usually it's safe to leave this as its 49 | default (True), but for occasional situations like testing, it may be 50 | reasonable to disable. 51 | """ 52 | experiment: base_experiment.BaseExperiment = experiment_cls() 53 | 54 | # Get the task configuration, modulo any changes. 55 | task_config = experiment.task() 56 | 57 | dataset_configs = experiment.datasets() 58 | eval_datasets = [ 59 | dataset_config 60 | for dataset_config in dataset_configs 61 | if not dataset_config.is_training 62 | ] 63 | decode_datasets = experiment.decode_datasets() 64 | if not isinstance(decode_datasets, list): 65 | decode_datasets = list(decode_datasets) 66 | overall_config = pax_fiddle.Config( 67 | parameterized_experiment.ParameterizedExperiment, 68 | task=task_config, 69 | eval_datasets=eval_datasets, 70 | ) 71 | if has_train_dataset: 72 | overall_config.train_datasets = experiment.train_datasets() 73 | if has_input_specs_provider: 74 | overall_config.input_specs_provider = ( 75 | experiment.get_input_specs_provider_params() 76 | ) 77 | if decode_datasets: 78 | overall_config.decode_datasets = decode_datasets 79 | 80 | # Now run normalization, and return the result. 81 | return normalizer(overall_config) if normalizer else overall_config 82 | -------------------------------------------------------------------------------- /paxml/base_inference_runner_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for base_inference_runner.""" 17 | 18 | from __future__ import annotations 19 | 20 | from typing import Any, Sequence 21 | 22 | from absl.testing import absltest 23 | import jax 24 | import numpy as np 25 | from paxml import base_inference_runner 26 | from paxml import train_states 27 | from praxis import base_hyperparams 28 | from praxis import base_layer 29 | from praxis import base_model 30 | from praxis import pax_fiddle 31 | from praxis import py_utils 32 | from praxis import pytypes 33 | from praxis import test_utils 34 | import tensorflow.compat.v2 as tf 35 | import tensorflow_datasets as tfds 36 | 37 | instantiate = base_hyperparams.instantiate 38 | NestedMap = py_utils.NestedMap 39 | NestedWeightHParams = base_layer.NestedWeightHParams 40 | PRNGKey = pytypes.PRNGKey 41 | TrainState = train_states.TrainState 42 | 43 | 44 | class DummyInference(base_inference_runner.BaseInferenceRunner): 45 | output: Any = None 46 | output_schema: Any = None 47 | 48 | def infer(self, train_state: TrainState, prng_key: PRNGKey, 49 | var_weight_hparams: NestedWeightHParams, 50 | input_batch: NestedMap) -> NestedMap: 51 | return self.output 52 | 53 | 54 | class BaseInferenceRunnerTest(test_utils.TestCase): 55 | 56 | def test_infer(self): 57 | dummy_output = NestedMap( 58 | tensor=np.arange(64, dtype=np.float32).reshape(8, 8), 59 | nested=NestedMap( 60 | text=np.array([f'{i}'.encode('utf-8') for i in range(8)], 61 | dtype=object))) 62 | dummy_schema = NestedMap( 63 | tensor=tfds.features.Tensor(shape=(8,), dtype=tf.float32), 64 | nested=NestedMap(text=tfds.features.Text())) 65 | 66 | infer_runner_p = pax_fiddle.Config( 67 | DummyInference, output=dummy_output, output_schema=dummy_schema 68 | ) 69 | infer_runner = infer_runner_p.Instantiate(model=None) 70 | 71 | serialized_outputs = infer_runner.serialize_outputs( 72 | # Pass dummy values to all 4 arguments of infer(). 73 | infer_runner.infer(*([None] * 4))) 74 | 75 | expected_outputs: Sequence[NestedMap] = py_utils.tree_unstack( 76 | dummy_output, 0 77 | ) 78 | self.assertEqual(len(serialized_outputs), len(expected_outputs)) 79 | 80 | features_dict = tfds.features.FeaturesDict(dummy_schema) 81 | for serialized, expected in zip(serialized_outputs, expected_outputs): 82 | output = features_dict.deserialize_example(serialized) 83 | output_np = jax.tree.map(lambda x: x.numpy(), output) 84 | 85 | for output_leaf, expected_leaf in zip( 86 | jax.tree_util.tree_leaves(output_np), 87 | jax.tree_util.tree_leaves(expected)): 88 | self.assertArraysEqual(output_leaf, expected_leaf) 89 | 90 | 91 | if __name__ == '__main__': 92 | absltest.main() 93 | -------------------------------------------------------------------------------- /paxml/profiling.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Expose functionalities for profiling code.""" 17 | 18 | from absl import logging 19 | 20 | 21 | class Profiler: 22 | """Dummy class to capture code profiles. 23 | 24 | Note: The current implementation is a no-op. 25 | """ 26 | 27 | def __init__( 28 | self, 29 | num_steps: float = 2.0, 30 | min_duration_sec: float = 1.0, 31 | default_duration_sec: float = 5.0, 32 | tag: str | None = None, 33 | max_num_hosts: int | None = None, 34 | capture_host_profile: bool = False, 35 | ) -> None: 36 | """Constructor. 37 | 38 | Args: 39 | num_steps: The number of steps to capture based on the step duration 40 | estimate that is set by calling update_step_moving_mean() successfully. 41 | min_duration_sec: The minimum duration of the profiler capture in seconds. 42 | Set to this value when the estimate step duration times num_steps is 43 | smaller than this value. 44 | default_duration_sec: The default duration of the profiler capture in 45 | seconds. Used when no step duration were sampled by calling 46 | update_step_moving_mean(). 47 | tag: An optional tag to be added to the profiler trace. 48 | max_num_hosts: If max_num_hosts is unspecified, all hosts of devices will 49 | be chosen. Otherwise, at most max_num_hosts will be chosen. This option 50 | only works with pathways. 51 | capture_host_profile: Capture host CPU profile as well. 52 | """ 53 | self._capture_num_steps = num_steps 54 | self._capture_min_duration_sec = min_duration_sec 55 | self._capture_default_duration_sec = default_duration_sec 56 | self._tag = tag 57 | self._max_num_hosts = max_num_hosts 58 | self._capture_host_profile = capture_host_profile 59 | self._step_duration_sec = 0. 60 | self._step_count = 0 61 | 62 | def capture_async(self) -> None: 63 | """Captures a trace asynchronously. 64 | 65 | The duration of the trace corresponds to step_duration_estimate_sec. 66 | """ 67 | logging.info('Dummy profiler currently does not capture any trace.') 68 | 69 | def update_step_moving_mean(self, duration_sec: float): 70 | """Updates the step duration moving average with a step duration estimate. 71 | 72 | Args: 73 | duration_sec: The duration of the step to add in seconds. 74 | """ 75 | self._step_duration_sec += duration_sec 76 | self._step_count += 1 77 | 78 | @property 79 | def step_duration_estimate_sec(self) -> float: 80 | """Estimates of the step duration in seconds. 81 | 82 | If update_step_moving_mean() has never been called, returns the default 83 | duration instead. 84 | """ 85 | if not self._step_count: 86 | return self._capture_default_duration_sec 87 | return self._step_duration_sec / self._step_count 88 | -------------------------------------------------------------------------------- /paxml/base_inference_runner.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Base API for inference runners.""" 17 | 18 | from __future__ import annotations 19 | 20 | import abc 21 | import dataclasses 22 | from typing import Any 23 | 24 | from paxml import train_states 25 | from praxis import base_hyperparams 26 | from praxis import base_layer 27 | from praxis import base_model 28 | from praxis import lazy_loader 29 | from praxis import py_utils 30 | from praxis import pytypes 31 | 32 | # TFDS is slow to import, so we do it lazily. 33 | tfds = lazy_loader.LazyLoader('tfds', globals(), 'tensorflow_datasets') 34 | 35 | NestedMap = py_utils.NestedMap 36 | NestedWeightHParams = base_layer.NestedWeightHParams 37 | PRNGKey = pytypes.PRNGKey 38 | TrainState = train_states.TrainState 39 | 40 | 41 | class BaseInferenceRunner(base_hyperparams.FiddleBaseParameterizable, abc.ABC): 42 | """Abstract base class for users to override. 43 | 44 | This class is essentially a container for a functional infer method and 45 | output schema definition. It defines (1) how to run inference to generate 46 | outputs given a model and some inputs, and (2) the corresponding schema for 47 | the output. 48 | 49 | TODO(b/238220793): Currently we only write Jax native types since we do all 50 | computations in a jit-ed context. We may eventually want to support non jax- 51 | native types such as strings. 52 | """ 53 | _model: Any = dataclasses.field(init=False, repr=False) 54 | model: base_model.BaseModel = None 55 | 56 | def __post_init__(self): 57 | self._model = self.model 58 | 59 | @abc.abstractmethod 60 | def infer(self, train_state: TrainState, prng_key: PRNGKey, 61 | var_weight_hparams: NestedWeightHParams, 62 | input_batch: NestedMap) -> NestedMap: 63 | """Generates some output given a model and input. Should be pmap-able.""" 64 | 65 | @property 66 | @abc.abstractmethod 67 | def output_schema(self) -> NestedMap: 68 | """Returns the schema for the output to be serialized. 69 | 70 | This must be a nested map of `tfds.features.FeatureConnector` types. See 71 | https://www.tensorflow.org/datasets/api_docs/python/tfds/features/FeatureConnector 72 | for more information. The following is an example: 73 | 74 | ``` 75 | return NestedMap( 76 | bucket_keys=tfds.features.Scalar(dtype=tf.int32), 77 | nested=NestedMap( 78 | field=tfds.features.Tensor(shape=(1000,), dtype=tf.int32) 79 | ), 80 | logprobs=tfds.features.Tensor(shape=(1, 32,), dtype=tf.float32), 81 | ) 82 | ``` 83 | """ 84 | 85 | def serialize_outputs(self, outputs: NestedMap) -> list[bytes]: 86 | input_batch_dim = 0 87 | features_dict = tfds.features.FeaturesDict(self.output_schema) 88 | examples = py_utils.tree_unstack(outputs, input_batch_dim) 89 | 90 | serialized_examples = [] 91 | for ex in examples: 92 | serialized_examples.append(features_dict.serialize_example(ex)) 93 | 94 | return serialized_examples 95 | -------------------------------------------------------------------------------- /paxml/tasks/lm/params/optimal_scaling.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Decoder-only language model configurations with Chinchilla-like scaling.""" 17 | 18 | from paxml import experiment_registry 19 | from paxml import tasks_lib 20 | from paxml.tasks.lm.params.lm_cloud import LmCloudSpmd 21 | from praxis import layers 22 | from praxis import pax_fiddle 23 | 24 | 25 | class OptimalScaling(LmCloudSpmd): 26 | """Decoder-only language model configurations with Chinchilla-like scaling.""" 27 | 28 | CHECKPOINT_POLICY = layers.AutodiffCheckpointType.SAVE_DOT_WITH_NO_BATCH_DIM 29 | 30 | # subclasses override these 31 | PERCORE_BATCH_SIZE = None 32 | NUM_LAYERS = None 33 | 34 | def task(self) -> pax_fiddle.Config[tasks_lib.SingleTask]: 35 | # pylint: disable=invalid-name 36 | assert self.NUM_LAYERS 37 | self.MODEL_DIMS = self.NUM_LAYERS * 128 38 | self.HIDDEN_DIMS = self.MODEL_DIMS * 4 39 | # pylint: enable=invalid-name 40 | return super().task() 41 | 42 | 43 | @experiment_registry.register 44 | class OptimalScaling2x2x1(OptimalScaling): 45 | NUM_LAYERS = 28 46 | PERCORE_BATCH_SIZE = 16 47 | ICI_MESH_SHAPE = [1, 4, 1] 48 | 49 | 50 | @experiment_registry.register 51 | class OptimalScaling2x2x2(OptimalScaling): 52 | NUM_LAYERS = 32 53 | PERCORE_BATCH_SIZE = 8 54 | ICI_MESH_SHAPE = [1, 8, 1] 55 | 56 | 57 | @experiment_registry.register 58 | class OptimalScaling2x2x4(OptimalScaling): 59 | NUM_LAYERS = 36 60 | PERCORE_BATCH_SIZE = 8 61 | ICI_MESH_SHAPE = [1, 16, 1] 62 | 63 | 64 | @experiment_registry.register 65 | class OptimalScaling2x4x4(OptimalScaling): 66 | NUM_LAYERS = 40 67 | PERCORE_BATCH_SIZE = 8 68 | ICI_MESH_SHAPE = [1, 32, 1] 69 | 70 | 71 | @experiment_registry.register 72 | class OptimalScaling4x4x4(OptimalScaling): 73 | NUM_LAYERS = 45 74 | PERCORE_BATCH_SIZE = 8 75 | ICI_MESH_SHAPE = [1, 64, 1] 76 | 77 | 78 | @experiment_registry.register 79 | class OptimalScaling4x4x8(OptimalScaling): 80 | NUM_LAYERS = 50 81 | PERCORE_BATCH_SIZE = 4 82 | ICI_MESH_SHAPE = [1, 128, 1] 83 | 84 | 85 | @experiment_registry.register 86 | class OptimalScaling4x8x8(OptimalScaling): 87 | NUM_LAYERS = 56 88 | PERCORE_BATCH_SIZE = 2 89 | ICI_MESH_SHAPE = [1, 64, 4] 90 | 91 | 92 | @experiment_registry.register 93 | class OptimalScaling4x8x16(OptimalScaling): 94 | NUM_LAYERS = 64 95 | PERCORE_BATCH_SIZE = 2 96 | ICI_MESH_SHAPE = [1, 128, 4] 97 | 98 | 99 | @experiment_registry.register 100 | class OptimalScaling4x16x16(OptimalScaling): 101 | NUM_LAYERS = 64 102 | PERCORE_BATCH_SIZE = 1 103 | ICI_MESH_SHAPE = [1, 256, 4] 104 | 105 | 106 | @experiment_registry.register 107 | class OptimalScaling4x16x32(OptimalScaling): 108 | NUM_LAYERS = 64 109 | PERCORE_BATCH_SIZE = 1 110 | ICI_MESH_SHAPE = [1, 512, 4] 111 | 112 | 113 | @experiment_registry.register 114 | class OptimalScaling4x24x32(OptimalScaling): 115 | NUM_LAYERS = 64 116 | PERCORE_BATCH_SIZE = 1 117 | ICI_MESH_SHAPE = [1, 768, 4] 118 | -------------------------------------------------------------------------------- /paxml/tools/fiddle/codegen_highlevel_parameterization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Codegen pass to trace high-level settings and make those useable.""" 17 | 18 | import dataclasses 19 | from typing import Any 20 | 21 | from fiddle import daglish 22 | from fiddle._src.codegen.auto_config import code_ir 23 | from fiddle.codegen.auto_config import experimental_top_level_api 24 | from paxml.tools.fiddle import codegen_pax_code_ir 25 | 26 | 27 | @dataclasses.dataclass(frozen=True) 28 | class HighlevelParameterization(experimental_top_level_api.CodegenPass): 29 | """Parameterizes fixtures by a highlevel settings object.""" 30 | 31 | lowercasing: bool = False 32 | 33 | def __call__(self, task: Any, **pass_kwargs: Any) -> Any: 34 | assert isinstance(task, codegen_pax_code_ir.PaxCodegenTask) 35 | all_fns = task.top_level_call.all_fixture_functions() 36 | 37 | def process_fn(fn: code_ir.FixtureFunction): 38 | self_name = code_ir.Name("self", is_generated=False) 39 | fn.parameters.insert(0, code_ir.Parameter(self_name)) 40 | 41 | def add_self_to_calls(value, state: daglish.State): 42 | value = state.map_children(value) 43 | 44 | # Convert calls to sub-fixtures like model_fixture() to 45 | # self.model_fixture(). 46 | if isinstance(value, code_ir.SymbolOrFixtureCall): 47 | if isinstance(value.symbol_expression, code_ir.FixtureReference): 48 | value.symbol_expression = code_ir.AttributeExpression( 49 | base=code_ir.VariableReference(self_name), 50 | attribute=value.symbol_expression.name.value, 51 | ) 52 | 53 | # Convert any instances of highlevel variables to relevant expressions. 54 | elif hasattr(value, "__highlevel_name__"): 55 | attribute_name = value.__highlevel_name__ 56 | if self.lowercasing: 57 | attribute_name = attribute_name.lower() 58 | task.highlevel_accesses[attribute_name] = value 59 | return code_ir.AttributeExpression( 60 | base=code_ir.VariableReference(self_name), 61 | attribute=attribute_name, 62 | ) 63 | 64 | # Process dict keys too (normally ignored by daglish). 65 | if isinstance(value, dict): 66 | converted = {} 67 | for key, sub_value in value.items(): 68 | if hasattr(key, "__highlevel_name__"): 69 | attribute_name = key.__highlevel_name__ 70 | if self.lowercasing: 71 | attribute_name = attribute_name.lower() 72 | task.highlevel_accesses[attribute_name] = key 73 | key = code_ir.AttributeExpression( 74 | base=code_ir.VariableReference(self_name), 75 | attribute=attribute_name, 76 | ) 77 | converted[key] = sub_value 78 | return converted 79 | 80 | return value 81 | 82 | fn.replace_with(daglish.MemoizedTraversal.run(add_self_to_calls, fn)) 83 | 84 | for fn in all_fns: 85 | process_fn(fn) 86 | 87 | return task 88 | -------------------------------------------------------------------------------- /paxml/pip_package/Dockerfile: -------------------------------------------------------------------------------- 1 | ARG cpu_base_image="ubuntu:22.04" 2 | ARG base_image=$cpu_base_image 3 | FROM $base_image 4 | 5 | LABEL maintainer="Pax team " 6 | 7 | # Re-declare args because the args declared before FROM can't be used in any 8 | # instruction after a FROM. 9 | ARG cpu_base_image="ubuntu:22.04" 10 | ARG base_image=$cpu_base_image 11 | ARG wheel_folder 12 | ENV WHEEL_FOLDER $wheel_folder 13 | ENV PYTHON_VERSION="3" 14 | ENV PYTHON_MINOR_VERSION="10" 15 | 16 | # Pick up some TF dependencies 17 | RUN apt update && DEBIAN_FRONTEND=noninteractive apt install -y --no-install-recommends software-properties-common 18 | RUN apt update && DEBIAN_FRONTEND=noninteractive apt install -y --no-install-recommends \ 19 | build-essential \ 20 | curl \ 21 | git \ 22 | pkg-config \ 23 | rename \ 24 | rsync \ 25 | unzip \ 26 | vim \ 27 | && \ 28 | apt-get clean && \ 29 | rm -rf /var/lib/apt/lists/* 30 | 31 | # Install python 3.10 32 | RUN apt-get update && apt-get install -y \ 33 | python3 python3-dev python3-pip python3-venv && \ 34 | rm -rf /var/lib/apt/lists/* && \ 35 | python3.10 -m pip install pip --upgrade && \ 36 | update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.10 0 37 | 38 | # Install pip package 39 | RUN pip3 install pip-tools 40 | 41 | ARG bazel_version=5.1.1 42 | # This is to install bazel, for development purposes. 43 | ENV BAZEL_VERSION ${bazel_version} 44 | RUN mkdir /bazel && \ 45 | cd /bazel && \ 46 | curl -H "User-Agent: Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/57.0.2987.133 Safari/537.36" -fSsL -O https://github.com/bazelbuild/bazel/releases/download/$BAZEL_VERSION/bazel-$BAZEL_VERSION-installer-linux-x86_64.sh && \ 47 | curl -H "User-Agent: Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/57.0.2987.133 Safari/537.36" -fSsL -o /bazel/LICENSE.txt https://raw.githubusercontent.com/bazelbuild/bazel/master/LICENSE && \ 48 | chmod +x bazel-*.sh && \ 49 | ./bazel-$BAZEL_VERSION-installer-linux-x86_64.sh && \ 50 | cd / && \ 51 | rm -f /bazel/bazel-$BAZEL_VERSION-installer-linux-x86_64.sh 52 | 53 | COPY . /paxml 54 | 55 | RUN mkdir -p $WHEEL_FOLDER && cd paxml && \ 56 | git rev-parse HEAD > $WHEEL_FOLDER/paxml_commit.txt 57 | 58 | RUN git clone https://github.com/google/praxis.git && \ 59 | cd praxis && git rev-parse HEAD > $WHEEL_FOLDER/praxis_commit.txt 60 | 61 | RUN cp -r praxis/praxis /paxml/ 62 | 63 | RUN cd /praxis && \ 64 | pip-compile --quiet requirements.in \ 65 | --output-file $WHEEL_FOLDER/praxis_requirements.txt 66 | 67 | RUN cd /paxml && \ 68 | sed -i 's/praxis/#praxis/' requirements.in && \ 69 | pip-compile --quiet requirements.in \ 70 | /praxis/requirements.in \ 71 | --output-file $WHEEL_FOLDER/paxml_requirements.txt 72 | 73 | RUN pip3 install --no-deps -r paxml/paxml/pip_package/requirements.txt 74 | 75 | RUN mv paxml/paxml/pip_package /paxml/ 76 | RUN cd /paxml && bash pip_package/build.sh 77 | #TODO:enable -praxis/layers:normalizations_test once the new Lingvo pip package is released 78 | RUN cd praxis && \ 79 | bazel test \ 80 | --test_output=all \ 81 | --test_verbose_timeout_warnings \ 82 | -- \ 83 | praxis/... \ 84 | -praxis/layers:attentions_test \ 85 | -praxis/layers:convolutions_test \ 86 | -praxis/layers:ctc_objectives_test \ 87 | -praxis/layers:embedding_softmax_test \ 88 | -praxis/layers:models_test \ 89 | -praxis/layers:ngrammer_test \ 90 | -praxis/layers:normalizations_test \ 91 | -praxis/layers:transformer_models_test \ 92 | -praxis/layers:transformers_test 93 | 94 | RUN cd praxis && bash praxis/pip_package/build_pip_pkg.sh 95 | 96 | WORKDIR / 97 | 98 | CMD ["/bin/bash"] 99 | -------------------------------------------------------------------------------- /paxml/tools/fiddle/codegen_tracer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tracer classes for the Pax-specific codegen.""" 17 | 18 | # Note: Fiddle and Pax devs are in collaboration; please generally do not import 19 | # private libraries from Fiddle. 20 | 21 | import dataclasses 22 | import enum 23 | from typing import Type, TypeVar 24 | 25 | from fiddle import daglish 26 | from paxml import base_experiment 27 | 28 | 29 | @dataclasses.dataclass(frozen=True) 30 | class BoolTracer: 31 | """Special-case tracer for bool values, since we can't inherit from bool.""" 32 | 33 | name: str 34 | value: bool 35 | 36 | @property 37 | def __highlevel_name__(self): 38 | return self.name 39 | 40 | def __bool__(self): 41 | return self.value 42 | 43 | def __eq__(self, other): 44 | if isinstance(other, bool): 45 | return self.value == other 46 | return super().__eq__(other) 47 | 48 | 49 | def make_tracer( 50 | name, value, allowed_types=(bool, int, float, str, list, tuple) 51 | ): 52 | """Wraps a value in a tracer object, that has a name.""" 53 | typ = type(value) 54 | if not issubclass(typ, allowed_types): 55 | raise ValueError( 56 | f"Type {typ} is not allowed. If it seems to work with " 57 | "the subclassed tracers, please add it to allowed_types " 58 | "and write a unit test." 59 | ) 60 | 61 | if typ == bool: 62 | return BoolTracer(name, value) 63 | wrapped = type( 64 | f"Wrapped{typ.__name__}_{name}", (typ,), {"__highlevel_name__": name} 65 | ) 66 | return wrapped(value) 67 | 68 | 69 | class TracerMixin: 70 | """Mixin that will wrap experiments' property accessors with traced versions. 71 | 72 | This means that when there are high-level settings as attributes on a 73 | BaseExperiment instance, a traced value will be returned. This can then be 74 | intercepted in code generation, resulting in partially abstracted code. 75 | """ 76 | 77 | __trace_names__: set[str] 78 | 79 | def __init__(self, *args, **kwargs): 80 | super().__init__(*args, **kwargs) 81 | object.__setattr__(self, "__trace_names__", set()) 82 | 83 | def __getattribute__(self, name): 84 | result = super().__getattribute__(name) 85 | if name in {"__trace_names__"}: 86 | return result 87 | elif isinstance(result, enum.Enum): 88 | return result 89 | elif isinstance(result, (bool, int, float, str, list, tuple)): 90 | self.__trace_names__.add(name) 91 | return make_tracer(name, result) 92 | else: 93 | return result 94 | 95 | 96 | def make_subclass_mixin( 97 | experiment_cls: Type[base_experiment.BaseExperiment], 98 | ): 99 | """Creates a dynamic subclass of an experiment that adds TracerMixin.""" 100 | if not issubclass(experiment_cls, base_experiment.BaseExperiment): 101 | raise TypeError("Please pass a subclass of BaseExperiment.") 102 | cls_name = experiment_cls.__name__ 103 | return type(f"{cls_name}Traced", (experiment_cls, TracerMixin), {}) 104 | 105 | 106 | _T = TypeVar("_T") 107 | 108 | 109 | def remove_tracers(root: _T) -> _T: 110 | def transform(value, state: daglish.State): 111 | if isinstance(value, BoolTracer): 112 | value = bool(value) 113 | elif hasattr(value, "__highlevel_name__"): 114 | value = type(value).__bases__[0](value) 115 | return state.map_children(value) 116 | 117 | return daglish.MemoizedTraversal.run(transform, root) 118 | -------------------------------------------------------------------------------- /paxml/tools/dump_input_specs_lib.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Input specification retrieval from either provider or input pipeline.""" 17 | 18 | import pprint 19 | 20 | from absl import logging 21 | import jax 22 | from paxml import base_experiment 23 | from praxis import base_hyperparams 24 | from praxis import base_input 25 | from praxis import pytypes 26 | import pyglove as pg 27 | 28 | NestedShapeDtypeStruct = pytypes.NestedShapeDtypeStruct 29 | instantiate = base_hyperparams.instantiate 30 | 31 | 32 | def extract_input_specs( 33 | experiment_config: base_experiment.BaseExperiment, 34 | ) -> tuple[NestedShapeDtypeStruct | None, NestedShapeDtypeStruct | None]: 35 | """Extracts the input specs for a given experiment config.""" 36 | logging.info('Starting extraction of input specs info.') 37 | 38 | # Input specs from input_specs_provider. 39 | input_specs_provider_p = experiment_config.get_input_specs_provider_params() 40 | input_specs_from_provider = None 41 | if not isinstance(input_specs_provider_p, 42 | base_input.DatasetInputSpecsProvider): 43 | logging.info('Extracting input specs info from provider.') 44 | specs_provider = instantiate(input_specs_provider_p) 45 | input_specs_from_provider = specs_provider.get_input_specs() 46 | 47 | # NOTE(daiyip): putting `training_dataset()` and `instantiate(train_input_p)` 48 | # under an AutoML context allows dynamic evaluation of hyperparameters that is 49 | # to be swept. The first values of all `pg.oneof` will be used. 50 | with pg.hyper.DynamicEvaluationContext(require_hyper_name=True).collect(): 51 | # Input specs from experiment config 52 | logging.info('Extracting input specs info from input pipeline.') 53 | try: 54 | # Clone it since we may mutate a few attributes below. 55 | train_input_p = experiment_config.train_datasets()[0].clone() 56 | except ValueError: 57 | logging.info('Could not find a training input pipeline for %s', 58 | experiment_config) 59 | train_input_p = None 60 | 61 | if train_input_p is None: 62 | return input_specs_from_provider, None 63 | 64 | # Attempt at reducing loading time when using Lingvo input. 65 | if isinstance(train_input_p, base_input.LingvoInputAdaptor): 66 | train_input_p.input.num_batcher_threads = 1 67 | train_input_p.input.file_parallelism = 1 68 | train_input_p.input.file_buffer_size = 32 69 | 70 | logging.info('Instantiating input pipeline...') 71 | input_pipeline = instantiate(train_input_p) 72 | logging.info('Retrieving specs from input pipeline...') 73 | input_specs_from_input_pipeline = jax.tree_util.tree_map( 74 | lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype), 75 | input_pipeline.get_next_padded()) 76 | 77 | return input_specs_from_provider, input_specs_from_input_pipeline 78 | 79 | 80 | def specs_to_string( 81 | experiment_name: str, 82 | specs: tuple[NestedShapeDtypeStruct | None, NestedShapeDtypeStruct | None], 83 | ) -> str: 84 | """Converts input specs into a readable string.""" 85 | pp = pprint.PrettyPrinter(indent=2) 86 | specs_provider, specs_pipeline = specs 87 | out_lst = [] 88 | out_lst.append(experiment_name) 89 | out_lst.append('From InputSpecsProvider:') 90 | out_lst.append(pp.pformat(specs_provider)) 91 | out_lst.append('From training input pipeline:') 92 | out_lst.append(pp.pformat(specs_pipeline)) 93 | out_lst.append('\n\n') 94 | out_str = '\n'.join(out_lst) 95 | logging.info(out_str) 96 | return out_str 97 | -------------------------------------------------------------------------------- /paxml/tasks/lm/params/c4_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for GPT-3 models defined in c4.py.""" 17 | import os 18 | 19 | from absl.testing import absltest 20 | import fiddle as fdl 21 | import jax 22 | from paxml.tasks.lm.params import c4 23 | from praxis import layers 24 | from praxis import optimizers 25 | from praxis import schedules 26 | from praxis import test_utils 27 | 28 | prev_xla_flags = None 29 | 30 | 31 | def setUpModule(): 32 | global prev_xla_flags 33 | prev_xla_flags = os.getenv("XLA_FLAGS") 34 | flags_str = prev_xla_flags or "" 35 | # Don't override user-specified device count, or other XLA flags. 36 | if "xla_force_host_platform_device_count" not in flags_str: 37 | os.environ["XLA_FLAGS"] = ( 38 | flags_str + " --xla_force_host_platform_device_count=768" 39 | ) 40 | # Clear any cached backends so new CPU backend will pick up the env var. 41 | jax.extend.backend.get_backend.cache_clear() 42 | 43 | 44 | def tearDownModule(): 45 | if prev_xla_flags is None: 46 | del os.environ["XLA_FLAGS"] 47 | else: 48 | os.environ["XLA_FLAGS"] = prev_xla_flags 49 | jax.extend.backend.get_backend.cache_clear() 50 | 51 | 52 | class C4Test(test_utils.TestCase): 53 | 54 | def test_gpt3_mlperf_bs1p5k_config(self): 55 | config = c4.C4SpmdPipelineGpt3AdamMLPerfHPBS1p5k768Replicas() 56 | task_p = config.task() 57 | 58 | # Model architecture 59 | lm_tpl = task_p.model.lm_tpl 60 | self.assertEqual(config.MAX_SEQ_LEN, 2048) 61 | self.assertEqual(config.NUM_LAYERS, 96) 62 | self.assertEqual(config.NUM_HEADS, 96) 63 | self.assertEqual(lm_tpl.model_dims, 12288) 64 | self.assertEqual(config.HIDDEN_DIMS, 12288 * 4) 65 | self.assertGreaterEqual(lm_tpl.vocab_size, 50257) 66 | self.assertEqual( 67 | lm_tpl.position_emb_tpl.cls, 68 | layers.embedding_softmax.TrainablePositionalEmbedding, 69 | ) 70 | 71 | global_batch_size = int( 72 | config.PERCORE_BATCH_SIZE * jax.device_count() + 1e-6 73 | ) 74 | self.assertEqual(global_batch_size, 1536) 75 | self.assertEqual( 76 | task_p.train.eval_interval_steps * global_batch_size, 24 * 1024 77 | ) 78 | 79 | # Early stopping fn 80 | self.assertEqual(task_p.early_stopping_fn.cls, c4.EarlyStoppingFn) 81 | self.assertAlmostEqual(task_p.early_stopping_fn.target_log_pplx, 2.69) 82 | 83 | # optimizer and HPs 84 | optimizer_p = task_p.train.learner.optimizer 85 | self.assertEqual(fdl.get_callable(optimizer_p), optimizers.Adam) 86 | self.assertAlmostEqual(optimizer_p.weight_decay, 0.1) 87 | self.assertAlmostEqual(optimizer_p.epsilon, 1e-8) 88 | self.assertAlmostEqual(optimizer_p.beta1, 0.9) 89 | self.assertAlmostEqual(optimizer_p.beta2, 0.95) 90 | self.assertAlmostEqual(optimizer_p.clip_gradient_norm_to_value, 1.0) 91 | 92 | # LR schedule 93 | lr_schedule = optimizer_p.lr_schedule 94 | self.assertEqual(lr_schedule.cls, schedules.LinearRampupCosineDecay) 95 | self.assertEqual(lr_schedule.warmup_steps * global_batch_size, 265 * 1536) 96 | self.assertEqual(lr_schedule.decay_start, lr_schedule.warmup_steps + 1) 97 | self.assertEqual(lr_schedule.decay_end * global_batch_size, 108600 * 1536) 98 | self.assertAlmostEqual(optimizer_p.learning_rate, 2e-5) 99 | self.assertAlmostEqual(lr_schedule.min_ratio, 0.1) 100 | self.assertAlmostEqual(lr_schedule.max, 1.0) 101 | 102 | 103 | if __name__ == "__main__": 104 | absltest.main() 105 | -------------------------------------------------------------------------------- /paxml/tools/fiddle/codegen_external_init_checkpoint_fns_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for codegen_external_init_checkpoint_fns.""" 17 | 18 | from absl.testing import absltest 19 | from fiddle._src.codegen.auto_config import init_task 20 | from fiddle._src.codegen.auto_config import ir_printer 21 | from fiddle.experimental import visualize 22 | from paxml.tools.fiddle import codegen_external_init_checkpoint_fns 23 | from paxml.tools.fiddle import codegen_pax_code_ir 24 | from paxml.tools.fiddle import test_fixtures 25 | 26 | 27 | class InitCheckpointRulesFromOtherTaskTest(absltest.TestCase): 28 | 29 | def test_creates_calls(self): 30 | config = test_fixtures.SampleExperimentWithInitFromCheckpointRules().task() 31 | config = visualize.with_defaults_trimmed(config, remove_deep_defaults=True) 32 | task = init_task.init_task(config) 33 | task = codegen_pax_code_ir.PaxCodegenTask( 34 | original_config=task.original_config, 35 | top_level_call=task.top_level_call, 36 | ) 37 | codegen_pass = ( 38 | codegen_external_init_checkpoint_fns.InitCheckpointRulesFromOtherTask() 39 | ) 40 | task = codegen_pass( 41 | task, 42 | init_checkpoint_experiments={ 43 | "/path/to/my/checkpoint": ( 44 | test_fixtures.SampleExperimentWithInputSpecsProvider 45 | ) 46 | }, 47 | ) 48 | 49 | debug_str = ir_printer.format_task(task) 50 | self.assertIn( 51 | "task_p=call:.task(*[[]], **{})>", 54 | debug_str, 55 | ) 56 | self.assertIn( 57 | "input_specs_provider_p=call:.get_input_specs_provider_params(*[[]], **{})>", 60 | debug_str, 61 | ) 62 | 63 | def test_errors_unused_rule(self): 64 | config = test_fixtures.SampleExperiment().task() 65 | config = visualize.with_defaults_trimmed(config, remove_deep_defaults=True) 66 | task = init_task.init_task(config) 67 | task = codegen_pax_code_ir.PaxCodegenTask( 68 | original_config=task.original_config, 69 | top_level_call=task.top_level_call, 70 | ) 71 | codegen_pass = ( 72 | codegen_external_init_checkpoint_fns.InitCheckpointRulesFromOtherTask() 73 | ) 74 | with self.assertRaisesRegex( 75 | ValueError, r"Didn't encounter.*path/to/my/checkpoint" 76 | ): 77 | codegen_pass( 78 | task, 79 | init_checkpoint_experiments={ 80 | "/path/to/my/checkpoint": ( 81 | test_fixtures.SampleExperimentWithInputSpecsProvider 82 | ) 83 | }, 84 | ) 85 | 86 | def test_errors_unmatched_rule(self): 87 | config = test_fixtures.SampleExperimentWithInitFromCheckpointRules().task() 88 | config = visualize.with_defaults_trimmed(config, remove_deep_defaults=True) 89 | task = init_task.init_task(config) 90 | task = codegen_pax_code_ir.PaxCodegenTask( 91 | original_config=task.original_config, 92 | top_level_call=task.top_level_call, 93 | ) 94 | codegen_pass = ( 95 | codegen_external_init_checkpoint_fns.InitCheckpointRulesFromOtherTask() 96 | ) 97 | with self.assertRaisesRegex( 98 | ValueError, r"No task for checkpoint /path/to/my/checkpoint" 99 | ): 100 | codegen_pass( 101 | task, 102 | init_checkpoint_experiments={}, 103 | ) 104 | 105 | 106 | if __name__ == "__main__": 107 | absltest.main() 108 | -------------------------------------------------------------------------------- /paxml/tools/fiddle/remove_sharding_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for remove_sharding.""" 17 | 18 | import dataclasses 19 | import typing 20 | 21 | from absl.testing import absltest 22 | import fiddle as fdl 23 | from fiddle import daglish 24 | from paxml.tools.fiddle import remove_sharding 25 | from praxis import base_layer 26 | from praxis import pax_fiddle 27 | 28 | 29 | class TestLayer(base_layer.BaseLayer): 30 | x: int = 12 31 | 32 | 33 | def fake_config(): 34 | config = pax_fiddle.Config(TestLayer, x=14) 35 | config.weight_split_dims_mapping.wt = ("foo", "bar") 36 | return config 37 | 38 | 39 | def _is_sharding_config(typ): 40 | origin = typing.get_origin(typ) 41 | if not isinstance(origin, type) or not issubclass(origin, pax_fiddle.Config): 42 | return False 43 | args = typing.get_args(typ) 44 | if len(args) != 1: 45 | return False 46 | return args[0] in remove_sharding.SHARDING_TYPES 47 | 48 | 49 | class RemoveShardingTest(absltest.TestCase): 50 | 51 | def test_base_layer_sharding_fields(self): 52 | detected = { 53 | name 54 | for name, typ in typing.get_type_hints(base_layer.BaseLayer).items() 55 | if _is_sharding_config(typ) 56 | } 57 | self.assertEqual(detected, remove_sharding.BASE_LAYER_SHARDING_FIELDS) 58 | 59 | def test_is_sharding_annotation(self): 60 | self.assertTrue( 61 | remove_sharding._is_sharding_annotation( 62 | value=pax_fiddle.Config(base_layer.BaseLayer.WeightSharding), 63 | path=(), 64 | ) 65 | ) 66 | 67 | def test_is_sharding_annotation_works_with_function(self): 68 | self.assertFalse( 69 | remove_sharding._is_sharding_annotation( 70 | value=pax_fiddle.Config(fake_config), 71 | path=(), 72 | ) 73 | ) 74 | 75 | def test_is_sharding_annotation_by_path(self): 76 | # This is just to catch if the user forgets to inherit from WeightSharding / 77 | # ActivationSharding. 78 | class Foo(base_layer.BaseLayer): 79 | 80 | @dataclasses.dataclass 81 | class MyActivationSplitDimsMapping: 82 | bar_sharding: list[str] = dataclasses.field(default_factory=list) 83 | 84 | activation_split_dims_mapping: MyActivationSplitDimsMapping = ( 85 | dataclasses.field(default_factory=MyActivationSplitDimsMapping) 86 | ) 87 | 88 | config = pax_fiddle.Config( 89 | Foo, 90 | activation_split_dims_mapping=pax_fiddle.Config( 91 | Foo.MyActivationSplitDimsMapping, bar_sharding=["baz", "qux"] 92 | ), 93 | ) 94 | self.assertTrue( 95 | remove_sharding._is_sharding_annotation( 96 | value=config.activation_split_dims_mapping, 97 | path=(daglish.Attr("activation_split_dims_mapping"),), 98 | ) 99 | ) 100 | 101 | def test_remove_sharding(self): 102 | config = fake_config() 103 | without_sharding = remove_sharding.remove_sharding(config=config) 104 | self.assertIn("weight_split_dims_mapping", fdl.ordered_arguments(config)) 105 | self.assertNotIn( 106 | "weight_split_dims_mapping", fdl.ordered_arguments(without_sharding) 107 | ) 108 | 109 | # Ensure other attributes were not deleted. 110 | self.assertEqual(config.x, 14) 111 | self.assertEqual(without_sharding.x, 14) 112 | 113 | def test_replace_with_defaults(self): 114 | config = fake_config() 115 | without_sharding = remove_sharding.remove_sharding( 116 | config=config, replace_with_default=True 117 | ) 118 | without_sharding.weight_split_dims_mapping.wt = () 119 | 120 | 121 | if __name__ == "__main__": 122 | absltest.main() 123 | -------------------------------------------------------------------------------- /paxml/tasks/lm/params/BUILD: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Description: 17 | # Language modeling model configurations. 18 | 19 | load("//paxml:paxml.bzl", "pytype_library") 20 | load("//paxml:paxml.bzl", "py_strict_test") 21 | load("//paxml:build_defs.bzl", "pax_targets") 22 | load("//praxis:build-visibility.bzl", "JAX_VISIBILITY") 23 | 24 | package(default_visibility = JAX_VISIBILITY) 25 | 26 | licenses(["notice"]) 27 | 28 | pytype_library( 29 | name = "params", 30 | srcs = [ 31 | "bert.py", 32 | "c4.py", 33 | "c4_multislice.py", 34 | "lm_cloud.py", 35 | "optimal_scaling.py", 36 | ], 37 | tags = ["keep_dep"], 38 | deps = [ 39 | ":lm_cloud", 40 | # Implicit absl.logging dependency. 41 | # Implicit etils dependency. 42 | # Implicit fiddle dependency. 43 | # Implicit jax dependency. 44 | "//paxml:base_experiment", 45 | "//paxml:experiment_registry", 46 | "//paxml:seqio_input", 47 | "//paxml:tasks_lib", 48 | "//paxml:trainer_lib", 49 | "//paxml/tasks/lm:input_generator", 50 | "//paxml/tasks/lm:model_params", 51 | "//praxis:base_hyperparams", 52 | "//praxis:base_input", 53 | "//praxis:base_layer", 54 | "//praxis:optimizers", 55 | "//praxis:pax_fiddle", 56 | "//praxis:schedules", 57 | "//praxis/layers", 58 | "//praxis/layers:transformers", 59 | # Implicit seqio dependency. 60 | # Implicit t5.data dependency. 61 | # Implicit t5.data.preprocessors dependency. 62 | ], 63 | ) 64 | 65 | pytype_library( 66 | name = "nvidia", 67 | srcs = [ 68 | "nvidia.py", 69 | ], 70 | tags = ["keep_dep"], 71 | deps = [ 72 | ":lm_cloud", 73 | ":params", 74 | # Implicit fiddle dependency. 75 | # Implicit jax dependency. 76 | # Implicit numpy dependency. 77 | "//paxml:experiment_registry", 78 | "//paxml:tasks_lib", 79 | "//paxml/tasks/lm:model_params", 80 | "//praxis:base_layer", 81 | "//praxis:optimizers", 82 | "//praxis:pax_fiddle", 83 | "//praxis:schedules", 84 | "//praxis/layers", 85 | "//praxis/layers:activations", 86 | "//praxis/layers:glam", 87 | "//praxis/layers:gpu_fast_attention", 88 | "//praxis/layers:grok", 89 | "//praxis/layers:transformers", 90 | ], 91 | ) 92 | 93 | pytype_library( 94 | name = "lm_cloud", 95 | srcs = [ 96 | "lm_cloud.py", 97 | ], 98 | deps = [ 99 | # Implicit absl.logging dependency. 100 | # Implicit jax dependency. 101 | "//paxml:base_experiment", 102 | "//paxml:experiment_registry", 103 | "//paxml:tasks_lib", 104 | "//paxml/tasks/lm:input_generator", 105 | "//paxml/tasks/lm:model_params", 106 | "//praxis:base_layer", 107 | "//praxis:pax_fiddle", 108 | "//praxis/layers", 109 | ], 110 | ) 111 | 112 | py_strict_test( 113 | name = "c4_test", 114 | timeout = "long", 115 | srcs = ["c4_test.py"], 116 | deps = [ 117 | ":params", 118 | # Implicit absl.testing.absltest.absltest dependency. 119 | # Implicit fiddle dependency. 120 | # Implicit jax dependency. 121 | "//praxis:optimizers", 122 | "//praxis:schedules", 123 | "//praxis:test_utils", 124 | "//praxis/layers", 125 | ], 126 | ) 127 | 128 | pax_targets( 129 | experiments = [ 130 | ":params", 131 | ], 132 | ) 133 | 134 | pax_targets( 135 | experiments = [ 136 | ":nvidia", 137 | ], 138 | prefix_name = "gpu", 139 | ) 140 | -------------------------------------------------------------------------------- /paxml/xla_passthrough.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Helpers for handling inputs unsupported by XLA.""" 17 | 18 | from absl import logging 19 | import jax 20 | import numpy as np 21 | 22 | 23 | def split_out_xla_unsupported_batch(batch, partitioning_spec=None): 24 | """Splits out values not supported by XLA (such as strings) from the batch. 25 | 26 | This is used to pass the unsupported values through a different channel. See 27 | also `merge_back_xla_unsupported_batch`. 28 | 29 | Args: 30 | batch: The input (possibly nested) dictionary. 31 | partitioning_spec: The dictionary with partitioning information or None. 32 | 33 | Returns: 34 | A tuple of the following elements. 35 | batch: The original batch dictionary with unsupported elements removed. 36 | unsupported_batch: A dictionary with only the unsupported elements. 37 | partitioning_spec: The original partitioning_spec with unsupported elements 38 | removed. 39 | """ 40 | unsupported_batch = {} 41 | new_partitioning_spec = {} 42 | 43 | for k, v in batch.items(): 44 | if hasattr(v, 'items'): 45 | nested_batch, nested_unsupported_batch, nested_partitioning_spec = ( 46 | split_out_xla_unsupported_batch( 47 | v, 48 | partitioning_spec=partitioning_spec.get(k) 49 | if partitioning_spec 50 | else None, 51 | ) 52 | ) 53 | if nested_unsupported_batch: 54 | batch[k] = nested_batch 55 | unsupported_batch[k] = nested_unsupported_batch 56 | if ( 57 | partitioning_spec 58 | and k in partitioning_spec 59 | and nested_partitioning_spec 60 | ): 61 | new_partitioning_spec[k] = nested_partitioning_spec 62 | continue 63 | 64 | if not np.issubdtype(v.dtype, np.str_) and not np.issubdtype( 65 | v.dtype, np.object_ 66 | ): 67 | continue 68 | 69 | unsupported_batch[k] = v 70 | 71 | # If no unsupported keys were detected, return out the original batch object 72 | # without modifying it. 73 | if not unsupported_batch: 74 | return batch, {}, partitioning_spec 75 | 76 | # Similarly for the multi-host case, which is not supported yet: return out 77 | # the original batch object without modifying it. 78 | if jax.process_count() > 1 and unsupported_batch: 79 | # TODO(b/279795947): Support xla passthrough for multihost eval. 80 | raise NotImplementedError( 81 | ( 82 | 'Unsupported inputs (with keys %s) were detected, but running with' 83 | ' more than one host. Forwarding these keys is currently not' 84 | ' supported (but may be supported in the future).' 85 | ) 86 | % unsupported_batch.keys(), 87 | ) 88 | 89 | batch = {k: v for k, v in batch.items() if k not in unsupported_batch} 90 | if partitioning_spec is not None: 91 | new_partitioning_spec.update( 92 | { 93 | k: v 94 | for k, v in partitioning_spec.items() 95 | if k not in unsupported_batch 96 | } 97 | ) 98 | else: 99 | new_partitioning_spec = None 100 | return batch, unsupported_batch, new_partitioning_spec 101 | 102 | 103 | def merge_back_xla_unsupported_batch(out, xla_unsupported_batch): 104 | """Adds back unsupported parts of the batch into out. 105 | 106 | This is done in case process_decode_out or other parts of the code want to 107 | make use of the unsupported parts. 108 | 109 | Args: 110 | out: The output dictionary without unsupported parts. 111 | xla_unsupported_batch: A dictionary with only the unsupported elements, if 112 | any. 113 | """ 114 | if xla_unsupported_batch: 115 | out.update(xla_unsupported_batch) 116 | -------------------------------------------------------------------------------- /paxml/tasks/lm/input_generator_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Test lm input generator.""" 17 | 18 | from absl.testing import absltest 19 | from absl.testing import parameterized 20 | import numpy as np 21 | from paxml import test_helper 22 | from paxml.tasks.lm import input_generator 23 | from praxis import base_hyperparams 24 | from praxis import pax_fiddle 25 | from praxis import test_utils 26 | import tensorflow.compat.v2 as tf 27 | 28 | instantiate = base_hyperparams.instantiate 29 | 30 | 31 | class InputTest(test_utils.TestCase): 32 | 33 | # We use first id and seq length as fingerprint to identify each shard 34 | # has the right elements. 35 | def _get_first_id_and_lengths(self, batch): 36 | return batch.labels[:, 1], np.sum(batch.segment_ids, axis=1, dtype=np.int32) 37 | 38 | def test_full(self): 39 | p = pax_fiddle.Config(input_generator.TFRecordBertInput) 40 | # There are 10 examples in this test data file. 41 | p.input_file = test_helper.test_src_dir_path('tasks/lm/testdata/tfrecords') 42 | p.batch_size = 10 43 | 44 | inp = instantiate(p) 45 | batch = inp.get_next() 46 | ids, lengths = self._get_first_id_and_lengths(batch) 47 | expected_ids = np.array( 48 | [2003, 1996, 1996, 2049, 3748, 1007, 4862, 1996, 2004, 2002], 49 | dtype=np.int32) 50 | expected_lengths = np.array([35, 239, 55, 56, 511, 511, 161, 43, 416, 511], 51 | dtype=np.int32) 52 | self.assertArraysEqual(ids, expected_ids) 53 | self.assertArraysEqual(lengths, expected_lengths) 54 | 55 | def test_remask(self): 56 | p = pax_fiddle.Config(input_generator.TFRecordBertInput) 57 | # There are 10 examples in this test data file. 58 | p.input_file = test_helper.test_src_dir_path('tasks/lm/testdata/tfrecords') 59 | p.batch_size = 10 60 | p.is_training = True 61 | p.remask = True 62 | p.mlm_augmenter.Set(mask_token_id=103, vocab_size=8000) 63 | 64 | inp = instantiate(p) 65 | batch = inp.get_next() 66 | ids, lengths = self._get_first_id_and_lengths(batch) 67 | self.assertEqual(ids.shape, (10,)) 68 | self.assertEqual(lengths.shape, (10,)) 69 | 70 | @parameterized.parameters(True, False) 71 | def test_sharded(self, provide_data_size): 72 | p = pax_fiddle.Config(input_generator.TFRecordBertInput) 73 | # There are 10 examples in this test data file. 74 | p.input_file = test_helper.test_src_dir_path('tasks/lm/testdata/tfrecords') 75 | p.batch_size = 4 76 | p.eval_data_size = 10 if provide_data_size else 0 77 | sharded_inputs = [] 78 | for i in range(4): 79 | local_p = p.clone().set(infeed_host_index=i, num_infeed_hosts=4) 80 | sharded_inputs.append(instantiate(local_p)) 81 | 82 | # This is the same as in test_full() above. 83 | expected_ids = np.array( 84 | [2003, 1996, 1996, 2049, 3748, 1007, 4862, 1996, 2004, 2002], 85 | dtype=np.int32) 86 | expected_lengths = np.array([35, 239, 55, 56, 511, 511, 161, 43, 416, 511], 87 | dtype=np.int32) 88 | expected_ids = np.reshape( 89 | np.concatenate( 90 | [expected_ids, np.array([0] * 6, dtype=np.int32)], axis=0), [4, -1]) 91 | expected_lengths = np.reshape( 92 | np.concatenate([expected_lengths, 93 | np.array([0] * 6, dtype=np.int32)], 94 | axis=0), [4, -1]) 95 | 96 | for i in [1, 3, 2, 0]: 97 | # each shard would produce one batch, and then out of range. 98 | batch = sharded_inputs[i].get_next() 99 | ids, lengths = self._get_first_id_and_lengths(batch) 100 | self.assertArraysEqual(ids, expected_ids[i]) 101 | self.assertArraysEqual(lengths, expected_lengths[i]) 102 | 103 | with self.assertRaisesRegex(tf.errors.OutOfRangeError, 'End of sequence'): 104 | sharded_inputs[i].get_next() 105 | 106 | 107 | if __name__ == '__main__': 108 | absltest.main() 109 | -------------------------------------------------------------------------------- /paxml/train_states.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """TrainState class for encapsulating model weights and optimizer states.""" 17 | 18 | from __future__ import annotations 19 | 20 | import dataclasses 21 | from typing import Any, Generic, TypeVar 22 | 23 | from flax import struct as flax_struct 24 | import jax 25 | import jaxtyping as jt 26 | import optax 27 | from praxis import base_layer 28 | from praxis import py_utils 29 | from praxis import pytypes 30 | from praxis import trees 31 | 32 | JTensor = py_utils.JTensor 33 | JTensorProvenance = tuple[str, int | None] 34 | JTensorOrPartitionSpec = pytypes.JTensorOrPartitionSpec 35 | Nested = pytypes.Nested 36 | NestedJTensor = base_layer.NestedJTensor 37 | NestedJTensorOrPartitionSpec = pytypes.NestedJTensorOrPartitionSpec 38 | NestedMap = py_utils.NestedMap 39 | 40 | 41 | _ArrayOrPSpec = TypeVar('_ArrayOrPSpec', jax.Array, jax.sharding.PartitionSpec) 42 | """Either a pspec (when tracing) or a Jax tensor.""" 43 | ExtraStateType = NestedJTensorOrPartitionSpec | None 44 | 45 | 46 | # A helper class for managing various train states. This struct may contain the 47 | # actual Jax tensors, or simply PartitionSpecs for the corresponding tensor. 48 | # If the latter, this struct is used for specifying the PartitionSpecs for 49 | # input/output to/from a pjit-ed function. 50 | class TrainState(flax_struct.PyTreeNode, Generic[_ArrayOrPSpec]): 51 | """Simple train state.""" 52 | 53 | step: _ArrayOrPSpec 54 | mdl_vars: jt.PyTree[_ArrayOrPSpec] 55 | opt_states: list[jt.PyTree[_ArrayOrPSpec]] 56 | extra_state: ExtraStateType = () 57 | 58 | def new_state( 59 | self, 60 | mdl_vars: NestedJTensor, 61 | opt_states: list[optax.OptState], 62 | extra_state: ExtraStateType = (), 63 | ) -> TrainState: 64 | """Returns a new TrainState with updated mdl_vars and opt_states.""" 65 | return TrainState( 66 | step=self.step + 1, 67 | mdl_vars=trees.copy(mdl_vars), 68 | opt_states=trees.copy(opt_states), 69 | extra_state=trees.copy(extra_state), 70 | ) 71 | 72 | def to_eval_state(self) -> TrainState: 73 | """Returns a new TrainState with opt_states removed, for eval purpose.""" 74 | return TrainState( 75 | step=self.step, mdl_vars=self.mdl_vars, opt_states=[], extra_state=() 76 | ) 77 | 78 | 79 | @dataclasses.dataclass 80 | class TensorProvenance: 81 | checkpoint_path: str = 'random_init' 82 | checkpoint_step: int | None = None 83 | 84 | def __repr__(self) -> str: 85 | if self.checkpoint_path == 'random_init': 86 | return f'"({self.checkpoint_path})"' 87 | 88 | checkpoint_step_repr = ( 89 | self.checkpoint_step if self.checkpoint_step is not None else 'latest' 90 | ) 91 | 92 | return f'"({self.checkpoint_path}:{checkpoint_step_repr})"' 93 | 94 | 95 | @dataclasses.dataclass 96 | class TrainStateProvenance: 97 | """Provenance for the TrainState pytree struct (not jax-transformable).""" 98 | 99 | step: TensorProvenance 100 | mdl_vars: Nested[TensorProvenance] 101 | opt_states: Nested[TensorProvenance] 102 | extra_state: Nested[TensorProvenance] 103 | 104 | def replace(self, **changes: Any) -> TrainStateProvenance: 105 | return dataclasses.replace(self, **changes) 106 | 107 | 108 | def build_train_state_provenance( 109 | train_state: TrainState, 110 | checkpoint_path: str | None = None, 111 | step: int | None = None, 112 | ) -> TrainStateProvenance: 113 | assert not isinstance( 114 | train_state.step, jax.sharding.PartitionSpec 115 | ), 'Tensor provenance is only for tensors' 116 | 117 | provenance = TensorProvenance() 118 | if checkpoint_path: 119 | provenance = TensorProvenance( 120 | checkpoint_path=checkpoint_path, checkpoint_step=step 121 | ) 122 | return TrainStateProvenance( 123 | step=provenance, 124 | mdl_vars=jax.tree.map(lambda x: provenance, train_state.mdl_vars), 125 | opt_states=jax.tree.map(lambda x: provenance, train_state.opt_states), 126 | extra_state=jax.tree.map(lambda x: provenance, train_state.extra_state), 127 | ) 128 | -------------------------------------------------------------------------------- /paxml/xla_passthrough_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for xla_passthrough.""" 17 | 18 | from absl.testing import absltest 19 | import numpy as np 20 | from paxml import xla_passthrough 21 | 22 | 23 | class InputUtilsTest(absltest.TestCase): 24 | 25 | def test_split_out_xla_unsupported_batch_noop(self): 26 | batch = {'a': np.array([1, 2, 3, 4]), 'b': np.array([5, 6, 7, 8])} 27 | partitioning_spec = {'a': 'fake_spec', 'b': 'fake_spec'} 28 | out_batch, out_unsupported, partitioning_spec = ( 29 | xla_passthrough.split_out_xla_unsupported_batch( 30 | batch, partitioning_spec=partitioning_spec 31 | ) 32 | ) 33 | # In no-op cases, the exact same object should be returned. 34 | self.assertIs(batch, out_batch) 35 | self.assertEmpty(out_unsupported) 36 | xla_passthrough.merge_back_xla_unsupported_batch(out_batch, out_unsupported) 37 | self.assertEqual(out_batch, batch) 38 | # Partitioning spec should not be modified. 39 | self.assertIsNotNone(partitioning_spec) 40 | self.assertCountEqual(partitioning_spec.keys(), {'a', 'b'}) 41 | 42 | def test_split_out_xla_unsupported_batch_singly_nested(self): 43 | batch = { 44 | 'a': np.array([1, 2, 3, 4]), 45 | 'b': np.array([5, 6, 7, 8]), 46 | 'c': np.array(['a', 'b', 'c', 'd']), 47 | } 48 | partitioning_spec = {'a': 'fake_spec', 'b': 'fake_spec', 'c': 'fake_spec'} 49 | out_batch, out_unsupported, new_partitioning_spec = ( 50 | xla_passthrough.split_out_xla_unsupported_batch( 51 | batch, partitioning_spec=partitioning_spec 52 | ) 53 | ) 54 | self.assertCountEqual(out_batch.keys(), {'a', 'b'}) 55 | self.assertCountEqual(out_unsupported.keys(), {'c'}) 56 | # Verify that the unsupported parts were flattened. 57 | self.assertEqual(list(out_unsupported['c']), ['a', 'b', 'c', 'd']) 58 | xla_passthrough.merge_back_xla_unsupported_batch(out_batch, out_unsupported) 59 | self.assertCountEqual(out_batch.keys(), {'a', 'b', 'c'}) 60 | # The original partitioning_spec should not be modified. 61 | self.assertCountEqual(partitioning_spec.keys(), {'a', 'b', 'c'}) 62 | # The unsupported key should have been deleted from the new partitioning 63 | # spec. 64 | self.assertIsNotNone(new_partitioning_spec) 65 | self.assertCountEqual(new_partitioning_spec.keys(), {'a', 'b'}) 66 | 67 | def test_split_out_xla_unsupported_batch_multi_nested(self): 68 | batch = { 69 | 'a': np.array([1, 2, 3, 4]), 70 | 'b': np.array([5, 6, 7, 8]), 71 | 'c': { 72 | 'd': np.array(['a', 'b', 'c', 'd']), 73 | 'e': np.array([1, 2, 3, 4]), 74 | }, 75 | } 76 | partitioning_spec = { 77 | 'a': 'fake_spec', 78 | 'b': 'fake_spec', 79 | 'c': {'d': 'fake_spec', 'e': 'fake_spec'}, 80 | } 81 | out_batch, out_unsupported, new_partitioning_spec = ( 82 | xla_passthrough.split_out_xla_unsupported_batch( 83 | batch, partitioning_spec=partitioning_spec 84 | ) 85 | ) 86 | self.assertCountEqual(out_batch.keys(), {'a', 'b'}) 87 | self.assertCountEqual(out_unsupported.keys(), {'c'}) 88 | self.assertCountEqual(out_unsupported['c'].keys(), {'d'}) 89 | # Verify that the unsupported parts were flattened. 90 | self.assertEqual(list(out_unsupported['c']['d']), ['a', 'b', 'c', 'd']) 91 | xla_passthrough.merge_back_xla_unsupported_batch(out_batch, out_unsupported) 92 | self.assertCountEqual(out_batch.keys(), {'a', 'b', 'c'}) 93 | self.assertCountEqual(out_batch['c'].keys(), {'d'}) 94 | # The original partitioning_spec should not be modified. 95 | self.assertCountEqual(partitioning_spec.keys(), {'a', 'b', 'c'}) 96 | self.assertCountEqual(partitioning_spec['c'].keys(), {'d', 'e'}) 97 | # The unsupported key should have been deleted from the new partitioning 98 | # spec. 99 | self.assertIsNotNone(new_partitioning_spec) 100 | self.assertCountEqual(new_partitioning_spec.keys(), {'a', 'b', 'c'}) 101 | self.assertCountEqual(new_partitioning_spec['c'].keys(), {'e'}) 102 | 103 | 104 | if __name__ == '__main__': 105 | absltest.main() 106 | -------------------------------------------------------------------------------- /paxml/docs/tutorials/inputs_in_Pax-train.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "4Rh0P4V34wZm" 7 | }, 8 | "source": [ 9 | "# Pax Workshop\n", 10 | "## Inputs in Pax - training\n", 11 | "\n", 12 | "This colab demonstrates how inputs in Pax work.\n" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": null, 18 | "metadata": { 19 | "id": "mNJyBaezlI7O" 20 | }, 21 | "outputs": [], 22 | "source": [ 23 | "from praxis import base_input\n", 24 | "from praxis import base_hyperparams\n", 25 | "from paxml import seqio_input\n", 26 | "import numpy as np" 27 | ] 28 | }, 29 | { 30 | "cell_type": "markdown", 31 | "metadata": { 32 | "id": "sVtD2spKACvz" 33 | }, 34 | "source": [ 35 | "Let's start with a SeqIO input using the wsc training data." 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": null, 41 | "metadata": { 42 | "id": "frocGIyv5Kd-" 43 | }, 44 | "outputs": [], 45 | "source": [ 46 | "import t5.data.tasks\n", 47 | "p = seqio_input.SeqIOInput.HParams(\n", 48 | " mixture_name='super_glue_wsc_v102_simple_train',\n", 49 | " split_name='train',\n", 50 | " task_feature_lengths={'targets': 1280},\n", 51 | " feature_converter=seqio_input.LanguageModelFeatures(pack=True),\n", 52 | " is_training=True,\n", 53 | " use_cached=False,\n", 54 | " input_random_seed=123,\n", 55 | " batch_size=4)\n", 56 | "inp = base_hyperparams.instantiate(p)" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": null, 62 | "metadata": { 63 | "id": "eKeVT2kelzoa" 64 | }, 65 | "outputs": [], 66 | "source": [ 67 | "# Get a batch, inspect the spec of the data\n", 68 | "batch = inp.get_next()\n", 69 | "for k, v in batch.FlattenItems():\n", 70 | " print(k, v.shape, v.dtype)" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": null, 76 | "metadata": { 77 | "id": "dGzr0xtq7AYF" 78 | }, 79 | "outputs": [], 80 | "source": [ 81 | "# The data is packed\n", 82 | "for _ in range(4):\n", 83 | " batch = inp.get_next()\n", 84 | " print('segments: ', np.max(batch.segment_ids, axis=1))\n" 85 | ] 86 | }, 87 | { 88 | "cell_type": "markdown", 89 | "metadata": { 90 | "id": "Jq8trTrlDeXB" 91 | }, 92 | "source": [ 93 | "We set `input_random_seed=123` on the input hparams. What happens with `inp.reset()`? Does it reproduce the same data?\n", 94 | "\n", 95 | "What about if we re-instantiate the input object?" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": null, 101 | "metadata": { 102 | "id": "v9dhcdqj74lY" 103 | }, 104 | "outputs": [], 105 | "source": [ 106 | "# Tweak some fields\n", 107 | "p2 = p.clone().set(infeed_host_index=0, num_infeed_hosts=2, shuffle=False)\n", 108 | "# disable packing\n", 109 | "p2.feature_converter = seqio_input.LanguageModelFeatures(pack=False)\n", 110 | "inp2 = base_hyperparams.instantiate(p2)\n", 111 | "\n", 112 | "p2_complement = p2.clone().set(infeed_host_index=1)\n", 113 | "inp2_complement = base_hyperparams.instantiate(p2_complement)\n", 114 | "\n", 115 | "batches = [inp2.get_next(), inp2_complement.get_next()]" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": null, 121 | "metadata": { 122 | "id": "CjPqIOz68lJm" 123 | }, 124 | "outputs": [], 125 | "source": [ 126 | "inp.ids_to_strings(batches[0].ids, [1280] * 4)" 127 | ] 128 | }, 129 | { 130 | "cell_type": "markdown", 131 | "metadata": { 132 | "id": "G-5jh4P4CcPp" 133 | }, 134 | "source": [ 135 | "Now inspect the data from `inp2_complement`. Verify that it does not overlap with the data from `inp2`. Does this hold if we run more batches?" 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": null, 141 | "metadata": { 142 | "id": "8fbBkr64-2tA" 143 | }, 144 | "outputs": [], 145 | "source": [ 146 | "# The data is also no longer packed\n", 147 | "np.max(batches[0].segment_ids, axis=1)" 148 | ] 149 | } 150 | ], 151 | "metadata": { 152 | "kernelspec": { 153 | "display_name": "Python 3 (ipykernel)", 154 | "language": "python", 155 | "name": "python3" 156 | }, 157 | "language_info": { 158 | "codemirror_mode": { 159 | "name": "ipython", 160 | "version": 3 161 | }, 162 | "file_extension": ".py", 163 | "mimetype": "text/x-python", 164 | "name": "python", 165 | "nbconvert_exporter": "python", 166 | "pygments_lexer": "ipython3", 167 | "version": "3.8.10" 168 | } 169 | }, 170 | "nbformat": 4, 171 | "nbformat_minor": 1 172 | } 173 | -------------------------------------------------------------------------------- /paxml/docs/models.md: -------------------------------------------------------------------------------- 1 | # Models 2 | 3 | doc/pax/models 4 | 5 | [TOC] 6 | 7 | ## Introduction 8 | 9 | A **model** is a collection of layers defining the network. 10 | 11 | In Pax, Models are used in [Tasks][tasks], which are part of 12 | [Experiments][experiments] that can be *trained*, *evaluated*, and *decoded* (as 13 | well as a mix of these). 14 | 15 | > Tip: For a rudamentary introduction to the basic Pax components, check out 16 | > [Pax Elements][pax-elements]. If you want to dive in for a hands-on 17 | > experience, try the [Pax Model and Task Jupyter Notebook][model_ipynb]. 18 | 19 | ## Model How-To's 20 | 21 | ### Define a Model 22 | 23 | In Pax, a Model inherits from `BaseModel`; and `BaseModel`, in turn, inherits 24 | from `BaseLayer` along with a few interfaces for interacting with the 25 | model: 26 | 27 | * `compute_predictions()` 28 | * `compute_loss()` 29 | * `decode()` 30 | * `process_decode_out()` 31 | 32 | A BaseModel is a nothing more than a BaseLayer with a specific API that is used 33 | to integrate with the Pax trainer: 34 | 35 | ```python 36 | class BaseModel(base_layer.BaseLayer): 37 | ... 38 | ``` 39 | 40 | Therefore, to build your own Pax Model, you will need to define these methods in 41 | your derived class. 42 | 43 | #### `compute_predictions` 44 | 45 | *Computes predictions for `input_batch`.* 46 | 47 | The output can be in the form of probabilistic distributions, such as softmax 48 | logits for discrete outputs, mixture of logistics for continuous values, or 49 | regression values. 50 | 51 | For training or evaluation, the output will be used for computing loss and 52 | gradient updates, including comparing predicted distributions between teacher 53 | and student for distillation. During inference the output can be used to compute 54 | final outputs, perhaps with sampling. 55 | 56 | Args: 57 | 58 | * input_batch: A `.NestedMap` object containing input tensors. 59 | 60 | Returns: 61 | 62 | * Predictions, either a single Tensor, a `.NestedMap`, or a namedtuple. 63 | 64 | 65 | #### `compute_loss` 66 | 67 | *Computes the loss and other metrics for the given predictions.* 68 | 69 | Args: 70 | 71 | * predictions: The output of `compute_predictions`. 72 | * input_batch: A `.NestedMap` object containing input tensors to this tower. 73 | 74 | Returns: 75 | 76 | * WeightedScalars - A dict or NestedMap containing str keys and 77 | (value, weight) pairs as values, where one or more entries are 78 | expected to correspond to the loss (or losses). 79 | * A dict containing arbitrary tensors describing something about each 80 | training example, where the first dimension of each tensor is the batch 81 | index. 82 | 83 | #### `decode` 84 | 85 | *Decodes the input_batch.* 86 | 87 | This code should be expected to run on TPUs. 88 | 89 | Args: 90 | 91 | * input_batch: The input batch. A `NestedMap` of tensors. Or, if input batch 92 | splitting is used, a list of `NestedMap`, one for each split. 93 | 94 | Returns a 3-tuple with: 95 | 96 | * weighted scalars, a NestedMap containing str keys and (value, weight) 97 | pairs for the current batch (a tuple of two scalars). 98 | * results, a `.NestedMap` as decoder output. 99 | * metrics, a NestedMap containing str keys and clu_metrics.Metric 100 | objects. 101 | 102 | #### `process_decode_out` 103 | 104 | *Processes one batch of decoded outputs.* 105 | 106 | This code will run on the host (CPU) and not on an accelerator (GPU or 107 | TPU). This allows you to run things that can't be processed on TPUs, such as 108 | strings. 109 | 110 | Args: 111 | 112 | * input_obj: The input object where a tokenizer is accessible. 113 | * decode_out: The output from decode(). May have an extra leading axis. 114 | 115 | Returns a 3-tuple with: 116 | 117 | * weighted scalars, a NestedMap containing str keys and (value, weight) 118 | pairs for the current batch (a tuple of two scalars). 119 | * A list of tuples where each element corresponds to a row in the batch. 120 | Each tuple is a key value pair. 121 | * metrics, a NestedMap containing str keys and clu_metrics.Metric 122 | objects. These will run outside of pmap/pjit. 123 | 124 | --- 125 | 126 | ### Select an Existing Model 127 | 128 | TODO: To be written. Volunteers welcome. 129 | 130 | For a library of pre-defined models, check out [base models][base-models], which 131 | includes: 132 | 133 | * LanguageModel 134 | * SequenceModel 135 | * ClassificationModel 136 | 137 | 138 | 139 | 140 | [base-models]: https://github.com/google/praxis/tree/main/praxis/layers/models.py 141 | [experiments]: https://github.com/google/paxml/tree/main/paxml/docs/experiments.md 142 | [model_ipynb]: https://github.com/google/paxml/tree/main/paxml/docs/hands-on-tutorials.md#pax-model-and-task 143 | [pax-elements]: https://github.com/google/paxml/tree/main/paxml/docs/learning-pax.md#pax-elements 144 | [tasks]: https://github.com/google/paxml/tree/main/paxml/docs/tasks.md 145 | -------------------------------------------------------------------------------- /paxml/docs/hands-on-tutorials.md: -------------------------------------------------------------------------------- 1 | # Hands-on Tutorials 2 | 3 | doc/pax/hands-on-tutorials 4 | 5 | [TOC] 6 | 7 | TLDR: A curated set of Jupyter Notebooks to get you comfortable using Pax. 8 | 9 | ## Overview 10 | 11 | The following Jupyter Notebooks have been create by Pax SMEs. The goal is to provide a hands-on introduction 12 | basic Pax activities. 13 | 14 | Good luck! Have fun! 15 | 16 | ## Jupyter Notebooks 17 | 18 | ### A JAX Primer 19 | 20 | In this first Jupyter Notebook, you will be introduced to JAX basics. When finished, you 21 | will be ready to dive in and see how JAX is used in the Pax infrastructure for 22 | training and multipod models. 23 | 24 | **Jupyter Notebook (coming soon):** [A JAX Primer][ipynb_jax_primer] 25 | 26 | ### JAX for Edit-Distance 27 | 28 | This lab showcases a few new things you can do in Jupyter Notebook with JAX. Specifically, 29 | you will develop a native, JAX-based edit-distance algorithm that works on 30 | padded batches. 31 | 32 | You are encouraged to build your own JAX Jupyter Notebooks as a way to learn by doing. 33 | 34 | **Jupyter Notebook (coming soon):** [JAX for Edit-Distance][ipynb_jax_ed] 35 | 36 | ### Pax Layer Basics 37 | 38 | Get your first hands-on look at Pax. You will learn about its fundamental 39 | component, the Pax Layers: the essential building blocks of models. In the 40 | process, you will learn about the basics for authoring a new Pax layer. 41 | 42 | Additionally, you will learn about Flax, a high-performance neural network 43 | library for JAX that is designed for flexibility. 44 | 45 | **Jupyter Notebook:** [Pax Layer Basics][ipynb_pax_layer] 46 | 47 | ### Inputs in Pax 48 | 49 | This short Jupyter Notebook demonstrates how inputs work in Pax. 50 | 51 | **Jupyter Notebook:** 52 | * [Inputs in Pax (training)][ipynb_pax_inputs_train] 53 | * [Inputs in Pax (eval)][ipynb_pax_inputs_eval] 54 | 55 | ### Pax Model and Task 56 | 57 | Model and task examples in Pax 58 | 59 | **Jupyter Notebook (coming soon):** [Pax Model and Task][ipynb_model_and_task] 60 | 61 | ### Pax End-to-end Tutorial 62 | 63 | Here you will put together what you have learned so far. You will be building a 64 | Translator using a Transformer Encoder/Decoder Architecture via the Pax 65 | framework. 66 | 67 | Without going too deep into the Layer design, you will see how to build, test, 68 | and combine them to train a model. 69 | 70 | **Jupyter Notebook:** [Pax End-to-end Tutorial][ipynb_pax_e2e] 71 | 72 | ### Pax Checkpointing 73 | 74 | This lab shows how to use checkpoints for warm-starting. It covers three topics 75 | around checkpoint handling: 76 | 77 | * **Run a fine-tuning** starting from a pretrained checkpoint in Pax. 78 | * **Inspect variables** in a checkpoint file. 79 | * **Test a model** interactively using a checkpoint file. 80 | 81 | **Jupyter Notebook (coming soon):** [Pax Checkpointing][ipynb_checkpoint] 82 | 83 | ### Pax RNN Decode 84 | 85 | This Jupyter Notebook demonstates how to set up *extend step*, *init*, and *update decode 86 | state cache* for the model's autoregressive decoding. 87 | 88 | In *autoregressive decoding*, each output of the network is generated based on 89 | previously generated output. Models such as RNN and transformer uses 90 | autoregressive decoding to generate new sequence. 91 | 92 | **Jupyter Notebook (coming soon):** [Pax RNN Decode][ipynb_rnn_decode] 93 | 94 | ### Sharding in Pax 95 | 96 | This Jupyter Notebook demonstrates how to use *sharding* in Pax. 97 | 98 | **Jupyter Notebook (coming soon):** [Sharding Jupyter Notebook][ipynb_shard_hard] 99 | 100 | 101 | ## Where to go from here 102 | 103 | Congratulations! At this point, you should have a pretty good idea on how to 104 | work with the various aspects of Pax. The rest of this site should help guide 105 | you as you move on to deeper topics. 106 | 107 | 108 | 109 | 110 | [ipynb_shard_hard]: https://github.com/google/paxml/tree/main/paxml/docs/tutorials/sharding.ipynb 111 | [ipynb_checkpoint]: https://github.com/google/paxml/tree/main/paxml/docs/tutorials/pax201_checkpointing.ipynb 112 | [ipynb_jax_ed]: https://github.com/google/paxml/tree/main/paxml/docs/tutorials/pax101_jax_for_edit_distance.ipynb 113 | [ipynb_jax_primer]: https://github.com/google/paxml/tree/main/paxml/docs/tutorials/pax101_jax_primer.ipynb 114 | [ipynb_model_and_task]: https://github.com/google/paxml/tree/main/paxml/docs/tutorials/pax101_model_and_task.ipynb 115 | [ipynb_pax_e2e]: https://github.com/google/paxml/tree/main/paxml/docs/tutorials/pax101_e2e_tutorial.ipynb 116 | [ipynb_pax_inputs_train]: https://github.com/google/paxml/blob/main/paxml/docs/tutorials/inputs_in_Pax-train.ipynb 117 | [ipynb_pax_inputs_eval]: https://github.com/google/paxml/blob/main/paxml/docs/tutorials/inputs_in_Pax-eval.ipynb 118 | [ipynb_pax_layer]: https://github.com/google/paxml/tree/main/paxml/docs/tutorials/pax_layer_basics.ipynb 119 | [ipynb_rnn_decode]: https://github.com/google/paxml/tree/main/paxml/docs/tutorials/pax_rnn_decode.ipynb 120 | -------------------------------------------------------------------------------- /paxml/host_callback.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Utilities for host callbacks.""" 17 | 18 | import collections 19 | import threading 20 | 21 | 22 | class Repository: 23 | """Thread-safe container of ID-keyed strings to pass strings to host_callback. 24 | 25 | The intended usage is as follows: 26 | 1. Set up a `Repository` that is accessible from both pre-processing 27 | and the model. For convenience, the function `repository(namespace)` 28 | in this module returns a per-namespace singleton `Repository`. 29 | Other approaches include a module-level `Repository` variable or a 30 | `Repository` injected via hparams. 31 | 2. In pre-processing, use `repository(namespace).add()` to add strings 32 | and pass the resulting string IDs to the model. 33 | 3. In the model/accelerator, use `repository(namespace).get()` to fetch 34 | strings by ID. 35 | 4. Set the `device_index` argument in host_callback.call match the device 36 | that runs pre-processing. 37 | 5. In post-processing, use `repository(namespace).pop()` to remove strings 38 | by ID. There is also a last-resort eviction policy, see `MAX_SIZE`. 39 | 40 | To avoid OOM when the caller does not promptly pop() the strings they add(), 41 | there is a limit on size. If this grows beyond that limit, then strings are 42 | evicted in least-recently-added order. 43 | 44 | TODO(terrykoo): Define string ID using fingerprints to allow caching and 45 | reuse. If we do this, however, we will need to add refcounts so pop() only 46 | removes an ID when all usages of it have subsided. 47 | """ 48 | 49 | # Maximum number of strings held in the repository of each namespace. 50 | MAX_SIZE = 10000 51 | 52 | def __init__(self, max_size: int = MAX_SIZE): 53 | """Creates an empty repository. 54 | 55 | If you use a non-singleton `Repository`, the generated string IDs might not 56 | be sufficiently unique. 57 | 58 | Args: 59 | max_size: Maximum number of strings to hold. 60 | """ 61 | self._max_size = max_size 62 | self._lock = threading.Lock() 63 | self._string_by_id = dict() 64 | self._next_id_to_assign = 0 65 | self._next_id_to_evict = 0 66 | 67 | def add(self, value: str) -> int: 68 | """Adds new string to the mapping and returns its global ID. 69 | 70 | If necessary, also evicts old regexes to keep this under the maximum size. 71 | 72 | Args: 73 | value: String to add. 74 | 75 | Returns: 76 | ID of the string. IDs are unique per LM server provided the caller uses 77 | the singleton. 78 | """ 79 | with self._lock: 80 | string_id = self._next_id_to_assign 81 | self._next_id_to_assign += 1 82 | self._string_by_id[string_id] = value 83 | 84 | while len(self._string_by_id) > self._max_size: 85 | self._string_by_id.pop(self._next_id_to_evict) 86 | self._next_id_to_evict += 1 87 | 88 | return string_id 89 | 90 | def pop(self, value_id: int) -> bool: 91 | """Attempts to remove the `string_id`. 92 | 93 | The regex might not be removed if the `string_id` is unknown. 94 | 95 | Args: 96 | value_id: ID of the string to remove, as returned by add(). 97 | 98 | Returns: 99 | True if the string was removed. 100 | """ 101 | with self._lock: 102 | return self._string_by_id.pop(value_id, None) is not None 103 | 104 | def get(self, value_id: int) -> str: 105 | """Returns the string mapped to `string_id`. 106 | 107 | Args: 108 | value_id: ID of the string to fetch, as returned by add(). 109 | 110 | Returns: 111 | String associated with the `value_id`. 112 | 113 | Raises: 114 | KeyError: If the `value_id` is not mapped. 115 | """ 116 | with self._lock: 117 | return self._string_by_id[value_id] 118 | 119 | @property 120 | def size(self) -> int: 121 | """Returns the number of strings in this.""" 122 | with self._lock: 123 | return len(self._string_by_id) 124 | 125 | 126 | # This is defined and instantiated after the class, because (unlike languages 127 | # like C++, Java, or TypeScript) Python classes don't exist until after their 128 | # definition ends. 129 | _global_lock = threading.Lock() 130 | _global_repository_by_namespace = collections.defaultdict(Repository) 131 | 132 | 133 | def repository(namespace: str) -> Repository: 134 | with _global_lock: 135 | return _global_repository_by_namespace[namespace] 136 | --------------------------------------------------------------------------------