├── .bazelversion ├── .vscode ├── extensions.json └── settings.json ├── README.md ├── jax_tpu_embedding └── sparsecore │ ├── examples │ ├── shakespeare │ │ └── README.md │ └── models │ │ └── shakespeare │ │ ├── BUILD │ │ ├── flax_model.py │ │ ├── model.py │ │ └── flax_nnx_model.py │ ├── version.py.in │ ├── lib │ ├── __init__.py │ ├── core │ │ ├── __init__.py │ │ ├── primitives │ │ │ ├── __init__.py │ │ │ ├── utils.py │ │ │ ├── local_sparse_dense_matmul.py │ │ │ ├── optimizers_computation.py │ │ │ └── tests │ │ │ │ └── local_sparse_dense_matmul_test.py │ │ ├── constants.py │ │ ├── grpc │ │ │ ├── oss │ │ │ │ ├── BUILD.bazel │ │ │ │ └── grpc_credentials.cc │ │ │ ├── all_reduce.proto │ │ │ ├── grpc_credentials.h │ │ │ ├── minibatching_node.h │ │ │ ├── all_reduce_interface.h │ │ │ ├── all_reduce_service_impl.h │ │ │ └── BUILD │ │ ├── input_preprocessing_threads_test.cc │ │ ├── input_preprocessing_threads.h │ │ ├── test_utils.py │ │ ├── all_reduce_interface.h │ │ ├── unity_weights_stream_impl.h │ │ ├── minibatching_test_utils.h │ │ ├── input_preprocessing_threads.cc │ │ ├── input_preprocessing.h │ │ ├── abstract_input_batch.h │ │ ├── numpy_input_batch.h │ │ ├── sparse_coo_input_batch.h │ │ ├── sparse_coo_input_batch.cc │ │ ├── minibatching_splits_test.cc │ │ ├── sparse_csr_input_stream_impl.h │ │ └── minibatching_splits_impl.h │ ├── fdo │ │ ├── __init__.py │ │ ├── BUILD │ │ ├── fdo_client.py │ │ └── file_fdo_client_test.py │ ├── flax │ │ ├── __init__.py │ │ ├── nnx │ │ │ ├── __init__.py │ │ │ ├── BUILD │ │ │ └── tests │ │ │ │ └── BUILD │ │ ├── linen │ │ │ ├── __init__.py │ │ │ ├── BUILD │ │ │ ├── tests │ │ │ │ ├── BUILD │ │ │ │ └── embed_optimizer_test.py │ │ │ └── embed_optimizer.py │ │ └── BUILD │ ├── nn │ │ ├── __init__.py │ │ └── BUILD │ ├── proto │ │ ├── __init__.py │ │ ├── BUILD │ │ └── embedding_spec.proto │ ├── BUILD │ └── auto_pipelining │ │ ├── BUILD │ │ ├── utils.py │ │ └── preprocess.py │ ├── utils │ ├── __init__.py │ ├── BUILD │ └── utils.py │ ├── __init__.py │ ├── tests │ ├── version_test.py │ └── BUILD │ ├── jax_tpu_embedding.bzl │ ├── BUILD │ └── configure_file.bzl ├── .gitignore ├── third_party ├── py │ ├── requirements.in │ └── BUILD.bazel ├── jax │ ├── tpu_hardware_device.py │ └── BUILD.bazel ├── xla │ ├── BUILD.bazel │ ├── revision.bzl │ └── workspace.bzl └── bazel │ └── python │ ├── BUILD.bazel │ ├── pybind11.bzl │ ├── python_init_pip.bzl │ ├── pytype.bzl │ └── pypi.bzl ├── .github └── workflows │ ├── stale.yml │ └── build_and_test.yml ├── CHANGELOG.md ├── BUILD.bazel ├── CONTRIBUTING.md ├── tools ├── docker_build_wheel.sh ├── BUILD.bazel ├── local_build_wheel.sh └── install_bazelisk.sh ├── pyproject.toml └── WORKSPACE /.bazelversion: -------------------------------------------------------------------------------- 1 | 7.7.0 2 | -------------------------------------------------------------------------------- /.vscode/extensions.json: -------------------------------------------------------------------------------- 1 | { 2 | "recommendations": [ 3 | "ms-python.black-formatter" 4 | ] 5 | } 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # jax_tpu_embedding 2 | 3 | [![Unittests](https://github.com/jax-ml/jax-tpu-embedding/actions/workflows/build_and_test.yml/badge.svg)](https://github.com/jax-ml/jax-tpu-embedding/actions/workflows/build_and_test.yml) 4 | 5 | Usage instructions coming soon! 6 | 7 | *This is not an officially supported Google product.* 8 | -------------------------------------------------------------------------------- /jax_tpu_embedding/sparsecore/examples/shakespeare/README.md: -------------------------------------------------------------------------------- 1 | # A simple Shakespeare model using the JAX SparseCore API 2 | 3 | ## About 4 | 5 | This directory contains two versions of the simple Shakespeare model that run a 6 | distributed training with the embedding layer on SparseCore and a dense tower on 7 | TensorCore. One is implemented using `pmap()` and the other uses `jit` + 8 | `shard_map`. 9 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled python modules. 2 | *.pyc 3 | 4 | # Byte-compiled 5 | _pycache__/ 6 | .cache/ 7 | 8 | # Poetry, setuptools, PyPI distribution artifacts. 9 | /*.egg-info 10 | .eggs/ 11 | dist/ 12 | poetry.lock 13 | 14 | # Tests 15 | .pytest_cache/ 16 | 17 | # Type checking 18 | .pytype/ 19 | 20 | # Other 21 | *.DS_Store 22 | 23 | # PyCharm 24 | .idea 25 | 26 | # Bazel 27 | /bazel-* 28 | 29 | # Built wheels. 30 | /*.whl 31 | -------------------------------------------------------------------------------- /jax_tpu_embedding/sparsecore/version.py.in: -------------------------------------------------------------------------------- 1 | """JAX TPU Embedding versioning utilities 2 | 3 | For releases, the version is of the form: 4 | xx.yy.zz 5 | 6 | For nightly builds, the date of the build is added: 7 | xx.yy.zz-devYYYMMDD 8 | """ 9 | 10 | _base_version = "0.1.0" 11 | _version_suffix = "${VERSION_SUFFIX}" 12 | 13 | # Git commit corresponding to the build, if available. 14 | __git_commit__ = "${GIT_COMMIT}" 15 | 16 | # Library version. 17 | __version__ = _base_version + _version_suffix 18 | 19 | -------------------------------------------------------------------------------- /third_party/py/requirements.in: -------------------------------------------------------------------------------- 1 | # Library. 2 | absl-py 3 | flax @ https://github.com/google/flax/archive/e2134af.zip 4 | numpy 5 | # Pre-release of JAX required for SparseCore TPUs. 6 | jax[tpu] --pre 7 | portpicker 8 | 9 | # Build utilities. 10 | auditwheel 11 | build 12 | packaging 13 | setuptools 14 | wheel 15 | 16 | # Testing. 17 | clu 18 | einops 19 | google-benchmark 20 | optax 21 | orbax 22 | protobuf 23 | 24 | # Paths for pre-releases of JAX and libtpu. 25 | --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/ 26 | -f https://storage.googleapis.com/jax-releases/libtpu_releases.html 27 | -------------------------------------------------------------------------------- /.github/workflows/stale.yml: -------------------------------------------------------------------------------- 1 | name: Close stale PRs 2 | 3 | on: 4 | schedule: 5 | - cron: '0 0 * * *' # Run daily at midnight 6 | 7 | jobs: 8 | stale: 9 | runs-on: ubuntu-latest 10 | permissions: 11 | issues: write 12 | pull-requests: write 13 | steps: 14 | - uses: actions/stale@v9 15 | with: 16 | repo-token: ${{ secrets.GITHUB_TOKEN }} 17 | stale-pr-message: 'This PR was closed due to 90 days of inactivity. If this change is still desired, please recreate the PR from the latest main branch.' 18 | days-before-stale: 90 19 | days-before-close: 0 # Close PRs as soon as they are marked stale 20 | only-pulls: true 21 | -------------------------------------------------------------------------------- /jax_tpu_embedding/sparsecore/lib/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The JAX SC Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # Empty file needed by setuptools.find_packages to recognize this as a package. 15 | -------------------------------------------------------------------------------- /jax_tpu_embedding/sparsecore/lib/core/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The JAX SC Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # Empty file needed by setuptools.find_packages to recognize this as a package. 15 | -------------------------------------------------------------------------------- /jax_tpu_embedding/sparsecore/lib/fdo/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The JAX SC Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # Empty file needed by setuptools.find_packages to recognize this as a package. 15 | -------------------------------------------------------------------------------- /jax_tpu_embedding/sparsecore/lib/flax/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The JAX SC Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # Empty file needed by setuptools.find_packages to recognize this as a package. 15 | -------------------------------------------------------------------------------- /jax_tpu_embedding/sparsecore/lib/nn/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The JAX SC Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # Empty file needed by setuptools.find_packages to recognize this as a package. 15 | -------------------------------------------------------------------------------- /jax_tpu_embedding/sparsecore/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The JAX SC Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # Empty file needed by setuptools.find_packages to recognize this as a package. 15 | -------------------------------------------------------------------------------- /jax_tpu_embedding/sparsecore/lib/flax/nnx/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The JAX SC Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # Empty file needed by setuptools.find_packages to recognize this as a package. 15 | -------------------------------------------------------------------------------- /jax_tpu_embedding/sparsecore/lib/proto/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The JAX SC Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # Empty file needed by setuptools.find_packages to recognize this as a package. 15 | -------------------------------------------------------------------------------- /jax_tpu_embedding/sparsecore/lib/flax/linen/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The JAX SC Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # Empty file needed by setuptools.find_packages to recognize this as a package. 15 | -------------------------------------------------------------------------------- /jax_tpu_embedding/sparsecore/lib/core/primitives/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The JAX SC Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # Empty file needed by setuptools.find_packages to recognize this as a package. 15 | -------------------------------------------------------------------------------- /jax_tpu_embedding/sparsecore/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The JAX SC Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """JAX SparseCore library.""" 15 | 16 | from jax_tpu_embedding.sparsecore import version 17 | 18 | __version__ = version.__version__ 19 | -------------------------------------------------------------------------------- /jax_tpu_embedding/sparsecore/lib/core/constants.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The JAX SC Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Defines global constants used in the Jax SC Core library implementation.""" 15 | 16 | # The padding value that we use for the CSR wrapped COO tensors. 17 | PADDING_VALUE = 2 ** 31 - 1 18 | -------------------------------------------------------------------------------- /third_party/jax/tpu_hardware_device.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The JAX SC Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Flags for configuring the TPU in tests.""" 15 | 16 | from absl import flags 17 | 18 | _TPU_CHIP_CONFIG_NAME = flags.DEFINE_string( 19 | "tpu_chip_config_name", 20 | "default", 21 | "Selects a configuration for the TPU chip by name.", 22 | ) 23 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "files.insertFinalNewline": true, 3 | "files.trimFinalNewlines": true, 4 | "files.trimTrailingWhitespace": true, 5 | "files.associations": { 6 | ".pylintrc": "ini" 7 | }, 8 | "python.testing.unittestEnabled": false, 9 | "python.testing.nosetestsEnabled": false, 10 | "python.testing.pytestEnabled": true, 11 | "python.linting.pylintUseMinimalCheckers": false, 12 | "[python]": { 13 | "editor.rulers": [80], 14 | "editor.tabSize": 2, 15 | "editor.defaultFormatter": "ms-python.black-formatter", 16 | "editor.formatOnSave": true, 17 | "editor.detectIndentation": false 18 | }, 19 | "python.formatting.provider": "none", 20 | "black-formatter.path": ["pyink"], 21 | "files.watcherExclude": { 22 | "**/.git/**": true 23 | }, 24 | "files.exclude": { 25 | "**/__pycache__": true, 26 | "**/.pytest_cache": true, 27 | "**/*.egg-info": true 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /third_party/xla/BUILD.bazel: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The JAX SC Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | load("@bazel_skylib//:bzl_library.bzl", "bzl_library") 15 | 16 | package( 17 | default_applicable_licenses = ["//:license"], 18 | default_visibility = ["//jax_tpu_embedding/sparsecore:internal"], 19 | ) 20 | 21 | bzl_library( 22 | name = "workspace_bzl", 23 | srcs = ["workspace.bzl"], 24 | visibility = ["//visibility:private"], 25 | ) 26 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | 23 | 24 | ## [Unreleased] 25 | 26 | ## [0.1.0] - 2025-01-01 27 | 28 | * Initial release 29 | 30 | [Unreleased]: https://github.com/jax-ml/jax_tpu_embedding/compare/v0.1.0...HEAD 31 | [0.1.0]: https://github.com/jax-ml/jax_tpu_embedding/releases/tag/v0.1.0 32 | -------------------------------------------------------------------------------- /BUILD.bazel: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The JAX SC Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | load("@rules_license//rules:license.bzl", "license") 15 | 16 | package( 17 | default_applicable_licenses = [":license"], 18 | default_visibility = ["//jax_tpu_embedding/sparsecore:internal"], 19 | ) 20 | 21 | license( 22 | name = "license", 23 | package_name = "jax_tpu_embedding", 24 | ) 25 | 26 | licenses(["notice"]) 27 | 28 | exports_files([ 29 | "LICENSE", 30 | "pyproject.toml", 31 | "README.md", 32 | ]) 33 | -------------------------------------------------------------------------------- /third_party/xla/revision.bzl: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The JAX SC Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # To update XLA to a new revision, 15 | # a) update XLA_COMMIT to the new git commit hash 16 | # b) get the sha256 hash of the commit by running: 17 | # curl -L https://api.github.com/repos/openxla/xla/tarball/{git_hash} | sha256sum 18 | # and update XLA_SHA256 with the result. 19 | 20 | # buildifier: disable=module-docstring 21 | XLA_COMMIT = "249feb06ea4fa587cd9a5eeee8c8c17cc8597582" 22 | XLA_SHA256 = "8c9af3417b877c665adec6cdb3aa59e51a387d198a8d5fd599e77f6548b28389" 23 | -------------------------------------------------------------------------------- /jax_tpu_embedding/sparsecore/lib/core/grpc/oss/BUILD.bazel: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The JAX SC Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | load("@rules_cc//cc:cc_library.bzl", "cc_library") 15 | 16 | package( 17 | default_applicable_licenses = ["//:license"], 18 | default_visibility = ["//jax_tpu_embedding/sparsecore/lib/core/grpc:__subpackages__"], 19 | licenses = ["notice"], 20 | ) 21 | 22 | cc_library( 23 | name = "grpc_credentials", 24 | srcs = ["grpc_credentials.cc"], 25 | deps = [ 26 | "@com_github_grpc_grpc//:grpc++_unsecure", 27 | ], 28 | ) 29 | -------------------------------------------------------------------------------- /jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_threads_test.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2024 The JAX SC Authors. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | #include "jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_threads.h" 15 | 16 | #include 17 | #include "tsl/platform/threadpool.h" // from @tsl 18 | 19 | namespace jax_sc_embedding { 20 | namespace { 21 | 22 | TEST(InputPreprocessingThreadsTest, CreateThreadPool) { 23 | tsl::thread::ThreadPool* pool = PreprocessingThreadPool(); 24 | EXPECT_NE(pool, nullptr); 25 | } 26 | 27 | } // namespace 28 | } // namespace jax_sc_embedding 29 | -------------------------------------------------------------------------------- /jax_tpu_embedding/sparsecore/lib/core/grpc/all_reduce.proto: -------------------------------------------------------------------------------- 1 | // Copyright 2024 The JAX SC Authors. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | syntax = "proto3"; 15 | 16 | package jax_sc_embedding; 17 | 18 | message AllReduceData { 19 | int32 sync_key = 1; 20 | int32 src_rank = 2; 21 | oneof value { 22 | bool bool_val = 3; 23 | uint64 uint64_val = 4; 24 | } 25 | } 26 | 27 | message AllReduceResponse {} 28 | 29 | service AllReduceGrpcService { 30 | // Unary RPC for a host to send its data to another host as part of an 31 | // All-to-All exchange. 32 | rpc ContributeData(AllReduceData) returns (AllReduceResponse) {} 33 | } 34 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement (CLA). You (or your employer) retain the copyright to your 10 | contribution; this simply gives us permission to use and redistribute your 11 | contributions as part of the project. Head over to 12 | to see your current agreements on file or 13 | to sign a new one. 14 | 15 | You generally only need to submit a CLA once, so if you've already submitted one 16 | (even if it was for a different project), you probably don't need to do it 17 | again. 18 | 19 | ## Code Reviews 20 | 21 | All submissions, including submissions by project members, require review. We 22 | use GitHub pull requests for this purpose. Consult 23 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 24 | information on using pull requests. 25 | 26 | ## Community Guidelines 27 | 28 | This project follows 29 | [Google's Open Source Community Guidelines](https://opensource.google/conduct/). 30 | -------------------------------------------------------------------------------- /jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_threads.h: -------------------------------------------------------------------------------- 1 | // Copyright 2024 The JAX SC Authors. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | #ifndef JAX_TPU_EMBEDDING_SPARSECORE_LIB_CORE_INPUT_PREPROCESSING_THREADS_H_ 15 | #define JAX_TPU_EMBEDDING_SPARSECORE_LIB_CORE_INPUT_PREPROCESSING_THREADS_H_ 16 | 17 | #include "tsl/platform/threadpool.h" // from @tsl 18 | 19 | namespace jax_sc_embedding { 20 | 21 | // Global thread pool for all computations done by input preprocessing. 22 | tsl::thread::ThreadPool* PreprocessingThreadPool(); 23 | 24 | } // namespace jax_sc_embedding 25 | 26 | #endif // JAX_TPU_EMBEDDING_SPARSECORE_LIB_CORE_INPUT_PREPROCESSING_THREADS_H_ 27 | -------------------------------------------------------------------------------- /jax_tpu_embedding/sparsecore/tests/version_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The JAX SC Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Test version string generation.""" 15 | 16 | from absl.testing import absltest 17 | from jax_tpu_embedding import sparsecore 18 | from jax_tpu_embedding.sparsecore import version 19 | 20 | 21 | class VersionTest(absltest.TestCase): 22 | 23 | def test_version_string(self): 24 | self.assertEqual(sparsecore.__version__, version.__version__) 25 | self.assertTrue(version.__version__.startswith(version._base_version)) 26 | self.assertTrue(version.__version__.endswith(version._version_suffix)) 27 | 28 | 29 | if __name__ == "__main__": 30 | absltest.main() 31 | -------------------------------------------------------------------------------- /third_party/jax/BUILD.bazel: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The JAX SC Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Extra utilities for working with JAX.""" 15 | 16 | load("//third_party/bazel/python:pypi.bzl", "pypi_requirement") 17 | load("//third_party/bazel/python:pytype.bzl", "pytype_strict_library") 18 | 19 | package( 20 | default_applicable_licenses = ["//:license"], 21 | default_visibility = ["//jax_tpu_embedding/sparsecore:internal"], 22 | ) 23 | 24 | # Stub library for depending on TPU hardware in tests. 25 | pytype_strict_library( 26 | name = "tpu_support", 27 | srcs = ["tpu_hardware_device.py"], 28 | visibility = ["//visibility:public"], 29 | deps = [pypi_requirement("absl/flags")], 30 | ) 31 | -------------------------------------------------------------------------------- /third_party/py/BUILD.bazel: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The JAX SC Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | load("@python//:defs.bzl", "compile_pip_requirements") 15 | load("@python_version_repo//:py_version.bzl", "REQUIREMENTS") 16 | 17 | licenses(["notice"]) 18 | 19 | package( 20 | default_applicable_licenses = ["//:license"], 21 | default_visibility = ["//jax_tpu_embedding/sparsecore:internal"], 22 | ) 23 | 24 | compile_pip_requirements( 25 | name = "requirements", 26 | extra_args = [ 27 | "--allow-unsafe", 28 | "--build-isolation", 29 | "--rebuild", 30 | ], 31 | generate_hashes = True, 32 | requirements_in = "requirements.in", 33 | requirements_txt = REQUIREMENTS, 34 | ) 35 | -------------------------------------------------------------------------------- /jax_tpu_embedding/sparsecore/lib/flax/BUILD: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The JAX SC Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | load("//jax_tpu_embedding/sparsecore:jax_tpu_embedding.bzl", "EXTERNAL_USERS") 15 | load("//third_party/bazel/python:pytype.bzl", "pytype_strict_library") 16 | 17 | package( 18 | default_applicable_licenses = ["//:license"], 19 | default_visibility = EXTERNAL_USERS, 20 | ) 21 | 22 | # Library target. 23 | pytype_strict_library( 24 | name = "flax", 25 | srcs = ["__init__.py"], 26 | visibility = ["//jax_tpu_embedding/sparsecore/lib:__pkg__"], 27 | deps = [ 28 | "//jax_tpu_embedding/sparsecore/lib/flax/linen", # buildcleaner: keep 29 | "//jax_tpu_embedding/sparsecore/lib/flax/nnx", # buildcleaner: keep 30 | ], 31 | ) 32 | -------------------------------------------------------------------------------- /jax_tpu_embedding/sparsecore/utils/BUILD: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The JAX SC Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | load("//jax_tpu_embedding/sparsecore:jax_tpu_embedding.bzl", "EXTERNAL_USERS") 15 | load("//third_party/bazel/python:pypi.bzl", "pypi_requirement") 16 | load("//third_party/bazel/python:pytype.bzl", "pytype_strict_library") 17 | 18 | package( 19 | default_applicable_licenses = ["//:license"], 20 | default_visibility = EXTERNAL_USERS, 21 | ) 22 | 23 | pytype_strict_library( 24 | name = "utils", 25 | srcs = [ 26 | "__init__.py", 27 | "utils.py", 28 | ], 29 | deps = [ 30 | pypi_requirement("absl/flags"), 31 | pypi_requirement("jax"), 32 | pypi_requirement("jax/experimental:layout"), 33 | ], 34 | ) 35 | -------------------------------------------------------------------------------- /third_party/bazel/python/BUILD.bazel: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The JAX SC Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ Bazel extensions for python rules. """ 15 | 16 | load("@bazel_skylib//:bzl_library.bzl", "bzl_library") 17 | 18 | package( 19 | default_applicable_licenses = ["//:license"], 20 | default_visibility = ["//jax_tpu_embedding/sparsecore:internal"], 21 | ) 22 | 23 | bzl_library( 24 | name = "pybind11_bzl", 25 | srcs = ["pybind11.bzl"], 26 | visibility = ["//visibility:private"], 27 | ) 28 | 29 | bzl_library( 30 | name = "pytype_bzl", 31 | srcs = ["pytype.bzl"], 32 | visibility = ["//visibility:private"], 33 | ) 34 | 35 | bzl_library( 36 | name = "pypi_bzl", 37 | srcs = ["pypi.bzl"], 38 | visibility = ["//visibility:private"], 39 | ) 40 | -------------------------------------------------------------------------------- /third_party/bazel/python/pybind11.bzl: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The JAX SC Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Pybind11 Extensions.""" 15 | 16 | load("@bazel_skylib//lib:collections.bzl", "collections") 17 | load("@rules_cc//cc:cc_library.bzl", "cc_library") 18 | load("@xla//xla/tsl:tsl.bzl", "tsl_pybind_extension_opensource") 19 | 20 | def pybind_extension(name, deps, **kwargs): 21 | # Add pybind11 to deps. 22 | deps = collections.uniq(deps + ["@pybind11"]) 23 | tsl_pybind_extension_opensource(name = name, deps = deps, **kwargs) 24 | 25 | def pybind_library(name, deps, **kwargs): 26 | # Add pybind11 and python headers to deps. 27 | deps = collections.uniq(deps + ["@pybind11", "@local_config_python//:python_headers"]) 28 | cc_library(name = name, deps = deps, **kwargs) 29 | -------------------------------------------------------------------------------- /third_party/xla/workspace.bzl: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The JAX SC Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # buildifier: disable=module-docstring 15 | load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") 16 | load("//third_party/xla:revision.bzl", "XLA_COMMIT", "XLA_SHA256") 17 | 18 | XLA_ARCHIVE = "https://api.github.com/repos/openxla/xla/tarball/{commit}".format(commit = XLA_COMMIT) 19 | 20 | def repo(): 21 | http_archive( 22 | name = "xla", 23 | sha256 = XLA_SHA256, 24 | type = "tar.gz", 25 | strip_prefix = "openxla-xla-{commit}".format(commit = XLA_COMMIT[:7]), 26 | urls = [ 27 | # Try TF mirror first. 28 | "https://storage.googleapis.com/mirror.tensorflow.org/%s" % XLA_ARCHIVE[8:], 29 | XLA_ARCHIVE, 30 | ], 31 | ) 32 | -------------------------------------------------------------------------------- /tools/docker_build_wheel.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # Script to build python wheels in a docker container. Assumes linux x86_64. 4 | # Run from the root folder. 5 | 6 | if [ -z "$JTE_DOCKER_IMAGE" ]; then 7 | JTE_DOCKER_IMAGE=us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:latest 8 | fi 9 | 10 | if [ -z "$JTE_DOCKER_WORKDIR" ]; then 11 | JTE_DOCKER_WORKDIR=/build/jax_tpu_embedding 12 | fi 13 | 14 | # Mark the wheel output as relative to the docker build folder. 15 | if [ -z "$JTE_WHEEL_OUTDIR" ]; then 16 | JTE_WHEEL_OUTDIR="${JTE_DOCKER_WORKDIR}/dist" 17 | fi 18 | 19 | # Try to determine the git commit ID locally if not set. 20 | # (The docker container does not contain git). 21 | if [ -z "$JTE_GIT_SHA" ]; then 22 | # Extract git hash from current git folder, if any. 23 | JTE_GIT_SHA=`git rev-parse HEAD 2> /dev/null || echo ""` 24 | fi 25 | 26 | docker run \ 27 | -v "$PWD":"${JTE_DOCKER_WORKDIR}" \ 28 | -w "${JTE_DOCKER_WORKDIR}" \ 29 | --env HERMETIC_PYTHON_VERSION="${HERMETIC_PYTHON_VERSION}" \ 30 | --env JTE_HERMETIC_PYTHON_VERSION="${JTE_HERMETIC_PYTHON_VERSION}" \ 31 | --env JTE_RELEASE="${JTE_RELEASE}" \ 32 | --env JTE_VERSION_SUFFIX="${JTE_VERSION_SUFFIX}" \ 33 | --env JTE_GIT_SHA="${JTE_GIT_SHA}" \ 34 | --env JTE_WHEEL_OUTDIR="${JTE_WHEEL_OUTDIR}" \ 35 | "${JTE_DOCKER_IMAGE}" \ 36 | bash -c tools/local_build_wheel.sh 37 | -------------------------------------------------------------------------------- /jax_tpu_embedding/sparsecore/lib/core/grpc/oss/grpc_credentials.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2024 The JAX SC Authors. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #include 16 | 17 | #include "include/grpcpp/security/credentials.h" // from @com_github_grpc_grpc 18 | #include "include/grpcpp/security/server_credentials.h" // from @com_github_grpc_grpc 19 | 20 | namespace jax_sc_embedding { 21 | namespace rpc { 22 | 23 | // Returns insecure credentials for use when creating a gRPC server. 24 | std::shared_ptr<::grpc::ServerCredentials> GetDefaultServerCredentials() { 25 | return ::grpc::InsecureServerCredentials(); 26 | } 27 | 28 | // Returns insecure credentials for use when creating a gRPC channel. 29 | std::shared_ptr<::grpc::ChannelCredentials> GetDefaultChannelCredentials() { 30 | return ::grpc::InsecureChannelCredentials(); 31 | } 32 | 33 | } // namespace rpc 34 | } // namespace jax_sc_embedding 35 | -------------------------------------------------------------------------------- /third_party/bazel/python/python_init_pip.bzl: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The JAX SC Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Hermetic Python initialization.""" 15 | 16 | load("@python//:defs.bzl", "interpreter") 17 | load("@python_version_repo//:py_version.bzl", "REQUIREMENTS_WITH_LOCAL_WHEELS") 18 | load("@rules_python//python:pip.bzl", "package_annotation", "pip_parse") 19 | 20 | def python_init_pip(): 21 | setuptools_annotations = { 22 | # We require the "Lorem ipsum.txt" file from the following directory, 23 | # but cannot depend directly on a filename containing spaces. 24 | "setuptools": package_annotation( 25 | data = [":site-packages/setuptools/_vendor/jaraco/text"], 26 | ), 27 | } 28 | 29 | pip_parse( 30 | name = "pypi", 31 | annotations = setuptools_annotations, 32 | python_interpreter_target = interpreter, 33 | requirements_lock = REQUIREMENTS_WITH_LOCAL_WHEELS, 34 | ) 35 | -------------------------------------------------------------------------------- /jax_tpu_embedding/sparsecore/lib/proto/BUILD: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The JAX SC Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | load("@com_google_protobuf//bazel:proto_library.bzl", "proto_library") 15 | load("@rules_python//python:proto.bzl", "py_proto_library") 16 | load("//jax_tpu_embedding/sparsecore:jax_tpu_embedding.bzl", "EXTERNAL_USERS") 17 | load("//third_party/bazel/python:pytype.bzl", "pytype_strict_library") 18 | 19 | package( 20 | default_applicable_licenses = ["//:license"], 21 | default_visibility = EXTERNAL_USERS, 22 | ) 23 | 24 | proto_library( 25 | name = "embedding_spec_proto", 26 | srcs = ["embedding_spec.proto"], 27 | ) 28 | 29 | py_proto_library( 30 | name = "embedding_spec_py_pb2", 31 | deps = [":embedding_spec_proto"], 32 | ) 33 | 34 | pytype_strict_library( 35 | name = "proto", 36 | srcs = ["__init__.py"], 37 | visibility = ["//jax_tpu_embedding/sparsecore/lib:__pkg__"], 38 | deps = [":embedding_spec_py_pb2"], # buildcleaner: keep 39 | ) 40 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "jax_tpu_embedding" 3 | description = "JAX SparseCore API" 4 | readme = "README.md" 5 | license = "Apache-2.0" 6 | license-files = ["LICENSE"] 7 | requires-python = ">=3.10" 8 | authors = [{name = "JAX SC Authors", email="jax-sc-dev@google.com"}] 9 | classifiers = [ # List of https://pypi.org/classifiers/ 10 | "Intended Audience :: Science/Research", 11 | 'Programming Language :: Python :: 3.10', 12 | 'Programming Language :: Python :: 3.11', 13 | 'Programming Language :: Python :: 3.12', 14 | 'Programming Language :: Python :: 3.13', 15 | ] 16 | keywords = ["jax", "tpu", "embedding", "sparsecore", "machine learning", "sparse tensor", "distributed computing"] 17 | dependencies = [ 18 | 'absl-py', 19 | 'einops', 20 | 'flax', 21 | 'jax', 22 | 'numpy', 23 | 'portpicker' 24 | ] 25 | 26 | # `version` is set by setuptools to use `jax_tpu_embedding.sparsecore.__version__`. 27 | dynamic = ["version"] 28 | 29 | [project.urls] 30 | homepage = "https://github.com/jax-ml/jax-tpu-embedding" 31 | repository = "https://github.com/jax-ml/jax-tpu-embedding" 32 | issue_tracker = "https://github.com/jax-ml/jax-tpu-embedding/issues" 33 | 34 | [build-system] 35 | requires = ["setuptools"] 36 | build-backend = "setuptools.build_meta" 37 | 38 | [tool.setuptools.dynamic] 39 | version = {attr = "jax_tpu_embedding.sparsecore.__version__"} 40 | 41 | [tool.setuptools.packages.find] 42 | include = ["jax_tpu_embedding*"] 43 | 44 | [tool.setuptools.package-data] 45 | "*" = ["*.dylib", "*.so", "*.pyd"] 46 | -------------------------------------------------------------------------------- /jax_tpu_embedding/sparsecore/lib/core/grpc/grpc_credentials.h: -------------------------------------------------------------------------------- 1 | // Copyright 2024 The JAX SC Authors. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | #ifndef JAX_TPU_EMBEDDING_SPARSECORE_LIB_CORE_GRPC_GRPC_CREDENTIALS_H_ 15 | #define JAX_TPU_EMBEDDING_SPARSECORE_LIB_CORE_GRPC_GRPC_CREDENTIALS_H_ 16 | 17 | #include 18 | 19 | #include "include/grpcpp/security/credentials.h" // from @com_github_grpc_grpc 20 | #include "include/grpcpp/security/server_credentials.h" // from @com_github_grpc_grpc 21 | 22 | namespace jax_sc_embedding { 23 | namespace rpc { 24 | 25 | // Returns default credentials for use when creating a gRPC server. 26 | std::shared_ptr<::grpc::ServerCredentials> GetDefaultServerCredentials(); 27 | 28 | // Returns default credentials for use when creating a gRPC channel. 29 | std::shared_ptr<::grpc::ChannelCredentials> GetDefaultChannelCredentials(); 30 | 31 | } // namespace rpc 32 | } // namespace jax_sc_embedding 33 | 34 | #endif // JAX_TPU_EMBEDDING_SPARSECORE_LIB_CORE_GRPC_GRPC_CREDENTIALS_H_ 35 | -------------------------------------------------------------------------------- /tools/BUILD.bazel: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The JAX SC Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | load("@pypi//:requirements.bzl", "requirement") 15 | 16 | licenses(["notice"]) 17 | 18 | package( 19 | default_applicable_licenses = ["//:license"], 20 | default_visibility = ["//jax_tpu_embedding/sparsecore:internal"], 21 | ) 22 | 23 | ############################################################################### 24 | # PIP Package 25 | ############################################################################### 26 | 27 | py_binary( 28 | name = "build_wheel", 29 | srcs = ["build_wheel.py"], 30 | data = [ 31 | "//:LICENSE", 32 | "//:README.md", 33 | "//:pyproject.toml", 34 | ], 35 | deps = [ 36 | "//jax_tpu_embedding/sparsecore", 37 | requirement("absl_py"), 38 | requirement("auditwheel"), 39 | requirement("build"), 40 | requirement("packaging"), 41 | requirement("setuptools"), 42 | requirement("wheel"), 43 | ], 44 | ) 45 | -------------------------------------------------------------------------------- /jax_tpu_embedding/sparsecore/lib/core/test_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The JAX SC Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Test utils for sparsecore input preprocessing.""" 15 | 16 | import numpy as np 17 | 18 | 19 | def assert_equal_coo_buffer( 20 | local_device_count: int, 21 | num_sc_per_device: int, 22 | row_pointers: np.ndarray, 23 | actual: np.ndarray, 24 | expected: np.ndarray, 25 | ): 26 | """Compare COO buffers ignoring end of SC/device padding.""" 27 | local_sc_count = local_device_count * num_sc_per_device 28 | 29 | # Ignore leading dim 30 | row_pointers = row_pointers.reshape(-1) 31 | actual = actual.reshape(-1) 32 | expected = expected.reshape(-1) 33 | 34 | for row_pointer_slice, actual_sc_slice, expected_sc_slice in zip( 35 | np.split(row_pointers, local_sc_count), 36 | np.split(actual, local_sc_count), 37 | np.split(expected, local_sc_count), 38 | ): 39 | np.testing.assert_almost_equal( 40 | actual_sc_slice[: row_pointer_slice[-1]], 41 | expected_sc_slice[: row_pointer_slice[-1]], 42 | decimal=6, 43 | ) 44 | -------------------------------------------------------------------------------- /tools/local_build_wheel.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # Script to build the python wheel. Run from the root folder. 4 | 5 | # Determine the hermetic python version. 6 | if [ -z "$JTE_HERMETIC_PYTHON_VERSION" ]; then 7 | if [ -z "$HERMETIC_PYTHON_VERSION" ]; then 8 | # Use the default hermetic python version. 9 | JTE_HERMETIC_PYTHON_VERSION=3.12 10 | else 11 | JTE_HERMETIC_PYTHON_VERSION="${HERMETIC_PYTHON_VERSION}" 12 | fi 13 | fi 14 | 15 | # Try to determine the git commit ID if not set. 16 | if [ -z "$JTE_GIT_SHA" ]; then 17 | # Extract git hash from current git folder, if any. 18 | JTE_GIT_SHA=`git rev-parse HEAD 2> /dev/null || echo ""` 19 | fi 20 | 21 | # Determine the appropriate wheel suffix. If it's not release, 22 | # and if the version suffix is not explicitly set, build a dev version. 23 | if [ -z "$JTE_RELEASE" ] && [ -z "$JTE_VERSION_SUFFIX" ]; then 24 | # Build suffix as dev${DATE} 25 | JTE_VERSION_SUFFIX="dev$(date '+%Y%m%d')" 26 | fi 27 | 28 | # Output directory for the wheel. 29 | if [ -z "$JTE_WHEEL_OUTDIR" ]; then 30 | JTE_WHEEL_OUTDIR="$PWD/dist" 31 | fi 32 | 33 | echo "JTE_HERMETIC_PYTHON_VERSION: ${JTE_HERMETIC_PYTHON_VERSION}" 34 | echo "JTE_RELEASE: ${JTE_RELEASE}" 35 | echo "JTE_VERSION_SUFFIX: ${JTE_VERSION_SUFFIX}" 36 | echo "JTE_GIT_SHA: ${JTE_GIT_SHA}" 37 | echo "JTE_WHEEL_OUTDIR: ${JTE_WHEEL_OUTDIR}" 38 | 39 | bazel run //tools:build_wheel --verbose_failures \ 40 | --repo_env=HERMETIC_PYTHON_VERSION="${JTE_HERMETIC_PYTHON_VERSION}" \ 41 | --//jax_tpu_embedding/sparsecore:version_suffix="${JTE_VERSION_SUFFIX}" \ 42 | --//jax_tpu_embedding/sparsecore:git_commit="${JTE_GIT_SHA}" \ 43 | -- --output_dir="${JTE_WHEEL_OUTDIR}" 44 | -------------------------------------------------------------------------------- /jax_tpu_embedding/sparsecore/lib/BUILD: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The JAX SC Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | load("//third_party/bazel/python:pytype.bzl", "pytype_strict_library") 15 | 16 | package( 17 | default_applicable_licenses = ["//:license"], 18 | default_visibility = ["//jax_tpu_embedding/sparsecore:internal"], 19 | ) 20 | 21 | pytype_strict_library( 22 | name = "lib", 23 | srcs = ["__init__.py"], 24 | visibility = ["//jax_tpu_embedding/sparsecore:__pkg__"], 25 | deps = [ 26 | "//jax_tpu_embedding/sparsecore/lib/auto_pipelining", # buildcleaner: keep 27 | "//jax_tpu_embedding/sparsecore/lib/core", # buildcleaner: keep 28 | "//jax_tpu_embedding/sparsecore/lib/fdo", # buildcleaner: keep 29 | "//jax_tpu_embedding/sparsecore/lib/flax", # buildcleaner: keep 30 | "//jax_tpu_embedding/sparsecore/lib/flax/linen", # buildcleaner: keep 31 | "//jax_tpu_embedding/sparsecore/lib/flax/nnx", # buildcleaner: keep 32 | "//jax_tpu_embedding/sparsecore/lib/nn", # buildcleaner: keep 33 | "//jax_tpu_embedding/sparsecore/lib/proto", # buildcleaner: keep 34 | ], 35 | ) 36 | -------------------------------------------------------------------------------- /third_party/bazel/python/pytype.bzl: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The JAX SC Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Default (OSS) build versions of Python pytype rules.""" 15 | 16 | load("@bazel_skylib//lib:collections.bzl", "collections") 17 | load("@rules_python//python:defs.bzl", "py_binary", "py_library", "py_test") 18 | 19 | # Placeholder to use until bazel supports pytype_library. 20 | def pytype_library(name, deps = [], pytype_deps = [], pytype_srcs = [], **kwargs): 21 | _ = (pytype_deps, pytype_srcs) # @unused 22 | py_library(name = name, deps = collections.uniq(deps), **kwargs) 23 | 24 | # Placeholder to use until bazel supports pytype_strict_binary. 25 | def pytype_strict_binary(name, deps = [], **kwargs): 26 | py_binary(name = name, deps = collections.uniq(deps), **kwargs) 27 | 28 | # Placeholder to use until bazel supports pytype_strict_library. 29 | def pytype_strict_library(name, deps = [], **kwargs): 30 | py_library(name = name, deps = collections.uniq(deps), **kwargs) 31 | 32 | # Placeholder to use until bazel supports pytype_strict_contrib_test. 33 | def pytype_strict_contrib_test(name, deps = [], **kwargs): 34 | py_test(name = name, deps = collections.uniq(deps), **kwargs) 35 | -------------------------------------------------------------------------------- /jax_tpu_embedding/sparsecore/lib/flax/nnx/BUILD: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The JAX SC Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | load("//jax_tpu_embedding/sparsecore:jax_tpu_embedding.bzl", "EXTERNAL_USERS") 15 | load("//third_party/bazel/python:pypi.bzl", "pypi_requirement") 16 | load("//third_party/bazel/python:pytype.bzl", "pytype_strict_library") 17 | 18 | package( 19 | default_applicable_licenses = ["//:license"], 20 | default_visibility = EXTERNAL_USERS, 21 | ) 22 | 23 | pytype_strict_library( 24 | name = "embed", 25 | srcs = [ 26 | "embed.py", 27 | ], 28 | deps = [ 29 | "//jax_tpu_embedding/sparsecore/lib/nn:embedding", 30 | "//jax_tpu_embedding/sparsecore/lib/nn:embedding_spec", 31 | "//jax_tpu_embedding/sparsecore/utils", 32 | pypi_requirement("flax/nnx"), 33 | pypi_requirement("jax"), 34 | pypi_requirement("optax"), 35 | ], 36 | ) 37 | 38 | # Library target. 39 | pytype_strict_library( 40 | name = "nnx", 41 | srcs = ["__init__.py"], 42 | visibility = [ 43 | "//jax_tpu_embedding/sparsecore/lib:__pkg__", 44 | "//jax_tpu_embedding/sparsecore/lib/flax:__pkg__", 45 | ], 46 | deps = [ 47 | ":embed", # buildcleaner: keep 48 | ], 49 | ) 50 | -------------------------------------------------------------------------------- /jax_tpu_embedding/sparsecore/lib/auto_pipelining/BUILD: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The JAX SC Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | load("//third_party/bazel/python:pypi.bzl", "pypi_requirement") 15 | load("//third_party/bazel/python:pytype.bzl", "pytype_strict_library") 16 | 17 | package( 18 | default_applicable_licenses = ["//:license"], 19 | default_visibility = ["//jax_tpu_embedding/sparsecore:internal"], 20 | ) 21 | 22 | pytype_strict_library( 23 | name = "utils", 24 | srcs = ["utils.py"], 25 | deps = [ 26 | pypi_requirement("jax"), 27 | pypi_requirement("jax/extend"), 28 | ], 29 | ) 30 | 31 | pytype_strict_library( 32 | name = "decompose", 33 | srcs = ["decompose.py"], 34 | deps = [ 35 | ":preprocess", 36 | ":utils", 37 | pypi_requirement("jax"), 38 | pypi_requirement("jax/extend"), 39 | ], 40 | ) 41 | 42 | pytype_strict_library( 43 | name = "preprocess", 44 | srcs = ["preprocess.py"], 45 | deps = [ 46 | ":utils", 47 | pypi_requirement("jax/extend"), 48 | ], 49 | ) 50 | 51 | pytype_strict_library( 52 | name = "auto_pipelining", 53 | srcs = ["auto_pipelining.py"], 54 | deps = [ 55 | ":decompose", 56 | pypi_requirement("jax"), 57 | pypi_requirement("jax/extend"), 58 | ], 59 | ) 60 | -------------------------------------------------------------------------------- /tools/install_bazelisk.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #!/usr/bin/env bash 3 | # ============================================================================== 4 | # Copyright 2024 The JAX SC Authors. 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | # ============================================================================== 18 | # 19 | # Script to install bazelisk locally to $HOME/bin. 20 | # 21 | # Usage: 22 | # . build/install_bazelisk.sh 23 | 24 | if [[ -z "${BAZELISK_VERSION}" ]]; then 25 | BAZELISK_VERSION=v1.15.0 26 | fi 27 | 28 | # Downloads bazelisk to ~/bin as `bazel`. 29 | function install_bazelisk { 30 | case "$(uname -s)" in 31 | Darwin) local name=bazelisk-darwin-amd64 ;; 32 | Linux) 33 | case "$(uname -m)" in 34 | x86_64) local name=bazelisk-linux-amd64 ;; 35 | aarch64) local name=bazelisk-linux-arm64 ;; 36 | *) die "Unknown machine type: $(uname -m)" ;; 37 | esac ;; 38 | *) die "Unknown OS: $(uname -s)" ;; 39 | esac 40 | 41 | mkdir -p "$HOME/bin" 42 | wget --no-verbose -O "$HOME/bin/bazel" \ 43 | "https://github.com/bazelbuild/bazelisk/releases/download/$BAZELISK_VERSION/$name" \ 44 | 2> /dev/null 45 | 46 | chmod u+x "$HOME/bin/bazel" 47 | if [[ ! ":$PATH:" =~ :"$HOME"/bin/?: ]]; then 48 | export PATH="$HOME/bin:$PATH" 49 | fi 50 | } 51 | 52 | install_bazelisk 53 | -------------------------------------------------------------------------------- /jax_tpu_embedding/sparsecore/jax_tpu_embedding.bzl: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The JAX SC Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Provides python test rules for jax tpu embedding TPU tests.""" 15 | 16 | load("//third_party/bazel/python:pytype.bzl", "pytype_strict_contrib_test") 17 | 18 | # Visibility rules. 19 | EXTERNAL_USERS = ["//visibility:public"] 20 | 21 | # Use jax_tpu_embedding/sparsecore/lib/nn/embedding.py. 22 | CORE_USERS = [ 23 | "//jax_tpu_embedding/sparsecore:__subpackages__", 24 | ] 25 | 26 | def tpu_py_strict_test( 27 | name, 28 | tags = None, 29 | deps = None, 30 | args = [], 31 | **kwargs): 32 | """Generates unit test for TPU. 33 | 34 | Args: 35 | name: Name of test. Will be prefixed by accelerator versions. 36 | tags: BUILD tags to apply to tests. 37 | deps: Dependencies of the test. 38 | args: Arguments to apply to tests. 39 | **kwargs: Additional named arguments to apply to tests. 40 | 41 | """ 42 | tags = tags or [] 43 | deps = deps or [] 44 | kwargs.setdefault("main", "%s.py" % name) 45 | kwargs.setdefault("python_version", "PY3") 46 | 47 | args = [ 48 | "--logtostderr", 49 | ] + args 50 | 51 | tags = [ 52 | "requires-tpu", 53 | ] + tags 54 | 55 | pytype_strict_contrib_test( 56 | name = name, 57 | tags = tags, 58 | deps = deps, 59 | args = args, 60 | **kwargs 61 | ) 62 | -------------------------------------------------------------------------------- /third_party/bazel/python/pypi.bzl: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The JAX SC Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Utilities for working with pypi dependencies.""" 15 | 16 | load("@pypi//:requirements.bzl", "requirement") 17 | 18 | # Use a map for python packages whose names don't precisely correspond to the 19 | # import names. For example, the 'absl' python is 'absl_py'. These are the 20 | # exceptions. Most packages do have a direct correspondence (e.g. jax, numpy). 21 | _PYPI_PACKAGE_MAP = { 22 | "absl": "absl_py", 23 | "google/protobuf": "protobuf", 24 | } 25 | 26 | def pypi_requirement(dep): 27 | """Determines the pypi package dependency for a target. 28 | 29 | Args: 30 | dep: dependency target 31 | 32 | Returns: 33 | pypi requirement. 34 | """ 35 | package_name = dep 36 | 37 | # Remove target in root package. 38 | target_sep = package_name.find(":") 39 | if target_sep >= 0: 40 | package_name = package_name[:target_sep] 41 | 42 | # Check map if there is a direct dependency replacement. 43 | package_name = _PYPI_PACKAGE_MAP.get(package_name, package_name) 44 | 45 | # Remove any subpackage names. 46 | path_sep = package_name.find("/") 47 | if path_sep >= 0: 48 | package_name = package_name[:path_sep] 49 | 50 | # Replace any known package name substitutions. 51 | package_name = _PYPI_PACKAGE_MAP.get(package_name, package_name) 52 | 53 | return requirement(package_name) 54 | -------------------------------------------------------------------------------- /jax_tpu_embedding/sparsecore/lib/core/all_reduce_interface.h: -------------------------------------------------------------------------------- 1 | // Copyright 2024 The JAX SC Authors. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | #ifndef JAX_TPU_EMBEDDING_SPARSECORE_LIB_CORE_ALL_REDUCE_INTERFACE_H_ 15 | #define JAX_TPU_EMBEDDING_SPARSECORE_LIB_CORE_ALL_REDUCE_INTERFACE_H_ 16 | 17 | #include 18 | 19 | #include "absl/status/statusor.h" // from @com_google_absl 20 | 21 | namespace jax_sc_embedding { 22 | 23 | // Interface for performing all-reduce operations across multiple participants. 24 | // Implementations of this interface are used to synchronize and aggregate 25 | // values (e.g., boolean flags or uint64_t masks) across different hosts or 26 | // devices using a shared `sync_key`. 27 | class AllReduceInterface { 28 | public: 29 | virtual ~AllReduceInterface() = default; 30 | // Performs a blocking all-reduce operation for a boolean value. 31 | // The result is typically the logical OR of all `minibatching_required` 32 | // values from participants sharing the same `sync_key`. 33 | virtual absl::StatusOr BlockingAllReduce( 34 | int sync_key, bool minibatching_required) = 0; 35 | 36 | // Performs a blocking all-reduce operation for a uint64_t value. 37 | // The result is typically the logical OR of all `minibatching_split` 38 | // values from participants sharing the same `sync_key`. 39 | virtual absl::StatusOr BlockingAllReduce( 40 | int sync_key, uint64_t minibatching_split) = 0; 41 | }; 42 | } // namespace jax_sc_embedding 43 | 44 | #endif // JAX_TPU_EMBEDDING_SPARSECORE_LIB_CORE_ALL_REDUCE_INTERFACE_H_ 45 | -------------------------------------------------------------------------------- /jax_tpu_embedding/sparsecore/BUILD: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The JAX SC Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | load("@bazel_skylib//rules:common_settings.bzl", "string_flag") 15 | load("//third_party/bazel/python:pytype.bzl", "pytype_strict_library") 16 | load(":configure_file.bzl", "configure_file") 17 | 18 | package( 19 | default_applicable_licenses = ["//:license"], 20 | default_visibility = ["//visibility:public"], 21 | ) 22 | 23 | package_group( 24 | name = "internal", 25 | packages = [ 26 | "//jax_tpu_embedding/sparsecore/...", 27 | ], 28 | ) 29 | 30 | # Sets the version suffix for the package, e.g. dev20250321-805775f. 31 | string_flag( 32 | name = "version_suffix", 33 | build_setting_default = "", 34 | ) 35 | 36 | # Sets the git commit for the package, e.g. 805775fcb5f9272e4c52dce751b00cf7f70364f2. 37 | string_flag( 38 | name = "git_commit", 39 | build_setting_default = "", 40 | ) 41 | 42 | configure_file( 43 | name = "version", 44 | flag_substitutions = { 45 | "VERSION_SUFFIX": ":version_suffix", 46 | "GIT_COMMIT": ":git_commit", 47 | }, 48 | output = "version.py", 49 | template = "version.py.in", 50 | ) 51 | 52 | pytype_strict_library( 53 | name = "sparsecore", 54 | srcs = [ 55 | "__init__.py", 56 | ":version", 57 | ], 58 | deps = [ 59 | "//jax_tpu_embedding/sparsecore/lib", # buildcleaner: keep 60 | "//jax_tpu_embedding/sparsecore/utils", # buildcleaner: keep 61 | ], 62 | ) 63 | 64 | # copybara:uncomment_begin(internal) 65 | # exports_files(srcs = ["METADATA"]) 66 | # copybara:uncomment_end 67 | -------------------------------------------------------------------------------- /jax_tpu_embedding/sparsecore/lib/flax/linen/BUILD: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The JAX SC Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | load("//jax_tpu_embedding/sparsecore:jax_tpu_embedding.bzl", "EXTERNAL_USERS") 15 | load("//third_party/bazel/python:pypi.bzl", "pypi_requirement") 16 | load("//third_party/bazel/python:pytype.bzl", "pytype_strict_library") 17 | 18 | package( 19 | default_applicable_licenses = ["//:license"], 20 | default_visibility = EXTERNAL_USERS, 21 | ) 22 | 23 | pytype_strict_library( 24 | name = "embed", 25 | srcs = [ 26 | "embed.py", 27 | ], 28 | deps = [ 29 | "//jax_tpu_embedding/sparsecore/lib/nn:embedding", 30 | "//jax_tpu_embedding/sparsecore/lib/nn:embedding_spec", 31 | "//jax_tpu_embedding/sparsecore/utils", 32 | pypi_requirement("flax:core"), 33 | pypi_requirement("jax"), 34 | pypi_requirement("jax/experimental:layout"), 35 | pypi_requirement("numpy"), 36 | ], 37 | ) 38 | 39 | pytype_strict_library( 40 | name = "embed_optimizer", 41 | srcs = [ 42 | "embed_optimizer.py", 43 | ], 44 | deps = [ 45 | ":embed", 46 | pypi_requirement("jax"), 47 | pypi_requirement("optax"), 48 | ], 49 | ) 50 | 51 | # Library target. 52 | pytype_strict_library( 53 | name = "linen", 54 | srcs = ["__init__.py"], 55 | visibility = [ 56 | "//jax_tpu_embedding/sparsecore/lib:__pkg__", 57 | "//jax_tpu_embedding/sparsecore/lib/flax:__pkg__", 58 | ], 59 | deps = [ 60 | ":embed", # buildcleaner: keep 61 | ":embed_optimizer", # buildcleaner: keep 62 | ], 63 | ) 64 | -------------------------------------------------------------------------------- /jax_tpu_embedding/sparsecore/lib/fdo/BUILD: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The JAX SC Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | load("//third_party/bazel/python:pypi.bzl", "pypi_requirement") 15 | load("//third_party/bazel/python:pytype.bzl", "pytype_strict_contrib_test", "pytype_strict_library") 16 | 17 | package( 18 | default_applicable_licenses = ["//:license"], 19 | default_visibility = ["//jax_tpu_embedding/sparsecore:internal"], 20 | ) 21 | 22 | pytype_strict_library( 23 | name = "fdo_client", 24 | srcs = ["fdo_client.py"], 25 | deps = ["//jax_tpu_embedding/sparsecore/lib/nn:embedding"], 26 | ) 27 | 28 | pytype_strict_library( 29 | name = "file_fdo_client", 30 | srcs = ["file_fdo_client.py"], 31 | deps = [ 32 | ":fdo_client", 33 | "//jax_tpu_embedding/sparsecore/lib/nn:embedding", 34 | pypi_requirement("absl/logging"), 35 | pypi_requirement("jax"), 36 | pypi_requirement("numpy"), 37 | ], 38 | ) 39 | 40 | pytype_strict_contrib_test( 41 | name = "file_fdo_client_test", 42 | srcs = ["file_fdo_client_test.py"], 43 | env = {"JAX_PLATFORMS": "cpu"}, 44 | deps = [ 45 | ":file_fdo_client", 46 | "//jax_tpu_embedding/sparsecore/lib/nn:embedding", 47 | pypi_requirement("absl/testing:absltest"), 48 | pypi_requirement("numpy"), 49 | ], 50 | ) 51 | 52 | # Library target. 53 | pytype_strict_library( 54 | name = "fdo", 55 | srcs = ["__init__.py"], 56 | visibility = ["//jax_tpu_embedding/sparsecore/lib:__pkg__"], 57 | deps = [ 58 | ":fdo_client", # buildcleaner: keep 59 | ":file_fdo_client", # buildcleaner: keep 60 | ], 61 | ) 62 | -------------------------------------------------------------------------------- /jax_tpu_embedding/sparsecore/lib/core/unity_weights_stream_impl.h: -------------------------------------------------------------------------------- 1 | // Copyright 2024 The JAX SC Authors. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | #ifndef JAX_TPU_EMBEDDING_SPARSECORE_LIB_CORE_UNITY_WEIGHTS_STREAM_IMPL_H_ 15 | #define JAX_TPU_EMBEDDING_SPARSECORE_LIB_CORE_UNITY_WEIGHTS_STREAM_IMPL_H_ 16 | 17 | namespace jax_sc_embedding { 18 | 19 | // Class to iterate over a sparse CSR array, providing unity weights for each 20 | // value. This class takes an existing `ValuesStream` (e.g., 21 | // `SparseCsrInputBatchStream`) and provides an interface to iterate over the 22 | // same structure, but returning a weight of 1.0 for each value instead of the 23 | // actual value. This is useful when the input does not have associated weights 24 | // but the processing logic expects a weight stream. 25 | template 26 | class UnityWeightsStream { 27 | public: 28 | UnityWeightsStream(const ValuesStream& value_stream) 29 | : value_stream_(value_stream), curr_col_(0) {} 30 | 31 | int size() const { return value_stream_.size(); } 32 | 33 | int cols() const { return value_stream_.cols(); } 34 | 35 | void NextRow() { curr_col_ = 0; } 36 | 37 | void NextCol() { ++curr_col_; } 38 | 39 | void SeekCol(int col) { curr_col_ = col; } 40 | 41 | int row() const { return value_stream_.row(); } 42 | 43 | int col() const { return curr_col_; } 44 | 45 | float get() const { return 1.0f; } 46 | 47 | private: 48 | const ValuesStream& value_stream_; 49 | int curr_col_; 50 | }; 51 | 52 | template 53 | UnityWeightsStream(T) -> UnityWeightsStream; 54 | 55 | } // namespace jax_sc_embedding 56 | 57 | #endif // JAX_TPU_EMBEDDING_SPARSECORE_LIB_CORE_UNITY_WEIGHTS_STREAM_IMPL_H_ 58 | -------------------------------------------------------------------------------- /jax_tpu_embedding/sparsecore/examples/models/shakespeare/BUILD: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The JAX SC Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | load("//jax_tpu_embedding/sparsecore:jax_tpu_embedding.bzl", "EXTERNAL_USERS") 15 | load("//third_party/bazel/python:pypi.bzl", "pypi_requirement") 16 | load("//third_party/bazel/python:pytype.bzl", "pytype_strict_library") 17 | 18 | package( 19 | default_applicable_licenses = ["//:license"], 20 | default_visibility = EXTERNAL_USERS, 21 | ) 22 | 23 | pytype_strict_library( 24 | name = "dataset", 25 | srcs = [ 26 | "dataset.py", 27 | ], 28 | deps = [ 29 | pypi_requirement("absl/logging"), 30 | pypi_requirement("numpy"), 31 | ], 32 | ) 33 | 34 | pytype_strict_library( 35 | name = "model", 36 | srcs = [ 37 | "model.py", 38 | ], 39 | deps = [ 40 | "//jax_tpu_embedding/sparsecore/lib/nn:embedding_spec", 41 | pypi_requirement("flax:core"), 42 | pypi_requirement("jax"), 43 | pypi_requirement("optax"), 44 | ], 45 | ) 46 | 47 | pytype_strict_library( 48 | name = "flax_model", 49 | srcs = [ 50 | "flax_model.py", 51 | ], 52 | deps = [ 53 | "//jax_tpu_embedding/sparsecore/lib/flax/linen:embed", 54 | "//jax_tpu_embedding/sparsecore/lib/nn:embedding", 55 | "//jax_tpu_embedding/sparsecore/lib/nn:embedding_spec", 56 | pypi_requirement("flax:core"), 57 | pypi_requirement("jax"), 58 | ], 59 | ) 60 | 61 | pytype_strict_library( 62 | name = "flax_nnx_model", 63 | srcs = [ 64 | "flax_nnx_model.py", 65 | ], 66 | deps = [ 67 | "//jax_tpu_embedding/sparsecore/lib/flax/nnx:embed", 68 | "//jax_tpu_embedding/sparsecore/lib/nn:embedding", 69 | "//jax_tpu_embedding/sparsecore/lib/nn:embedding_spec", 70 | pypi_requirement("flax/nnx"), 71 | pypi_requirement("jax"), 72 | ], 73 | ) 74 | -------------------------------------------------------------------------------- /jax_tpu_embedding/sparsecore/lib/core/minibatching_test_utils.h: -------------------------------------------------------------------------------- 1 | // Copyright 2024 The JAX SC Authors. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | #ifndef JAX_TPU_EMBEDDING_SPARSECORE_LIB_CORE_MINIBATCHING_TEST_UTILS_H_ 15 | #define JAX_TPU_EMBEDDING_SPARSECORE_LIB_CORE_MINIBATCHING_TEST_UTILS_H_ 16 | 17 | #include 18 | #include 19 | #include 20 | 21 | #include "absl/strings/str_cat.h" // from @com_google_absl 22 | #include "jax_tpu_embedding/sparsecore/lib/core/grpc/minibatching_node.h" 23 | #include "tsl/platform/test.h" // from @tsl 24 | 25 | namespace jax_sc_embedding { 26 | namespace testing_utils { 27 | 28 | // Helper function to set up MinibatchingNode instances for each host. 29 | inline std::vector> 30 | SetUpMinibatchingNodes(int num_hosts, int threads_per_task = 1) { 31 | std::vector ports; 32 | ports.reserve(num_hosts); 33 | for (int i = 0; i < num_hosts; ++i) { 34 | ports.push_back(tsl::testing::PickUnusedPortOrDie()); 35 | } 36 | 37 | std::vector peer_addresses; 38 | peer_addresses.reserve(num_hosts); 39 | for (int i = 0; i < num_hosts; ++i) { 40 | peer_addresses.push_back(absl::StrCat("localhost:", ports[i])); 41 | } 42 | 43 | std::vector> nodes; 44 | nodes.reserve(num_hosts); 45 | for (int i = 0; i < num_hosts; ++i) { 46 | std::vector other_peer_addresses; 47 | for (int j = 0; j < num_hosts; ++j) { 48 | if (i == j) continue; 49 | other_peer_addresses.push_back(peer_addresses[j]); 50 | } 51 | nodes.push_back(std::make_unique( 52 | /*task_id=*/i, /*num_tasks=*/num_hosts, other_peer_addresses, ports[i], 53 | threads_per_task)); 54 | } 55 | return nodes; 56 | } 57 | 58 | } // namespace testing_utils 59 | } // namespace jax_sc_embedding 60 | 61 | #endif // JAX_TPU_EMBEDDING_SPARSECORE_LIB_CORE_MINIBATCHING_TEST_UTILS_H_ 62 | -------------------------------------------------------------------------------- /jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_threads.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2024 The JAX SC Authors. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | #include "jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_threads.h" 15 | 16 | #include 17 | #include 18 | 19 | #include "absl/log/check.h" // from @com_google_absl 20 | #include "absl/log/log.h" // from @com_google_absl 21 | #include "absl/strings/numbers.h" // from @com_google_absl 22 | #include "tsl/platform/env.h" // from @tsl 23 | #include "tsl/platform/threadpool.h" // from @tsl 24 | #include "tsl/platform/cpu_info.h" 25 | 26 | namespace jax_sc_embedding { 27 | 28 | namespace { 29 | 30 | constexpr char kScEnv[] = "SPARSECORE_INPUT_PREPROCESSING_THREADS"; 31 | constexpr char kScPool[] = "SparseCoreInputPreprocessingThreadPool"; 32 | 33 | // Returns at least one but the minimum of NumSchedulableCPUs() and the value 34 | // specified by the environment variable 35 | // `SPARSECORE_INPUT_PREPROCESSING_THREADS`. 36 | int GetThreadPoolSize() { 37 | int num_threads = tsl::port::NumSchedulableCPUs(); 38 | if (const char* env = std::getenv(kScEnv); env != nullptr) { 39 | int n; 40 | if (absl::SimpleAtoi(env, &n) && 0 < n && n < num_threads) { 41 | num_threads = n; 42 | } 43 | } 44 | return std::max(1, num_threads); 45 | } 46 | 47 | } // namespace 48 | 49 | tsl::thread::ThreadPool* PreprocessingThreadPool() { 50 | static tsl::thread::ThreadPool* pool = []() { 51 | const int num_threads = GetThreadPoolSize(); 52 | DCHECK_GE(num_threads, 1); 53 | LOG(INFO) << "Creating thread pool for SparseCore input " 54 | "preprocessing: " 55 | << num_threads << " threads"; 56 | auto thread_pool = new tsl::thread::ThreadPool( 57 | tsl::Env::Default(), tsl::ThreadOptions(), kScPool, num_threads, 58 | /*low_latency_hint=*/false); 59 | return thread_pool; 60 | }(); 61 | return pool; 62 | } 63 | 64 | } // namespace jax_sc_embedding 65 | -------------------------------------------------------------------------------- /jax_tpu_embedding/sparsecore/utils/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The JAX SC Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Utilities for examples.""" 15 | 16 | from absl import flags 17 | import jax 18 | from jax.experimental import layout 19 | 20 | if jax.__version_info__ >= (0, 6, 3): 21 | Layout = layout.Layout 22 | else: 23 | Layout = layout.DeviceLocalLayout # type: ignore 24 | 25 | 26 | _DUMP_DIR = flags.DEFINE_string( 27 | 'dump_dir', None, 'Directory to write debug dumps to.' 28 | ) 29 | 30 | NUM_SC_PER_DEVICE_MAP = { 31 | 'TPU v5': 4, 32 | 'TPU v6 lite': 2, 33 | 'TPU7x': 4, 34 | } 35 | 36 | 37 | def num_sparsecores_per_device(device: jax.Device | None = None): 38 | """Determine the number of sparsecores available on a device. 39 | 40 | Args: 41 | device: JAX device to check. If None, queries the first device in 42 | jax.devices(). 43 | 44 | Returns: 45 | Number of sparsecores. 46 | 47 | Raises: 48 | ValueError: if the number of sparsecores cannot be determined. 49 | """ 50 | device = device or jax.devices()[0] 51 | 52 | if not hasattr(device, 'device_kind'): 53 | raise ValueError(f'Cannot determine device kind for device: {device}') 54 | 55 | device_kind = device.device_kind 56 | if device_kind not in NUM_SC_PER_DEVICE_MAP: 57 | raise ValueError(f'Unknown sparsecore count for device kind: {device_kind}') 58 | 59 | return NUM_SC_PER_DEVICE_MAP[device_kind] 60 | 61 | 62 | def tree_summary(tree): 63 | """Returns the shape and dtype of each leaf in the tree.""" 64 | return jax.tree.map(lambda x: (x.shape, x.dtype), tree) 65 | 66 | 67 | def embedding_table_format( 68 | mesh: jax.sharding.Mesh, partition_spec: jax.sharding.PartitionSpec 69 | ) -> jax.sharding.Sharding: 70 | """Returns the layout format of the embedding table.""" 71 | return layout.Format( # pytype: disable=bad-return-type 72 | Layout( 73 | major_to_minor=(0, 1), 74 | tiling=((8,),), 75 | ), 76 | jax.sharding.NamedSharding(mesh, partition_spec), 77 | ) 78 | -------------------------------------------------------------------------------- /jax_tpu_embedding/sparsecore/lib/flax/nnx/tests/BUILD: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The JAX SC Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | load("//jax_tpu_embedding/sparsecore:jax_tpu_embedding.bzl", "tpu_py_strict_test") 15 | load("//third_party/bazel/python:pypi.bzl", "pypi_requirement") 16 | 17 | package( 18 | default_applicable_licenses = ["//:license"], 19 | default_visibility = ["//visibility:private"], 20 | ) 21 | 22 | tpu_py_strict_test( 23 | name = "autograd_test", 24 | srcs = [ 25 | "autograd_test.py", 26 | ], 27 | env = { 28 | "XLA_FLAGS": "--xla_dump_to=sponge", 29 | "JAX_TRACEBACK_FILTERING": "off", 30 | }, 31 | deps = [ 32 | "//jax_tpu_embedding/sparsecore/examples/models/shakespeare:dataset", 33 | "//jax_tpu_embedding/sparsecore/examples/models/shakespeare:flax_nnx_model", 34 | "//jax_tpu_embedding/sparsecore/lib/flax/nnx:embed", 35 | "//jax_tpu_embedding/sparsecore/lib/nn:embedding", 36 | "//jax_tpu_embedding/sparsecore/lib/nn:embedding_spec", 37 | "//jax_tpu_embedding/sparsecore/utils", 38 | "//third_party/jax:tpu_support", 39 | pypi_requirement("absl/flags"), 40 | pypi_requirement("absl/logging"), 41 | pypi_requirement("absl/testing:absltest"), 42 | pypi_requirement("flax/nnx"), 43 | pypi_requirement("jax"), 44 | pypi_requirement("numpy"), 45 | pypi_requirement("optax"), 46 | ], 47 | ) 48 | 49 | tpu_py_strict_test( 50 | name = "embed_test", 51 | srcs = ["embed_test.py"], 52 | env = { 53 | # "XLA_FLAGS": "--xla_dump_to=sponge", 54 | "JAX_TRACEBACK_FILTERING": "off", 55 | }, 56 | deps = [ 57 | "//jax_tpu_embedding/sparsecore/lib/flax/nnx:embed", 58 | "//jax_tpu_embedding/sparsecore/lib/nn:embedding", 59 | "//jax_tpu_embedding/sparsecore/lib/nn:embedding_spec", 60 | "//jax_tpu_embedding/sparsecore/lib/nn/tests:test_utils", 61 | "//jax_tpu_embedding/sparsecore/utils", 62 | "//third_party/jax:tpu_support", 63 | pypi_requirement("absl/testing:absltest"), 64 | pypi_requirement("absl/testing:parameterized"), 65 | pypi_requirement("einops"), 66 | pypi_requirement("flax/nnx"), 67 | pypi_requirement("jax"), 68 | pypi_requirement("numpy"), 69 | pypi_requirement("portpicker"), 70 | ], 71 | ) 72 | -------------------------------------------------------------------------------- /jax_tpu_embedding/sparsecore/examples/models/shakespeare/flax_model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The JAX SC Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Shakespeare model using embedding layer.""" 15 | 16 | from flax import linen as nn 17 | import jax 18 | import jax.numpy as jnp 19 | from jax_tpu_embedding.sparsecore.lib.flax.linen import embed 20 | from jax_tpu_embedding.sparsecore.lib.nn import embedding 21 | from jax_tpu_embedding.sparsecore.lib.nn import embedding_spec 22 | 23 | Nested = embedding.Nested 24 | 25 | 26 | ################################################################################ 27 | # Define the model. 28 | ################################################################################ 29 | class Model(nn.Module): 30 | """Shakespeare model using embedding layer.""" 31 | 32 | feature_specs: Nested[embedding_spec.FeatureSpec] 33 | global_batch_size: int 34 | vocab_size: int 35 | seq_len: int 36 | embedding_size: int 37 | mesh: jax.sharding.Mesh 38 | feature_name: str = 'shakespeare_feature' 39 | sharding_axis: str = 'sparsecore_sharding' 40 | 41 | def add_sharding_constraint(self, x: jax.Array): 42 | """Add a sharding constraint to the array. 43 | 44 | Add a sharding constraint to the array to ensure that the sharding 45 | information is not lost during compilation. This may not be necessary but 46 | it helps SPMD and ensures that the sharding information is as expected. 47 | 48 | Args: 49 | x: The array to add the sharding constraint to. 50 | 51 | Returns: 52 | The array with the sharding constraint added. 53 | """ 54 | return jax.lax.with_sharding_constraint( 55 | x, 56 | jax.sharding.NamedSharding( 57 | self.mesh, jax.sharding.PartitionSpec(self.sharding_axis) 58 | ), 59 | ) 60 | 61 | @nn.compact 62 | def __call__(self, embedding_lookup_inputs: embedding.PreprocessedInput): 63 | # Run the embedding layer. 64 | x = embed.SparseCoreEmbed( 65 | feature_specs=self.feature_specs, 66 | mesh=self.mesh, 67 | sharding_axis=self.sharding_axis, 68 | )(embedding_lookup_inputs) 69 | 70 | # Unpack the activations. 71 | x = x[self.feature_name] 72 | x = jnp.reshape(x, (self.global_batch_size, -1)) 73 | x = self.add_sharding_constraint(x) 74 | 75 | # Apply the dense portion of the model. 76 | x = nn.Dense(self.embedding_size)(x) 77 | x = self.add_sharding_constraint(x) 78 | x = nn.Dense(self.vocab_size)(x) 79 | x = self.add_sharding_constraint(x) 80 | 81 | return x 82 | -------------------------------------------------------------------------------- /jax_tpu_embedding/sparsecore/lib/core/input_preprocessing.h: -------------------------------------------------------------------------------- 1 | // Copyright 2024 The JAX SC Authors. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | #ifndef JAX_TPU_EMBEDDING_SPARSECORE_LIB_CORE_INPUT_PREPROCESSING_H_ 15 | #define JAX_TPU_EMBEDDING_SPARSECORE_LIB_CORE_INPUT_PREPROCESSING_H_ 16 | 17 | #include 18 | #include 19 | #include 20 | 21 | #include "absl/container/flat_hash_map.h" // from @com_google_absl 22 | #include "absl/status/statusor.h" // from @com_google_absl 23 | #include "absl/types/span.h" // from @com_google_absl 24 | #include "jax_tpu_embedding/sparsecore/lib/core/abstract_input_batch.h" 25 | #include "jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util.h" 26 | 27 | namespace jax_sc_embedding { 28 | 29 | struct SparseDenseMatmulInputStats { 30 | StackedTableMap max_ids_per_partition; 31 | StackedTableMap max_unique_ids_per_partition; 32 | StackedTableMap required_buffer_sizes; 33 | StackedTableMap dropped_id_count; 34 | 35 | int TotalDroppedIdCount() const { 36 | int sum = 0; 37 | for (const auto& [_, v] : dropped_id_count) sum += v; 38 | return sum; 39 | } 40 | // Merge another SparseDenseMatmulInputStats object into the current one. 41 | void merge(const SparseDenseMatmulInputStats& other); 42 | }; 43 | 44 | namespace internal { 45 | ExtractedCooTensors ExtractCooTensorsForAllFeaturesPerLocalDevice( 46 | absl::Span stacked_table_metadata, 47 | absl::Span> input_batches, 48 | int local_device_id, const PreprocessSparseDenseMatmulInputOptions& options, 49 | bool has_variable_weights = false); 50 | } // namespace internal 51 | 52 | struct PreprocessSparseDenseMatmulOutput { 53 | StackedTableMap lhs_row_pointers; 54 | StackedTableMap lhs_embedding_ids; 55 | StackedTableMap lhs_sample_ids; 56 | StackedTableMap lhs_gains; 57 | int num_minibatches; 58 | SparseDenseMatmulInputStats stats; 59 | }; 60 | 61 | absl::StatusOr 62 | PreprocessSparseDenseMatmulInput( 63 | absl::Span> input_batches, 64 | const absl::flat_hash_map>& 65 | stacked_tables, 66 | const PreprocessSparseDenseMatmulInputOptions& options, 67 | OutputCsrArrays* output_csr_arrays = nullptr); 68 | 69 | } // namespace jax_sc_embedding 70 | 71 | #endif // JAX_TPU_EMBEDDING_SPARSECORE_LIB_CORE_INPUT_PREPROCESSING_H_ 72 | -------------------------------------------------------------------------------- /jax_tpu_embedding/sparsecore/tests/BUILD: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The JAX SC Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | load("//jax_tpu_embedding/sparsecore:jax_tpu_embedding.bzl", "tpu_py_strict_test") 15 | load("//third_party/bazel/python:pypi.bzl", "pypi_requirement") 16 | load("//third_party/bazel/python:pytype.bzl", "pytype_strict_contrib_test") 17 | 18 | package( 19 | default_applicable_licenses = ["//:license"], 20 | default_visibility = ["//visibility:private"], 21 | ) 22 | 23 | pytype_strict_contrib_test( 24 | name = "version_test", 25 | srcs = ["version_test.py"], 26 | data = ["//jax_tpu_embedding/sparsecore:version"], 27 | deps = [ 28 | "//jax_tpu_embedding/sparsecore", 29 | pypi_requirement("absl/testing:absltest"), 30 | ], 31 | ) 32 | 33 | tpu_py_strict_test( 34 | name = "jax_sc_shakespeare_tests", 35 | srcs = ["jax_sc_shakespeare_tests.py"], 36 | env = { 37 | "XLA_FLAGS": "--xla_dump_to=sponge", 38 | }, 39 | deps = [ 40 | "//jax_tpu_embedding/sparsecore/examples/models/shakespeare:dataset", 41 | "//jax_tpu_embedding/sparsecore/examples/models/shakespeare:model", 42 | "//jax_tpu_embedding/sparsecore/lib/nn:embedding", 43 | "//jax_tpu_embedding/sparsecore/utils", 44 | "//third_party/jax:tpu_support", 45 | pypi_requirement("absl/flags"), 46 | pypi_requirement("absl/logging"), 47 | pypi_requirement("absl/testing:absltest"), 48 | pypi_requirement("einops"), 49 | pypi_requirement("jax"), 50 | pypi_requirement("numpy"), 51 | pypi_requirement("optax"), 52 | pypi_requirement("orbax/checkpoint"), 53 | ], 54 | ) 55 | 56 | tpu_py_strict_test( 57 | name = "jax_spmd_tc_with_sc_tests", 58 | srcs = ["jax_spmd_tc_with_sc_tests.py"], 59 | env = { 60 | "XLA_FLAGS": "--xla_dump_to=sponge", 61 | }, 62 | deps = [ 63 | "//jax_tpu_embedding/sparsecore/examples/models/shakespeare:dataset", 64 | "//jax_tpu_embedding/sparsecore/lib/nn:embedding", 65 | "//jax_tpu_embedding/sparsecore/lib/nn:embedding_spec", 66 | "//jax_tpu_embedding/sparsecore/utils", 67 | "//third_party/jax:tpu_support", 68 | pypi_requirement("absl/flags"), 69 | pypi_requirement("absl/logging"), 70 | pypi_requirement("absl/testing:absltest"), 71 | pypi_requirement("einops"), 72 | pypi_requirement("flax:core"), 73 | pypi_requirement("jax"), 74 | pypi_requirement("numpy"), 75 | pypi_requirement("optax"), 76 | ], 77 | ) 78 | -------------------------------------------------------------------------------- /jax_tpu_embedding/sparsecore/configure_file.bzl: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The JAX SC Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Generate a file using a template.""" 15 | 16 | load("@bazel_skylib//rules:common_settings.bzl", "BuildSettingInfo") 17 | 18 | def _get_flag_substitutions(flag_substitutions): 19 | """Extracts flag values.""" 20 | substitutions = {} 21 | for key, label in flag_substitutions.items(): 22 | substitutions[key] = label[BuildSettingInfo].value 23 | return substitutions 24 | 25 | def _create_substitution_map(string_substitutions): 26 | """Replaces {key: value} with {${key}: value}""" 27 | substitutions = {} 28 | for key, value in string_substitutions.items(): 29 | key_var = "${" + key + "}" 30 | substitutions[key_var] = value 31 | return substitutions 32 | 33 | def configure_file( 34 | name, 35 | template, 36 | output, 37 | substitutions = {}, 38 | flag_substitutions = {}): 39 | """Generates a file using a template. 40 | 41 | For every entry in the substitutions maps, replaces `${variable}` with `value`. 42 | 43 | Args: 44 | name: The name of the rule. 45 | template: The template file in which to perform the substitutions. 46 | output: The output file. 47 | substitutions: A map of string substitutions {variable: value}. 48 | flag_substitutions: A map of variable to bazel string_flag substitutions. \ 49 | Replacement values are extracted from the flag. 50 | 51 | Returns: 52 | A rule that generates the output file. 53 | """ 54 | _configure_file( 55 | name = name, 56 | template = template, 57 | output = output, 58 | substitutions = substitutions, 59 | flag_substitutions = flag_substitutions, 60 | ) 61 | 62 | def _configure_file_impl(ctx): 63 | substitutions = _create_substitution_map(ctx.attr.substitutions | _get_flag_substitutions(ctx.attr.flag_substitutions)) 64 | ctx.actions.expand_template( 65 | template = ctx.file.template, 66 | output = ctx.outputs.output, 67 | substitutions = substitutions, 68 | ) 69 | 70 | _configure_file = rule( 71 | implementation = _configure_file_impl, 72 | attrs = { 73 | "template": attr.label( 74 | mandatory = True, 75 | allow_single_file = True, 76 | ), 77 | "substitutions": attr.string_dict(), 78 | "flag_substitutions": attr.string_keyed_label_dict(), 79 | "output": attr.output(mandatory = True), 80 | }, 81 | ) 82 | -------------------------------------------------------------------------------- /jax_tpu_embedding/sparsecore/lib/flax/linen/tests/BUILD: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The JAX SC Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | load("//jax_tpu_embedding/sparsecore:jax_tpu_embedding.bzl", "tpu_py_strict_test") 15 | load("//third_party/bazel/python:pypi.bzl", "pypi_requirement") 16 | 17 | package( 18 | default_applicable_licenses = ["//:license"], 19 | default_visibility = ["//visibility:private"], 20 | ) 21 | 22 | tpu_py_strict_test( 23 | name = "embed_test", 24 | srcs = ["embed_test.py"], 25 | deps = [ 26 | "//jax_tpu_embedding/sparsecore/lib/flax/linen:embed", 27 | "//jax_tpu_embedding/sparsecore/lib/nn:embedding", 28 | "//jax_tpu_embedding/sparsecore/lib/nn:embedding_spec", 29 | "//jax_tpu_embedding/sparsecore/lib/nn/tests:test_utils", 30 | "//jax_tpu_embedding/sparsecore/utils", 31 | "//third_party/jax:tpu_support", 32 | pypi_requirement("absl/testing:absltest"), 33 | pypi_requirement("absl/testing:parameterized"), 34 | pypi_requirement("einops"), 35 | pypi_requirement("flax:core"), 36 | pypi_requirement("jax"), 37 | pypi_requirement("numpy"), 38 | ], 39 | ) 40 | 41 | tpu_py_strict_test( 42 | name = "autograd_test", 43 | srcs = [ 44 | "autograd_test.py", 45 | ], 46 | env = { 47 | "XLA_FLAGS": "--xla_dump_to=sponge", 48 | "JAX_TRACEBACK_FILTERING": "off", 49 | }, 50 | deps = [ 51 | "//jax_tpu_embedding/sparsecore/examples/models/shakespeare:dataset", 52 | "//jax_tpu_embedding/sparsecore/examples/models/shakespeare:flax_model", 53 | "//jax_tpu_embedding/sparsecore/lib/flax/linen:embed_optimizer", 54 | "//jax_tpu_embedding/sparsecore/lib/nn:embedding", 55 | "//jax_tpu_embedding/sparsecore/lib/nn:embedding_spec", 56 | "//jax_tpu_embedding/sparsecore/utils", 57 | "//third_party/jax:tpu_support", 58 | pypi_requirement("absl/flags"), 59 | pypi_requirement("absl/logging"), 60 | pypi_requirement("absl/testing:absltest"), 61 | pypi_requirement("flax:core"), 62 | pypi_requirement("jax"), 63 | pypi_requirement("numpy"), 64 | pypi_requirement("optax"), 65 | ], 66 | ) 67 | 68 | tpu_py_strict_test( 69 | name = "embed_optimizer_test", 70 | srcs = ["embed_optimizer_test.py"], 71 | deps = [ 72 | "//jax_tpu_embedding/sparsecore/lib/flax/linen:embed_optimizer", 73 | "//third_party/jax:tpu_support", 74 | pypi_requirement("absl/testing:absltest"), 75 | pypi_requirement("flax:core"), 76 | pypi_requirement("jax"), 77 | pypi_requirement("optax"), 78 | ], 79 | ) 80 | -------------------------------------------------------------------------------- /.github/workflows/build_and_test.yml: -------------------------------------------------------------------------------- 1 | name: Build and test 2 | 3 | on: 4 | # Only run workflow on pushes to main (includes PR merge), and on 5 | # opened pull-requests. 6 | push: 7 | branches: 8 | - main 9 | pull_request: 10 | 11 | permissions: 12 | contents: read # Lets the job check out your code 13 | actions: write # Lets the job write (save) the cache 14 | 15 | jobs: 16 | build-and-test-cpu: 17 | runs-on: linux-x86-n4-16 18 | container: us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:latest 19 | steps: 20 | - uses: actions/checkout@v4 21 | 22 | - name: Set up Python 3.10 23 | uses: actions/setup-python@v5 24 | with: 25 | python-version: '3.10' 26 | 27 | - name: Display Python version 28 | run: python -c "import sys; print(sys.version)" 29 | 30 | - name: Install dependencies 31 | run: | 32 | python -m pip install --upgrade --root-user-action=ignore pip setuptools wheel 33 | bash tools/install_bazelisk.sh 34 | 35 | # --- FOR PULL REQUESTS --- 36 | # Restore the cache from the main branch's base commit. 37 | - if: github.event_name == 'pull_request' 38 | name: Mount bazel cache (pull-request) 39 | uses: actions/cache/restore@v4 40 | with: 41 | path: "/__w/.cache/bazel" 42 | key: bazel-py3.10-${{ github.base_ref }}-${{ github.event.pull_request.base.sha }} 43 | restore-keys: | 44 | bazel-py3.10-${{ github.base_ref }} 45 | bazel-py3.10- 46 | bazel- 47 | 48 | # --- FOR MAIN PUSHES AND MERGES --- 49 | # Restore the cache from the previous commit. 50 | - if: github.event_name != 'pull_request' 51 | name: Mount bazel cache (main) 52 | uses: actions/cache/restore@v4 53 | with: 54 | path: "/__w/.cache/bazel" 55 | # Try to find the cache from the exact commit *before* this push 56 | key: bazel-py3.10-${{ github.ref_name }}-${{ github.event.before }} 57 | # Fall back to the most recent cache for the main branch 58 | restore-keys: | 59 | bazel-py3.10-${{ github.ref_name }} 60 | bazel-py3.10- 61 | bazel- 62 | 63 | - name: Build all targets 64 | run: | 65 | export HERMETIC_PYTHON_VERSION=3.10 66 | bazel build --config=public_cache //... --build_tag_filters=-oss_exclude 67 | 68 | - name: Build wheel 69 | run: | 70 | export HERMETIC_PYTHON_VERSION=3.10 71 | bazel run --config=public_cache //tools:build_wheel 72 | 73 | - name: Run CPU tests 74 | run: | 75 | export HERMETIC_PYTHON_VERSION=3.10 76 | bazel test --config=public_cache --build_tests_only --test_tag_filters=-requires-tpu --test_output=errors --keep_going //... 77 | 78 | - name: Check disk space 79 | run: | 80 | df -h --total 81 | du -sh /__w/.cache/bazel 82 | 83 | # --- SAVE NEW CACHE --- 84 | - name: Save bazel cache (main) 85 | if: github.event_name != 'pull_request' 86 | uses: actions/cache/save@v4 87 | with: 88 | path: "/__w/.cache/bazel" 89 | key: bazel-py3.10-${{ github.ref_name }}-${{ github.sha }} 90 | -------------------------------------------------------------------------------- /jax_tpu_embedding/sparsecore/lib/core/grpc/minibatching_node.h: -------------------------------------------------------------------------------- 1 | // Copyright 2024 The JAX SC Authors. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | #ifndef JAX_TPU_EMBEDDING_SPARSECORE_LIB_CORE_GRPC_MINIBATCHING_NODE_H_ 15 | #define JAX_TPU_EMBEDDING_SPARSECORE_LIB_CORE_GRPC_MINIBATCHING_NODE_H_ 16 | 17 | #include 18 | #include 19 | #include 20 | 21 | #include "absl/base/attributes.h" // from @com_google_absl 22 | #include "absl/strings/str_cat.h" // from @com_google_absl 23 | #include "include/grpcpp/server_builder.h" // from @com_github_grpc_grpc 24 | #include "include/grpcpp/server_context.h" // from @com_github_grpc_grpc 25 | #include "jax_tpu_embedding/sparsecore/lib/core/all_reduce_interface.h" 26 | #include "jax_tpu_embedding/sparsecore/lib/core/grpc/all_reduce_interface.h" 27 | #include "jax_tpu_embedding/sparsecore/lib/core/grpc/all_reduce_service_impl.h" 28 | #include "jax_tpu_embedding/sparsecore/lib/core/grpc/grpc_credentials.h" 29 | 30 | namespace jax_sc_embedding { 31 | namespace rpc { 32 | 33 | // This class encapsulates the gRPC server and client-side interface required 34 | // for performing all-reduce operations across multiple hosts in a minibatching 35 | // setup. It initializes an `AllReduceServiceImpl` to handle incoming 36 | // All-Reduce RPCs from other peers and a `GrpcAllReduceInterface` to 37 | // initiate All-Reduce RPCs to other peers. 38 | class MinibatchingNode { 39 | public: 40 | MinibatchingNode(int task_id, int num_tasks, 41 | std::vector peer_addresses, 42 | int minibatching_port, int threads_per_task = 1) 43 | : all_reduce_service_(std::make_unique( 44 | task_id, num_tasks, threads_per_task)), 45 | all_reduce_interface_(std::make_unique( 46 | peer_addresses, task_id, num_tasks, all_reduce_service_.get(), 47 | threads_per_task)), 48 | all_reduce_server_( 49 | ::grpc::ServerBuilder() 50 | .AddListeningPort(absl::StrCat("[::]:", minibatching_port), 51 | GetDefaultServerCredentials()) 52 | .RegisterService(all_reduce_service_.get()) 53 | .BuildAndStart()) {} 54 | 55 | AllReduceInterface* GetAllReduceInterface() ABSL_ATTRIBUTE_LIFETIME_BOUND { 56 | return all_reduce_interface_.get(); 57 | } 58 | 59 | private: 60 | std::unique_ptr all_reduce_service_; 61 | std::unique_ptr all_reduce_interface_; 62 | std::unique_ptr<::grpc::Server> all_reduce_server_; 63 | }; 64 | } // namespace rpc 65 | } // namespace jax_sc_embedding 66 | 67 | #endif // JAX_TPU_EMBEDDING_SPARSECORE_LIB_CORE_GRPC_MINIBATCHING_NODE_H_ 68 | -------------------------------------------------------------------------------- /jax_tpu_embedding/sparsecore/lib/core/abstract_input_batch.h: -------------------------------------------------------------------------------- 1 | // Copyright 2024 The JAX SC Authors. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | #ifndef JAX_TPU_EMBEDDING_SPARSECORE_LIB_CORE_ABSTRACT_INPUT_BATCH_H_ 15 | #define JAX_TPU_EMBEDDING_SPARSECORE_LIB_CORE_ABSTRACT_INPUT_BATCH_H_ 16 | #include 17 | 18 | #include 19 | #include 20 | 21 | #include "absl/base/attributes.h" // from @com_google_absl 22 | #include "jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util.h" 23 | 24 | namespace jax_sc_embedding { 25 | // NOTE: Converting input data to a C++ native type can be expensive. Therefore, 26 | // we define a read-only wrapper to abstract the input data. 27 | // Represents a batch of inputs for a single Feature and corresponding weights. 28 | class AbstractInputBatch { 29 | public: 30 | struct ExtractCooTensorsOptions { 31 | // Start index of the slice to be processed (inclusive). 32 | const int slice_start ABSL_REQUIRE_EXPLICIT_INIT; 33 | // End index of the slice to be processed (exclusive). 34 | const int slice_end ABSL_REQUIRE_EXPLICIT_INIT; 35 | // Row offset to be added to the sample id. 36 | const int row_offset ABSL_REQUIRE_EXPLICIT_INIT; 37 | // Column offset to be added to the embedding id. 38 | const int col_offset ABSL_REQUIRE_EXPLICIT_INIT; 39 | // Number of bits to shift the embedding id. 40 | const int col_shift ABSL_REQUIRE_EXPLICIT_INIT; 41 | // Number of sparse cores per device. Used to compute COO tensor counts per 42 | // SC. 43 | const int num_sc_per_device ABSL_REQUIRE_EXPLICIT_INIT; 44 | // Number of sparse cores. 45 | const uint32_t num_scs ABSL_REQUIRE_EXPLICIT_INIT; 46 | // Combiner to be used for the row. 47 | const RowCombiner combiner ABSL_REQUIRE_EXPLICIT_INIT; 48 | }; 49 | 50 | // Returns the number of samples (e.g., rows) in this input batch. 51 | virtual int64_t size() const = 0; 52 | 53 | // Returns the total number of embedding IDs across all samples. 54 | virtual int64_t id_count() const = 0; 55 | 56 | // Returns number of ids in rows [start_row, end_row). 57 | // If not implemented by a subclass, returns std::nullopt. 58 | virtual std::optional GetIdsCountInSlice(int start_row, 59 | int end_row) const { 60 | return std::nullopt; 61 | } 62 | 63 | // Returns true if the input batch has variable weights. 64 | virtual bool HasVariableWeights() const { return true; } 65 | 66 | virtual void ExtractCooTensors( 67 | const ExtractCooTensorsOptions& options, 68 | ExtractedCooTensors& extracted_coo_tensors) = 0; 69 | 70 | virtual ~AbstractInputBatch() = default; 71 | }; 72 | 73 | } // namespace jax_sc_embedding 74 | 75 | #endif // JAX_TPU_EMBEDDING_SPARSECORE_LIB_CORE_ABSTRACT_INPUT_BATCH_H_ 76 | -------------------------------------------------------------------------------- /jax_tpu_embedding/sparsecore/lib/flax/linen/tests/embed_optimizer_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The JAX SC Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import operator 15 | from typing import Any 16 | 17 | from absl.testing import absltest 18 | from flax import struct 19 | import jax 20 | import jax.numpy as jnp 21 | from jax_tpu_embedding.sparsecore.lib.flax.linen import embed_optimizer 22 | import optax 23 | 24 | 25 | class EmbedOptimizerTest(absltest.TestCase): 26 | 27 | def test_create_and_apply_optimizer_for_sc_model(self): 28 | # Define a sample model params with the specified path 29 | class ModelParams(struct.PyTreeNode): 30 | params: Any # Unnecessary indirection to introduce a non-dict key 31 | 32 | model_params = ModelParams( 33 | params={ 34 | "layers_0": { 35 | "sc_embedding_variables": { 36 | "value": {"table": {"table": jnp.array([1.0, 2.0])}} 37 | } 38 | }, 39 | "layers_2": { 40 | "Dense_0": { 41 | "bias": jnp.array([1.0, 2.0]), 42 | "kernel": jnp.array([1.0, 2.0]), 43 | }, 44 | "Dense_1": { 45 | "bias": jnp.array([1.0, 2.0]), 46 | "kernel": jnp.array([1.0, 2.0]), 47 | }, 48 | }, 49 | } 50 | ) 51 | # Create the optimizer 52 | optimizer: optax.GradientTransformation = ( 53 | embed_optimizer.create_optimizer_for_sc_model( 54 | model_params, 55 | tc_optimizer=optax.sgd(learning_rate=0.1), 56 | ) 57 | ) 58 | 59 | state = optimizer.init(model_params) # pytype:disable=wrong-arg-types 60 | rand_key = jax.random.key(42) 61 | updates = jax.tree.map( 62 | lambda x: jax.random.uniform(rand_key, x.shape), model_params 63 | ) 64 | transformed_updates, _ = optimizer.update(updates, state, model_params) # pytype:disable=wrong-arg-types 65 | new_model_params: ModelParams = embed_optimizer.apply_updates_for_sc_model( 66 | model_params, transformed_updates 67 | ) 68 | expected_updated_params = ModelParams( 69 | params={ 70 | "layers_0": updates.params[ 71 | "layers_0" 72 | ], # SC apply updates just return updates 73 | "layers_2": jax.tree.map( 74 | operator.add, 75 | model_params.params["layers_2"], 76 | transformed_updates.params["layers_2"], 77 | ), 78 | } 79 | ) 80 | jax.tree.map( 81 | lambda x, y: self.assertTrue( 82 | jnp.allclose(x, y), f"Arrays {x} and {y} are not equal." 83 | ), 84 | new_model_params, 85 | expected_updated_params, 86 | ) 87 | 88 | 89 | if __name__ == "__main__": 90 | absltest.main() 91 | -------------------------------------------------------------------------------- /jax_tpu_embedding/sparsecore/lib/fdo/fdo_client.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The JAX SC Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Abstract interface for FDO client.""" 15 | 16 | import abc 17 | 18 | from jax_tpu_embedding.sparsecore.lib.nn import embedding 19 | 20 | 21 | class FDOClient(abc.ABC): 22 | """Abstract interface for FDO client. 23 | 24 | This class defines the interface for a per process client that interacts with 25 | the FDO system. An implementation of this class should define how the FDO 26 | stats are recorded and published to the storage location(disk, database, 27 | etc.). The load method should return the current aggregated (across all 28 | processes) stats from the storage location. 29 | 30 | Typical usage: 31 | 1. Create an instance of an implementation of FDOClient. 32 | 2. Call `record` to record the raw stats to the process local FDO client. 33 | 3. (Optional) Repeat a few steps of training. 34 | 3. Call `publish` on the singleton instance to publish the stats to the 35 | storage location. 36 | 4. Call `load` on the singleton instance to get the aggregated (across all 37 | processes) stats from the storage location. 38 | """ 39 | 40 | @abc.abstractmethod 41 | def record( 42 | self, 43 | data: embedding.SparseDenseMatmulInputStats, 44 | ) -> None: 45 | """Records the raw stats to local memory. 46 | 47 | An implementation of this method defines how the raw stats are processed and 48 | stored in preparation for publishing to the storage location. 49 | 50 | Args: 51 | data: Mapping of data stats to be recorded. 52 | """ 53 | raise NotImplementedError 54 | 55 | @abc.abstractmethod 56 | def publish(self) -> None: 57 | """Publishes stats to the storage location. 58 | 59 | An implementation of this method defines how the stats are published to the 60 | storage location. For instance, this could involve writing the stats to a 61 | file or updating a database. 62 | """ 63 | raise NotImplementedError 64 | 65 | @abc.abstractmethod 66 | def load( 67 | self, 68 | ) -> embedding.SparseDenseMatmulInputStats: 69 | """Loads state of local FDO client and returns the aggregated stats. 70 | 71 | An implementation of this method defines how the stats are aggregated across 72 | all processes. For instance, this could involve reading the stats from all 73 | files written by `publish` or a database and then aggregating them. 74 | 75 | Returns: 76 | A tuple of (max_ids, max_uniques, required_buffer_size) where max_ids 77 | is a mapping of table name to max ids per partition, max_uniques is a 78 | mapping of table name to max unique ids per partition, and 79 | required_buffer_sizes is a mapping of table name to required buffer size 80 | per sparsecore. 81 | """ 82 | raise NotImplementedError 83 | -------------------------------------------------------------------------------- /jax_tpu_embedding/sparsecore/lib/flax/linen/embed_optimizer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The JAX SC Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Optimizer for models with SparseCore modules.""" 15 | 16 | from typing import Any 17 | 18 | import jax 19 | from jax import numpy as jnp 20 | from jax_tpu_embedding.sparsecore.lib.flax.linen import embed 21 | import optax 22 | 23 | 24 | def _is_emb_path(path: list[Any]) -> bool: 25 | return any( 26 | isinstance(level, jax.tree_util.DictKey) 27 | and level.key == embed.EMBEDDING_PARAM_NAME 28 | for level in path 29 | ) 30 | 31 | 32 | def create_optimizer_for_sc_model( 33 | params: Any, tc_optimizer: optax.GradientTransformation 34 | ) -> optax.GradientTransformation: 35 | """Create the optimizer for the model. 36 | 37 | Args: 38 | params: A PyTree of model parameters. 39 | tc_optimizer: The optimizer for the TensorCore part of the model. 40 | 41 | Returns: 42 | An optax.GradientTransformation that applies updates to the model. 43 | """ 44 | embedding_params_tree = jax.tree_util.tree_map_with_path( 45 | lambda path, v: ( 46 | 'tc_optimizer' if not _is_emb_path(path) else 'sc_optimizer' 47 | ), 48 | params, 49 | ) 50 | 51 | # Create optimizer for the model. 52 | return optax.multi_transform( 53 | { 54 | 'tc_optimizer': tc_optimizer, 55 | 'sc_optimizer': _get_optimizer_for_optax(), 56 | }, 57 | embedding_params_tree, 58 | ) 59 | 60 | 61 | def apply_updates_for_sc_model(params, updates): 62 | """Apply the updates to the params for models with SparseCore modules.""" 63 | 64 | def apply_update_to_params(path, param, update): 65 | if not _is_emb_path(path): 66 | return jnp.asarray(param + update).astype(jnp.asarray(update).dtype) 67 | else: 68 | return _apply_update(param, update) 69 | 70 | return jax.tree_util.tree_map_with_path( 71 | apply_update_to_params, 72 | params, 73 | updates, 74 | ) 75 | 76 | 77 | def _get_optimizer_for_optax() -> optax.GradientTransformation: 78 | # For now, the optimizer is part of the SC grad op. 79 | # We create a trivial optimizer to simply return the new embedding table. 80 | # 81 | # For the long run, we'd like to have SC grad op to return the real gradients, 82 | # and this function would need to create the real optimizer for SC. 83 | return optax.GradientTransformation( 84 | init=lambda params: optax.EmptyState(), 85 | update=lambda grads, state, params: (grads, state), 86 | ) 87 | 88 | 89 | def _apply_update(params, updates): 90 | # For now, since the grad op and the SC dummy optimizer are 91 | # returning the updated embedding table as the "update", we just need to 92 | # use the updated embedding table here. 93 | # 94 | # For the long run, we'd like to implement logic to apply the real 95 | # embedding table updates to embedding tables. 96 | del params 97 | return updates 98 | -------------------------------------------------------------------------------- /jax_tpu_embedding/sparsecore/lib/nn/BUILD: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The JAX SC Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | load("//jax_tpu_embedding/sparsecore:jax_tpu_embedding.bzl", "EXTERNAL_USERS") 15 | load("//third_party/bazel/python:pypi.bzl", "pypi_requirement") 16 | load("//third_party/bazel/python:pytype.bzl", "pytype_strict_library") 17 | 18 | package( 19 | default_applicable_licenses = ["//:license"], 20 | default_visibility = EXTERNAL_USERS, 21 | ) 22 | 23 | pytype_strict_library( 24 | name = "embedding_spec", 25 | srcs = [ 26 | "embedding_spec.py", 27 | ], 28 | deps = [ 29 | "//jax_tpu_embedding/sparsecore/lib/core/primitives:sparse_dense_matmul_grad_with_adagrad", 30 | "//jax_tpu_embedding/sparsecore/lib/core/primitives:sparse_dense_matmul_grad_with_adagrad_momentum", 31 | "//jax_tpu_embedding/sparsecore/lib/core/primitives:sparse_dense_matmul_grad_with_adam", 32 | "//jax_tpu_embedding/sparsecore/lib/core/primitives:sparse_dense_matmul_grad_with_ftrl", 33 | "//jax_tpu_embedding/sparsecore/lib/core/primitives:sparse_dense_matmul_grad_with_laprop", 34 | "//jax_tpu_embedding/sparsecore/lib/core/primitives:sparse_dense_matmul_grad_with_sgd", 35 | pypi_requirement("flax:core"), 36 | pypi_requirement("jax"), 37 | pypi_requirement("jax/extend"), 38 | ], 39 | ) 40 | 41 | pytype_strict_library( 42 | name = "embedding", 43 | srcs = ["embedding.py"], 44 | deps = [ 45 | ":embedding_spec", 46 | ":table_stacking", 47 | "//jax_tpu_embedding/sparsecore/lib/core:pybind_input_preprocessing", 48 | "//jax_tpu_embedding/sparsecore/lib/core/primitives:sparse_dense_matmul_csr", 49 | "//jax_tpu_embedding/sparsecore/lib/proto:embedding_spec_py_pb2", 50 | "//jax_tpu_embedding/sparsecore/utils", 51 | pypi_requirement("absl/logging"), 52 | pypi_requirement("einops"), 53 | pypi_requirement("flax:core"), 54 | pypi_requirement("jax"), 55 | pypi_requirement("numpy"), 56 | ], 57 | ) 58 | 59 | pytype_strict_library( 60 | name = "table_stacking", 61 | srcs = ["table_stacking.py"], 62 | deps = [ 63 | ":embedding_spec", 64 | "//jax_tpu_embedding/sparsecore/lib/proto:embedding_spec_py_pb2", 65 | pypi_requirement("absl/logging"), 66 | pypi_requirement("jax"), 67 | pypi_requirement("numpy"), 68 | ], 69 | ) 70 | 71 | pytype_strict_library( 72 | name = "embedding_pipelining_utils", 73 | srcs = [ 74 | "embedding_pipelining_utils.py", 75 | ], 76 | deps = [ 77 | ":embedding", 78 | ":embedding_spec", 79 | pypi_requirement("flax:core"), 80 | pypi_requirement("jax"), 81 | ], 82 | ) 83 | 84 | pytype_strict_library( 85 | name = "nn", 86 | srcs = ["__init__.py"], 87 | visibility = ["//jax_tpu_embedding/sparsecore/lib:__pkg__"], 88 | deps = [ 89 | ":embedding", # buildcleaner: keep 90 | ":embedding_spec", # buildcleaner: keep 91 | ":table_stacking", # buildcleaner: keep 92 | ], 93 | ) 94 | -------------------------------------------------------------------------------- /jax_tpu_embedding/sparsecore/lib/core/grpc/all_reduce_interface.h: -------------------------------------------------------------------------------- 1 | // Copyright 2024 The JAX SC Authors. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | #ifndef JAX_TPU_EMBEDDING_SPARSECORE_LIB_CORE_GRPC_ALL_REDUCE_INTERFACE_H_ 15 | #define JAX_TPU_EMBEDDING_SPARSECORE_LIB_CORE_GRPC_ALL_REDUCE_INTERFACE_H_ 16 | 17 | #include 18 | #include 19 | #include 20 | #include 21 | 22 | #include "absl/container/flat_hash_map.h" // from @com_google_absl 23 | #include "absl/log/check.h" // from @com_google_absl 24 | #include "absl/log/log.h" // from @com_google_absl 25 | #include "absl/status/statusor.h" // from @com_google_absl 26 | #include "absl/strings/str_join.h" // from @com_google_absl 27 | #include "jax_tpu_embedding/sparsecore/lib/core/all_reduce_interface.h" 28 | #include "jax_tpu_embedding/sparsecore/lib/core/grpc/all_reduce.grpc.pb.h" 29 | #include "jax_tpu_embedding/sparsecore/lib/core/grpc/all_reduce_service_impl.h" 30 | 31 | namespace jax_sc_embedding { 32 | namespace rpc { 33 | 34 | class GrpcAllReduceInterface final : public AllReduceInterface { 35 | public: 36 | GrpcAllReduceInterface(std::vector peer_addresses, int task_id, 37 | int num_tasks, AllReduceServiceImpl* local_service, 38 | int threads_per_task = 1) 39 | : peer_addresses_(peer_addresses), 40 | task_id_(task_id), 41 | num_tasks_(num_tasks), 42 | threads_per_task_(threads_per_task), 43 | local_service_(local_service) { 44 | VLOG(2) << "GrpcAllReduceInterface created with task_id: " << task_id 45 | << " num_tasks: " << num_tasks 46 | << " peer_addresses: " << absl::StrJoin(peer_addresses, ","); 47 | CHECK_EQ(peer_addresses_.size(), num_tasks_ - 1); 48 | SetUp(); 49 | } 50 | 51 | // Performs a blocking All-Reduce operation. 52 | // `sync_key`: A unique key for this all-reduce operation. 53 | // `minibatching_required`: The local value to be reduced. 54 | absl::StatusOr BlockingAllReduce(int sync_key, 55 | bool minibatching_required) override; 56 | 57 | // Performs a blocking All-Reduce operation. 58 | // `sync_key`: A unique key for this all-reduce operation. 59 | // `minibatching_split`: The local value to be reduced. 60 | absl::StatusOr BlockingAllReduce( 61 | int sync_key, uint64_t minibatching_split) override; 62 | 63 | private: 64 | // Internal helper to perform the gRPC-based blocking All-Reduce. 65 | // `request`: Contains the sync_key, src_rank, and the value to be reduced. 66 | // Returns the reduced value. 67 | absl::StatusOr BlockingAllReduce(const AllReduceData& request); 68 | 69 | std::vector peer_addresses_; 70 | int task_id_; 71 | int num_tasks_; 72 | int threads_per_task_; 73 | AllReduceServiceImpl* local_service_; // Not owned. 74 | 75 | absl::flat_hash_map> 76 | stubs_; 77 | 78 | private: 79 | // Create gRPC channels to other peers. 80 | void SetUp(); 81 | }; 82 | 83 | } // namespace rpc 84 | } // namespace jax_sc_embedding 85 | 86 | #endif // JAX_TPU_EMBEDDING_SPARSECORE_LIB_CORE_GRPC_ALL_REDUCE_INTERFACE_H_ 87 | -------------------------------------------------------------------------------- /jax_tpu_embedding/sparsecore/lib/core/numpy_input_batch.h: -------------------------------------------------------------------------------- 1 | // Copyright 2024 The JAX SC Authors. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | #ifndef JAX_TPU_EMBEDDING_SPARSECORE_LIB_CORE_NUMPY_INPUT_BATCH_H_ 15 | #define JAX_TPU_EMBEDDING_SPARSECORE_LIB_CORE_NUMPY_INPUT_BATCH_H_ 16 | 17 | #include 18 | #include 19 | 20 | #include "absl/log/check.h" // from @com_google_absl 21 | #include "jax_tpu_embedding/sparsecore/lib/core/abstract_input_batch.h" 22 | #include "jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util.h" 23 | #include "pybind11/cast.h" // from @pybind11 24 | #include "pybind11/gil.h" // from @pybind11 25 | #include "pybind11/numpy.h" // from @pybind11 26 | #include "pybind11/pybind11.h" // from @pybind11 27 | 28 | namespace jax_sc_embedding { 29 | 30 | namespace py = ::pybind11; 31 | 32 | class NumpySparseInputBatch : public AbstractInputBatch { 33 | public: 34 | NumpySparseInputBatch(const py::array& feature) 35 | : NumpySparseInputBatch(feature, std::nullopt) {} 36 | 37 | NumpySparseInputBatch(const py::array& feature, 38 | std::optional weights) 39 | : feature_(feature), weights_(weights) { 40 | DCHECK(PyGILState_Check()) 41 | << "Need GIL to create references to features and weights."; 42 | if (weights_.has_value()) { 43 | CHECK_EQ(feature_.shape(0), weights_->shape(0)) 44 | << "Batch size mismatch for features and weights."; 45 | CHECK_EQ(feature_.ndim(), weights_->ndim()) 46 | << "Dimension mismatch for features and weights"; 47 | } 48 | CHECK(feature_.ndim() == 1 || feature_.ndim() == 2) 49 | << "Only 1D and 2D numpy arrays supported as inputs."; 50 | 51 | if (feature_.ndim() == 1) { 52 | // Iterating over every row to sum up the number of IDs negates the 53 | // performance benefit of reserving memory for them, so we underestimate 54 | // the number of IDs as 1 per sample. 55 | id_count_ = feature_.shape(0); 56 | } else { 57 | id_count_ = feature_.shape(0) * feature_.shape(1); 58 | } 59 | } 60 | 61 | // Returns the number of samples in this input batch. 62 | int64_t size() const override { return feature_.shape(0); } 63 | 64 | // Returns the total number of embedding IDs across all samples. 65 | int64_t id_count() const override { return id_count_; } 66 | 67 | std::optional GetIdsCountInSlice(int start_row, 68 | int end_row) const override { 69 | if (feature_.ndim() == 2) { 70 | return (end_row - start_row) * feature_.shape(1); 71 | } 72 | return std::nullopt; 73 | } 74 | 75 | bool HasVariableWeights() const override { return weights_.has_value(); } 76 | 77 | void ExtractCooTensors(const ExtractCooTensorsOptions& options, 78 | ExtractedCooTensors& coo_tensors) override; 79 | 80 | private: 81 | const py::array feature_; 82 | // NOTE: stored as optional instead of py::object to avoid GIL 83 | // locking to cast to array every time. 84 | const std::optional weights_; 85 | int64_t id_count_; 86 | }; 87 | 88 | } // namespace jax_sc_embedding 89 | 90 | #endif // JAX_TPU_EMBEDDING_SPARSECORE_LIB_CORE_NUMPY_INPUT_BATCH_H_ 91 | -------------------------------------------------------------------------------- /jax_tpu_embedding/sparsecore/examples/models/shakespeare/model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The JAX SC Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Shakespeare next word predictor model.""" 15 | 16 | from typing import Any, Mapping 17 | 18 | from flax import linen as nn 19 | import jax 20 | import jax.numpy as jnp 21 | from jax_tpu_embedding.sparsecore.lib.nn import embedding_spec 22 | import optax 23 | 24 | 25 | class Model(nn.Module): 26 | """A simple model that predicts the next word in a sequence of words. 27 | 28 | Attributes: 29 | global_batch_size: The number of examples in the global batch. 30 | vocab_size: The number of unique words in the vocabulary. 31 | seq_len: The length of the sequences in the global batch. 32 | embedding_size: The dimension of the embedding vectors. 33 | table_name: The name of the embedding table. 34 | feature_name: The name of the embedding feature. 35 | """ 36 | 37 | global_batch_size: int 38 | vocab_size: int 39 | seq_len: int 40 | embedding_size: int 41 | table_name: str = 'shakespeare_table' 42 | feature_name: str = 'shakespeare_feature' 43 | 44 | def create_feature_specs( 45 | self, 46 | ) -> Mapping[str, embedding_spec.FeatureSpec]: 47 | """Creates the feature specs for the Shakespeare model. 48 | 49 | Returns: 50 | The feature specs for the Shakespeare model. 51 | """ 52 | table_spec = embedding_spec.TableSpec( 53 | vocabulary_size=self.vocab_size, 54 | embedding_dim=self.embedding_size, 55 | initializer=jax.nn.initializers.zeros, 56 | optimizer=embedding_spec.SGDOptimizerSpec(), 57 | combiner='sum', 58 | name=self.table_name, 59 | max_ids_per_partition=64, 60 | max_unique_ids_per_partition=64, 61 | ) 62 | feature_spec = embedding_spec.FeatureSpec( 63 | table_spec=table_spec, 64 | input_shape=(self.global_batch_size * self.seq_len, 1), 65 | output_shape=( 66 | self.global_batch_size * self.seq_len, 67 | self.embedding_size, 68 | ), 69 | name=self.feature_name, 70 | ) 71 | feature_specs = nn.FrozenDict({self.feature_name: feature_spec}) 72 | return feature_specs 73 | 74 | @nn.compact 75 | def __call__(self, emb_activations: Mapping[str, jax.Array]): 76 | # Unpack the activations. 77 | x = emb_activations[self.feature_name] 78 | x = jnp.reshape(x, (x.shape[0], -1)) 79 | # Apply the model. 80 | x = nn.Dense(self.embedding_size)(x) 81 | x = nn.Dense(self.vocab_size)(x) 82 | return x 83 | 84 | 85 | def loss( 86 | model: nn.Module, 87 | params: Any, 88 | emb_activations: Mapping[str, jax.Array], 89 | labels: jax.Array, 90 | ) -> tuple[jax.Array, jax.Array]: 91 | """Applies the embedding activations to model and returns loss. 92 | 93 | Args: 94 | model: The model being trained. 95 | params: The parameters of the model. 96 | emb_activations: The embedding activations that will be applied. 97 | labels: The integer labels corresponding to the embedding activations. 98 | 99 | Returns: 100 | The loss. 101 | """ 102 | logits = model.apply(params, emb_activations) 103 | xentropy = optax.softmax_cross_entropy_with_integer_labels( 104 | logits=logits, labels=labels 105 | ) 106 | return jnp.mean(xentropy), logits 107 | -------------------------------------------------------------------------------- /jax_tpu_embedding/sparsecore/lib/core/primitives/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The JAX SC Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Utils for sparsecore grad primitives.""" 15 | 16 | from typing import Any 17 | 18 | from jax import typing as jax_typing 19 | import numpy as np 20 | 21 | 22 | def ensure_dtype(check: Any, expected_type: Any, object_name: str): 23 | if check.dtype != expected_type: 24 | raise ValueError( 25 | f"{object_name} must have type {expected_type!r}, got {check.dtype!r}" 26 | ) 27 | 28 | 29 | def ensure_dim(check: Any, expected_dim: int, object_name: str): 30 | if len(check.shape) != expected_dim: 31 | raise ValueError( 32 | f"{object_name} must have dim {expected_dim!r}, got {check.shape!r}" 33 | ) 34 | 35 | 36 | def validate_abstract_eval_params( 37 | lhs_row_pointers: jax_typing.ArrayLike, 38 | lhs_local_embedding_ids: jax_typing.ArrayLike, 39 | lhs_local_sample_ids: jax_typing.ArrayLike, 40 | lhs_gains: jax_typing.ArrayLike, 41 | num_minibatches_per_physical_sparse_core: np.int32, 42 | embedding_table: jax_typing.ArrayLike, 43 | activations_grad: jax_typing.ArrayLike, 44 | max_ids_per_partition: int, 45 | max_unique_ids_per_partition: int, 46 | computation_name: str, 47 | sharding_strategy: int, 48 | ): 49 | """Validate parameters common to all sparsecore grad primitives.""" 50 | ensure_dtype(lhs_row_pointers, np.int32, "lhs_row_pointers") 51 | ensure_dtype(lhs_local_sample_ids, np.int32, "lhs_local_sample_ids") 52 | ensure_dtype(lhs_local_embedding_ids, np.int32, "lhs_local_embedding_ids") 53 | ensure_dtype(lhs_gains, np.float32, "lhs_gains") 54 | ensure_dtype( 55 | num_minibatches_per_physical_sparse_core, 56 | np.int32, 57 | "num_minibatches_per_physical_sparse_core", 58 | ) 59 | ensure_dtype(embedding_table, np.float32, "embedding_table") 60 | ensure_dtype(activations_grad, np.float32, "activations_grad") 61 | ensure_dim(lhs_row_pointers, 1, "lhs_row_pointers") 62 | ensure_dim(embedding_table, 2, "embedding_table") 63 | ensure_dim( 64 | num_minibatches_per_physical_sparse_core, 65 | 0, 66 | "num_minibatches_per_physical_sparse_core", 67 | ) 68 | ensure_dim(activations_grad, 2, "activations_grad") 69 | if ( 70 | lhs_local_sample_ids.shape != lhs_local_embedding_ids.shape 71 | or lhs_gains.shape != lhs_local_embedding_ids.shape 72 | or len(lhs_local_sample_ids.shape) != 1 73 | ): 74 | raise ValueError( 75 | "LHS sample IDs, embedding IDs, and gains must all have " 76 | f"equal rank 1 shapes, got shapes {lhs_local_sample_ids.shape}, " 77 | f"{lhs_local_embedding_ids.shape} and {lhs_gains.shape}" 78 | ) 79 | if embedding_table.shape[-1] != activations_grad.shape[-1]: 80 | raise ValueError( 81 | "embedding_table and activations_grad must have equal feature (minor)" 82 | f" dimensions, got {embedding_table.shape}, {activations_grad.shape}" 83 | ) 84 | 85 | if sharding_strategy != 1: 86 | raise ValueError( 87 | f"sharding_strategy must be MOD (1), got {sharding_strategy}" 88 | ) 89 | 90 | if max_ids_per_partition <= 0: 91 | raise ValueError( 92 | f"max_ids_per_partition must be positive, got {max_ids_per_partition}" 93 | ) 94 | 95 | if max_unique_ids_per_partition <= 0: 96 | raise ValueError( 97 | "max_unique_ids_per_partition must be positive, got" 98 | f" {max_unique_ids_per_partition}" 99 | ) 100 | if not computation_name: 101 | raise ValueError("computation_name must be non-empty") 102 | -------------------------------------------------------------------------------- /jax_tpu_embedding/sparsecore/lib/core/sparse_coo_input_batch.h: -------------------------------------------------------------------------------- 1 | // Copyright 2024 The JAX SC Authors. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | #ifndef JAX_TPU_EMBEDDING_SPARSECORE_LIB_CORE_SPARSE_COO_INPUT_BATCH_H_ 15 | #define JAX_TPU_EMBEDDING_SPARSECORE_LIB_CORE_SPARSE_COO_INPUT_BATCH_H_ 16 | 17 | #include 18 | 19 | #include 20 | #include 21 | #include 22 | #include 23 | #include 24 | 25 | #include "absl/base/call_once.h" // from @com_google_absl 26 | #include "absl/log/check.h" // from @com_google_absl 27 | #include "jax_tpu_embedding/sparsecore/lib/core/abstract_input_batch.h" 28 | #include "jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util.h" 29 | #include "pybind11/cast.h" // from @pybind11 30 | #include "pybind11/gil.h" // from @pybind11 31 | #include "pybind11/numpy.h" // from @pybind11 32 | #include "pybind11/pybind11.h" // from @pybind11 33 | #include "pybind11/pytypes.h" // from @pybind11 34 | 35 | namespace jax_sc_embedding { 36 | 37 | namespace py = ::pybind11; 38 | 39 | // This class represents a sparse input batch in COO format. 40 | // - `indices` is a 2D array where each row represents a (row_id, col_id) pair. 41 | // It is assumed that `indices` is sorted in row-major order. 42 | // - `values` is a 1D array where each element represents the value associated 43 | // with the corresponding (row_id, col_id) pair in `indices`. 44 | class PySparseCooInputBatch : public AbstractInputBatch { 45 | public: 46 | PySparseCooInputBatch(const py::array_t& indices, 47 | const py::array_t& values, 48 | const py::array_t& dense_shape, 49 | const int64_t max_vocab_id, 50 | const std::string table_name) 51 | : indices_(indices), 52 | values_(values), 53 | max_vocab_id_(max_vocab_id), 54 | batch_size_(dense_shape.at(0)), 55 | table_name_(std::move(table_name)) { 56 | DCHECK(PyGILState_Check()) 57 | << "Need GIL to create references to indices and values."; 58 | } 59 | 60 | // Returns the number of samples in this input batch. 61 | int64_t size() const override { return batch_size_; } 62 | 63 | // Returns the total number of embedding IDs across all samples. 64 | int64_t id_count() const override { return values_.size(); } 65 | 66 | std::optional GetIdsCountInSlice(int start_row, 67 | int end_row) const override { 68 | ConstructRowPointersIfRequired(); 69 | return row_pointers_[end_row] - row_pointers_[start_row]; 70 | } 71 | 72 | bool HasVariableWeights() const override { return false; } 73 | 74 | // Extracts COO tensors for each SparseCore. 75 | void ExtractCooTensors(const ExtractCooTensorsOptions& options, 76 | ExtractedCooTensors& coo_tensors) override; 77 | 78 | private: 79 | // (N,2) array, sorted by row_id. 80 | const py::array_t indices_; 81 | const py::array_t values_; 82 | const int64_t max_vocab_id_; 83 | const int64_t batch_size_; 84 | const std::string table_name_; 85 | 86 | mutable std::vector row_pointers_; 87 | mutable absl::once_flag row_pointer_construction_flag_; 88 | 89 | // Converts this to a CSR format. A refactor could return an object of type 90 | // SparseCsrInputBatch after Slicing, and ExtractCooTensors can call 91 | // the same function on a temporary object of SparseCsrInputBatch type. 92 | void ConstructRowPointersIfRequired() const; 93 | 94 | // Internal function called by `ConstructRowPointersIfRequired`. 95 | void ConstructRowPointers() const; 96 | }; 97 | } // namespace jax_sc_embedding 98 | 99 | #endif // JAX_TPU_EMBEDDING_SPARSECORE_LIB_CORE_SPARSE_COO_INPUT_BATCH_H_ 100 | -------------------------------------------------------------------------------- /jax_tpu_embedding/sparsecore/examples/models/shakespeare/flax_nnx_model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The JAX SC Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Shakespeare model using embedding layer.""" 15 | 16 | from flax import nnx 17 | import jax 18 | import jax.numpy as jnp 19 | from jax_tpu_embedding.sparsecore.lib.flax.nnx import embed 20 | from jax_tpu_embedding.sparsecore.lib.nn import embedding 21 | from jax_tpu_embedding.sparsecore.lib.nn import embedding_spec 22 | 23 | Nested = embedding.Nested 24 | 25 | 26 | ################################################################################ 27 | # Define the model. 28 | ################################################################################ 29 | class Model(nnx.Module): 30 | """Shakespeare model using embedding layer.""" 31 | 32 | def __init__( 33 | self, 34 | *, 35 | feature_specs: Nested[embedding_spec.FeatureSpec], 36 | global_batch_size: int, 37 | vocab_size: int, 38 | seq_len: int, 39 | embedding_size: int, 40 | enable_minibatching: bool, 41 | mesh: jax.sharding.Mesh, 42 | sharding_axis: str, 43 | ): 44 | self.feature_name = 'shakespeare_feature' 45 | assert len(feature_specs) == 1, 'Shakespeare model expects one feature.' 46 | assert self.feature_name in feature_specs, ( 47 | 'Shakespeare model expects feature named "%s".' % self.feature_name 48 | ) 49 | 50 | self.feature_specs = feature_specs 51 | self.global_batch_size = global_batch_size 52 | self.vocab_size = vocab_size 53 | self.seq_len = seq_len 54 | self.embedding_size = embedding_size 55 | self.enable_minibatching = enable_minibatching 56 | self.mesh = mesh 57 | self.sharding_axis = sharding_axis 58 | rngs = nnx.Rngs(params=42) 59 | self.embedding_layer = embed.SparseCoreEmbed( 60 | feature_specs=self.feature_specs, 61 | mesh=self.mesh, 62 | sharding_axis=self.sharding_axis, 63 | rngs=rngs, 64 | enable_minibatching=enable_minibatching, 65 | ) 66 | e = self.embedding_size 67 | v = self.vocab_size 68 | s = self.seq_len 69 | self.dense_layer_1 = nnx.Linear( 70 | in_features=s * e, 71 | out_features=e, 72 | rngs=rngs, 73 | ) 74 | self.dense_layer_2 = nnx.Linear( 75 | in_features=e, 76 | out_features=v, 77 | rngs=rngs, 78 | ) 79 | 80 | def add_sharding_constraint(self, x: jax.Array, names: tuple[str | None]): 81 | # Add a sharding constraint to the array. 82 | # 83 | # Add a sharding constraint to the array to ensure that the sharding 84 | # information is not lost during compilation. This may not be necessary but 85 | # it helps SPMD and ensures that the sharding information is as expected. 86 | # 87 | # Args: 88 | # x: The array to add the sharding constraint to. 89 | # names: The mesh axes for the partition spec. 90 | # 91 | # Returns: 92 | # The array with the sharding constraint added. 93 | return jax.lax.with_sharding_constraint( 94 | x, 95 | jax.sharding.NamedSharding( 96 | self.mesh, jax.sharding.PartitionSpec(*names) 97 | ), 98 | ) 99 | 100 | def __call__(self, embedding_lookup_inputs: embedding.PreprocessedInput): 101 | # Run the embedding layer. 102 | x = self.embedding_layer(embedding_lookup_inputs) 103 | 104 | # Unpack the activations. 105 | x = x[self.feature_name] 106 | x = jnp.reshape(x, (self.global_batch_size, -1)) 107 | x = self.add_sharding_constraint(x, (self.sharding_axis,)) 108 | 109 | # Apply the dense portion of the model. 110 | x = self.dense_layer_1(x) 111 | x = self.add_sharding_constraint(x, (self.sharding_axis,)) 112 | x = self.dense_layer_2(x) 113 | x = self.add_sharding_constraint(x, (self.sharding_axis,)) 114 | 115 | return x 116 | -------------------------------------------------------------------------------- /jax_tpu_embedding/sparsecore/lib/core/sparse_coo_input_batch.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2024 The JAX SC Authors. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | #include "jax_tpu_embedding/sparsecore/lib/core/sparse_coo_input_batch.h" 15 | 16 | #include 17 | 18 | #include 19 | #include 20 | 21 | #include "absl/base/call_once.h" // from @com_google_absl 22 | #include "absl/log/check.h" // from @com_google_absl 23 | #include "absl/types/span.h" // from @com_google_absl 24 | #include "jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util.h" 25 | #include "jax_tpu_embedding/sparsecore/lib/core/process_coo_tensors_impl.h" 26 | #include "jax_tpu_embedding/sparsecore/lib/core/sparse_csr_input_stream_impl.h" 27 | #include "jax_tpu_embedding/sparsecore/lib/core/unity_weights_stream_impl.h" 28 | #include "pybind11/gil.h" // from @pybind11 29 | #include "pybind11/pybind11.h" // from @pybind11 30 | #include "tsl/profiler/lib/traceme.h" 31 | 32 | namespace jax_sc_embedding { 33 | 34 | void PySparseCooInputBatch::ConstructRowPointers() const { 35 | if (!row_pointers_.empty()) { 36 | return; 37 | } 38 | auto indices_array = indices_.unchecked<2>(); 39 | auto values_array = values_.unchecked<1>(); 40 | // Precompute indexes for row starts. Add a sentinel node for last row. 41 | row_pointers_.reserve(batch_size_ + 1); 42 | int row_pointers_index = 0; 43 | int last_row_id = -1; // Only for DCHECK. 44 | int last_col_id = -1; // Only for DCHECK. 45 | int last_val = -1; // Only for DCHECK. 46 | for (int i = 0; i < indices_array.shape(0); ++i) { 47 | const int row_id = indices_array(i, 0), col_id = indices_array(i, 1), 48 | val = values_array(i); 49 | DCHECK_GE(row_id, last_row_id) << "Decreasing row id values for row-major."; 50 | while (row_pointers_index <= row_id) { 51 | // Increment index until we reach the current row. Keep storing the row 52 | // pointers. 53 | row_pointers_.push_back(i); 54 | ++row_pointers_index; 55 | } 56 | 57 | // Loop Invariant: The index should point to one beyond the current row id. 58 | DCHECK_EQ(row_pointers_index, row_id + 1); 59 | 60 | if (row_id == last_row_id) { // Same Row should have increasing col values. 61 | DCHECK_GT(col_id, last_col_id) 62 | << "Non-increasing col id values for row-major."; 63 | } 64 | 65 | last_row_id = row_id; // NOMUTANTS - debugging. 66 | last_col_id = col_id; // NOMUTANTS - debugging. 67 | last_val = val; // NOMUTANTS - debugging. 68 | } 69 | while (row_pointers_index <= batch_size_) { 70 | row_pointers_.push_back(indices_array.shape(0)); 71 | row_pointers_index++; 72 | } 73 | 74 | DCHECK_EQ(row_pointers_.size(), batch_size_ + 1); 75 | } 76 | 77 | void PySparseCooInputBatch::ConstructRowPointersIfRequired() const { 78 | absl::call_once(row_pointer_construction_flag_, 79 | &PySparseCooInputBatch::ConstructRowPointers, this); 80 | } 81 | 82 | void PySparseCooInputBatch::ExtractCooTensors( 83 | const ExtractCooTensorsOptions& options, ExtractedCooTensors& coo_tensors) { 84 | DCHECK(!PyGILState_Check()); // Does not require external GIL. 85 | tsl::profiler::TraceMe t([] { return "ExtractCooTensors"; }); 86 | 87 | ConstructRowPointersIfRequired(); 88 | 89 | SparseCsrInputBatchStream, 91 | absl::Span> 92 | values_stream(values_.unchecked<1>(), absl::MakeConstSpan(row_pointers_), 93 | options.slice_start, options.slice_end, table_name_, 94 | max_vocab_id_); 95 | UnityWeightsStream weights_stream(values_stream); 96 | 97 | ProcessCooTensors(options, values_stream, weights_stream, coo_tensors); 98 | } 99 | } // namespace jax_sc_embedding 100 | -------------------------------------------------------------------------------- /jax_tpu_embedding/sparsecore/lib/core/grpc/all_reduce_service_impl.h: -------------------------------------------------------------------------------- 1 | // Copyright 2024 The JAX SC Authors. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | #ifndef JAX_TPU_EMBEDDING_SPARSECORE_LIB_CORE_GRPC_ALL_REDUCE_SERVICE_IMPL_H_ 15 | #define JAX_TPU_EMBEDDING_SPARSECORE_LIB_CORE_GRPC_ALL_REDUCE_SERVICE_IMPL_H_ 16 | 17 | #include 18 | #include 19 | 20 | #include "absl/base/thread_annotations.h" // from @com_google_absl 21 | #include "absl/container/flat_hash_map.h" // from @com_google_absl 22 | #include "absl/status/statusor.h" // from @com_google_absl 23 | #include "absl/synchronization/barrier.h" // from @com_google_absl 24 | #include "absl/synchronization/blocking_counter.h" // from @com_google_absl 25 | #include "absl/synchronization/mutex.h" // from @com_google_absl 26 | #include "include/grpcpp/server_context.h" // from @com_github_grpc_grpc 27 | #include "include/grpcpp/support/server_callback.h" // from @com_github_grpc_grpc 28 | #include "jax_tpu_embedding/sparsecore/lib/core/grpc/all_reduce.grpc.pb.h" 29 | #include "jax_tpu_embedding/sparsecore/lib/core/grpc/all_reduce.pb.h" // from internal 30 | 31 | namespace jax_sc_embedding { 32 | namespace rpc { 33 | 34 | // Implementation of the gRPC AllReduce service. This class manages the state 35 | // for multiple concurrent all-reduce operations, identified by a `sync_key`. 36 | class AllReduceServiceImpl : public AllReduceGrpcService::CallbackService { 37 | struct AllReduceState { 38 | AllReduceData local_data; 39 | // Counter to wait for all other local threads to make their 40 | // contributions. 41 | std::unique_ptr local_contributions_counter; 42 | // Counter for all local threads to retrieve the results and delete 43 | // this state from the map when done. 44 | std::unique_ptr results_counter; 45 | // Barrier to synchronize all local threads before they can retrieve the 46 | // final result. 47 | std::unique_ptr global_results_barrier; 48 | // Counter for the local threads that performs the RPC to wait for all 49 | // other tasks. 50 | std::unique_ptr incoming_rpc_counter; 51 | }; 52 | 53 | public: 54 | explicit AllReduceServiceImpl(int task_id, int num_tasks, 55 | int threads_per_task = 1) 56 | : task_id_(task_id), 57 | num_tasks_(num_tasks), 58 | threads_per_task_(threads_per_task) {} 59 | 60 | // Called by remote peers. Returns this server's local value for the sync_key. 61 | ::grpc::ServerUnaryReactor* ContributeData( 62 | ::grpc::CallbackServerContext* context, const AllReduceData* request, 63 | AllReduceResponse* response) override; 64 | 65 | // Method to register the local data for a given sync_key. Called by the local 66 | // client. Returns the locally-reduced data for the initializer thread, or 67 | // nullopt for other threads. 68 | absl::StatusOr> InitializeOrUpdateState( 69 | int sync_key, const AllReduceData& data); 70 | 71 | // Waits for incoming RPCs from all other tasks. Should be called from only 72 | // the initializer thread. 73 | void WaitIncomingRPCs(int sync_key); 74 | 75 | // A barrier for all local threads to wait on before retrieving the result. 76 | void WaitResults(int sync_key); 77 | 78 | // Gets locally and globally reduced result. 79 | absl::StatusOr GetResult(int sync_key); 80 | 81 | private: 82 | int task_id_; 83 | int num_tasks_; 84 | // Number of threads (within the same process) that will participate in the 85 | // all-reduce operation. 86 | int threads_per_task_; 87 | 88 | absl::Mutex mutex_; 89 | absl::flat_hash_map all_reduce_state_map_ 90 | ABSL_GUARDED_BY(mutex_); 91 | // CV to wait for state to be updated by all local thread. 92 | absl::CondVar local_reduced_cv_ ABSL_GUARDED_BY(mutex_); 93 | }; 94 | 95 | } // namespace rpc 96 | } // namespace jax_sc_embedding 97 | 98 | #endif // JAX_TPU_EMBEDDING_SPARSECORE_LIB_CORE_GRPC_ALL_REDUCE_SERVICE_IMPL_H_ 99 | -------------------------------------------------------------------------------- /jax_tpu_embedding/sparsecore/lib/auto_pipelining/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The JAX SC Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Utils for auto pipelining.""" 15 | 16 | from collections.abc import Iterable 17 | import itertools 18 | 19 | import jax 20 | import jax.extend as jex 21 | 22 | 23 | EMBEDDING_LOOKUP_PRIMITIVE_PREFIX = 'sparse_dense_matmul_csr' 24 | EMBEDDING_UPDATE_PRIMITIVE_PREFIX = 'sparse_dense_matmul_grad' 25 | CUSTOM_VJP_CALL_PRIMITIVE_NAME = 'custom_vjp_call_jaxpr' 26 | SHARD_MAP_PRIMITIVE_NAME = 'shard_map' 27 | 28 | # The number of data inputs of the embedding lookup shard_map. 29 | # They are row_pointers, embedding_ids, sample_ids, gains, num_minibatches. 30 | EMBEDDING_LOOKUP_DATA_LEN = 5 31 | # The number of data inputs of the embedding update shard_map. 32 | # They are: 33 | # gradients, row_pointers, embedding_ids, sample_ids, gains, num_minibatches. 34 | EMBEDDING_UPDATE_DATA_LEN = EMBEDDING_LOOKUP_DATA_LEN + 1 35 | 36 | 37 | def is_embedding_lookup(eqn: jex.core.JaxprEqn) -> bool: 38 | if eqn.primitive.name != SHARD_MAP_PRIMITIVE_NAME: 39 | return False 40 | jaxpr = eqn.params['jaxpr'] 41 | for sub_eqn in jaxpr.eqns: 42 | if sub_eqn.primitive.name.startswith(EMBEDDING_LOOKUP_PRIMITIVE_PREFIX): 43 | return True 44 | return False 45 | 46 | 47 | def is_embedding_update(eqn: jex.core.JaxprEqn) -> bool: 48 | if eqn.primitive.name != SHARD_MAP_PRIMITIVE_NAME: 49 | return False 50 | jaxpr = eqn.params['jaxpr'] 51 | for sub_eqn in jaxpr.eqns: 52 | if sub_eqn.primitive.name.startswith(EMBEDDING_UPDATE_PRIMITIVE_PREFIX): 53 | return True 54 | return False 55 | 56 | 57 | def get_embedding_lookup_eqn( 58 | eqns: list[jex.core.JaxprEqn], 59 | ) -> jex.core.JaxprEqn: 60 | for eqn in eqns: 61 | if eqn.primitive.name.startswith(EMBEDDING_LOOKUP_PRIMITIVE_PREFIX): 62 | return eqn 63 | assert False, 'No embedding lookup found in the given eqns.' 64 | 65 | 66 | def replace_embedding_lookup_eqn( 67 | eqns: list[jex.core.JaxprEqn], new_lookup_eqn: jex.core.JaxprEqn, 68 | ) -> list[jex.core.JaxprEqn]: 69 | result = [] 70 | for eqn in eqns: 71 | if eqn.primitive.name.startswith(EMBEDDING_LOOKUP_PRIMITIVE_PREFIX): 72 | result.append(new_lookup_eqn) 73 | else: 74 | result.append(eqn) 75 | return result 76 | 77 | 78 | def lookup_params( 79 | eqn: jex.core.JaxprEqn, 80 | ) -> tuple[list[jax.core.Atom], list[jax.core.Atom]]: 81 | return ( 82 | eqn.invars[:EMBEDDING_LOOKUP_DATA_LEN], 83 | eqn.invars[EMBEDDING_LOOKUP_DATA_LEN:], 84 | ) 85 | 86 | 87 | def update_params( 88 | eqn: jex.core.JaxprEqn, 89 | ) -> tuple[list[jax.core.Atom], list[jax.core.Atom]]: 90 | return ( 91 | eqn.invars[:EMBEDDING_UPDATE_DATA_LEN], 92 | eqn.invars[EMBEDDING_UPDATE_DATA_LEN:], 93 | ) 94 | 95 | 96 | def clone_vars(var_list: Iterable[jex.core.Var]) -> list[jex.core.Var]: 97 | return [jex.core.Var(var.aval) for var in var_list] 98 | 99 | 100 | def inline_jaxpr( 101 | jaxpr: jex.core.Jaxpr, 102 | invars: list[jex.core.Var], 103 | outvars: list[jex.core.Var], 104 | ) -> list[jex.core.JaxprEqn]: 105 | """Inlines a jaxpr with given invars and outvars.""" 106 | assert not set(jaxpr.invars).intersection( 107 | jaxpr.outvars 108 | ), 'Returning invars directly as outvars is not supported.' 109 | assert not jaxpr.constvars, 'Jaxpr with consts is not supported.' 110 | assert len(invars) == len(jaxpr.invars) 111 | assert len(outvars) == len(jaxpr.outvars) 112 | 113 | var_mapping = { 114 | var: val 115 | for var, val in itertools.chain( 116 | zip(jaxpr.invars, invars), zip(jaxpr.outvars, outvars) 117 | ) 118 | } 119 | 120 | def _translate_outvar(var: jex.core.Var) -> jex.core.Var: 121 | return var_mapping.setdefault(var, clone_vars([var])[0]) 122 | 123 | def _translate_invar(var: jex.core.Var) -> jex.core.Var: 124 | return var if isinstance(var, jex.core.Literal) else var_mapping[var] 125 | 126 | return [ 127 | eqn.replace( 128 | invars=[_translate_invar(var) for var in eqn.invars], 129 | outvars=[_translate_outvar(var) for var in eqn.outvars], 130 | ) 131 | for eqn in jaxpr.eqns 132 | ] 133 | -------------------------------------------------------------------------------- /jax_tpu_embedding/sparsecore/lib/core/minibatching_splits_test.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2024 The JAX SC Authors. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | 20 | #include 21 | #include 22 | #include "absl/types/span.h" // from @com_google_absl 23 | #include "jax_tpu_embedding/sparsecore/lib/core/minibatching_splits_impl.h" 24 | 25 | namespace jax_sc_embedding { 26 | namespace { 27 | 28 | using ::testing::UnorderedElementsAreArray; 29 | using ::testing::Values; 30 | 31 | class ComputeMinibatchingSplitTest 32 | : public ::testing::TestWithParam< 33 | std::tuple, int32_t, std::bitset<7>>> {}; 34 | 35 | TEST_P(ComputeMinibatchingSplitTest, TestSplits) { 36 | auto [unique_ids_per_bucket, max_ids_per_partition, expected_split] = 37 | GetParam(); 38 | 39 | EXPECT_EQ(internal::ComputeMinibatchingSplit<8>( 40 | absl::MakeSpan(unique_ids_per_bucket), max_ids_per_partition), 41 | expected_split); 42 | } 43 | 44 | INSTANTIATE_TEST_SUITE_P( 45 | TestSplits, ComputeMinibatchingSplitTest, 46 | Values( 47 | // Full Merge 48 | std::make_tuple(std::vector(8, 10), 100, std::bitset<7>(0)), 49 | // No Merges 50 | std::make_tuple(std::vector(8, 100), 150, 51 | std::bitset<7>(0b1111111)), 52 | // Partial Merge 53 | std::make_tuple(std::vector{10, 20, 30, 40, 50, 60, 70, 80}, 54 | 100, std::bitset<7>(0b1101100)), 55 | // Partial Merge 2 56 | std::make_tuple(std::vector{50, 50, 20, 80, 10, 10, 70, 10}, 57 | 100, std::bitset<7>(0b1010000)), 58 | // Partial Merge 3 59 | std::make_tuple(std::vector{90, 10, 90, 10, 90, 10, 90, 10}, 60 | 100, std::bitset<7>(0b1110000)))); 61 | 62 | class MergeBucketsTest : public ::testing::TestWithParam< 63 | std::tuple, std::bitset<7>, 64 | std::vector>>> {}; 65 | 66 | TEST_P(MergeBucketsTest, TestMergeBuckets) { 67 | auto [unique_ids_per_bucket, split_pos, expected_merged_buckets] = GetParam(); 68 | std::vector> merged_buckets; 69 | 70 | auto merge_fn = [&](int left, int right) { 71 | merged_buckets.push_back(std::make_pair(left, right)); 72 | }; 73 | internal::MergeBuckets<8>(split_pos, merge_fn); 74 | 75 | EXPECT_THAT(merged_buckets, 76 | UnorderedElementsAreArray(expected_merged_buckets)); 77 | } 78 | 79 | INSTANTIATE_TEST_SUITE_P( 80 | TestMergeBuckets, MergeBucketsTest, 81 | Values( 82 | // No Merge 83 | std::make_tuple(std::vector{1, 2, 3, 4, 5, 6, 7, 8}, 84 | std::bitset<7>(0b1111111), 85 | std::vector>{}), 86 | // Partial Merge 87 | std::make_tuple(std::vector{1, 2, 3, 4, 5, 6, 7, 8}, 88 | std::bitset<7>(0b0101101), 89 | std::vector>{ 90 | {2, 3}, {0, 2}, {0, 4}}), 91 | // Full Merge 92 | std::make_tuple( 93 | std::vector{1, 2, 3, 4, 5, 6, 7, 8}, std::bitset<7>(0b0), 94 | std::vector>{ 95 | {0, 1}, {2, 3}, {4, 5}, {6, 7}, {0, 2}, {4, 6}, {0, 4}}), 96 | // Partial Merge 2 97 | std::make_tuple(std::vector{50, 50, 20, 80, 10, 10, 70, 10}, 98 | std::bitset<7>(0b0001010), 99 | std::vector>{ 100 | {0, 1}, {4, 5}, {0, 2}, {4, 6}, {0, 4}}), 101 | // Partial Merge 3 102 | std::make_tuple(std::vector{90, 10, 90, 10, 90, 10, 90, 10}, 103 | std::bitset<7>(0b0101010), 104 | std::vector>{ 105 | {0, 1}, {4, 5}, {0, 2}, {0, 4}}))); 106 | 107 | } // namespace 108 | } // namespace jax_sc_embedding 109 | -------------------------------------------------------------------------------- /jax_tpu_embedding/sparsecore/lib/core/sparse_csr_input_stream_impl.h: -------------------------------------------------------------------------------- 1 | // Copyright 2024 The JAX SC Authors. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | #ifndef JAX_TPU_EMBEDDING_SPARSECORE_LIB_CORE_SPARSE_CSR_INPUT_STREAM_IMPL_H_ 15 | #define JAX_TPU_EMBEDDING_SPARSECORE_LIB_CORE_SPARSE_CSR_INPUT_STREAM_IMPL_H_ 16 | 17 | #include 18 | #include 19 | 20 | #include "absl/base/attributes.h" // from @com_google_absl 21 | #include "absl/strings/string_view.h" // from @com_google_absl 22 | 23 | namespace jax_sc_embedding { 24 | 25 | // Class to iterate over a sparse CSR array. 26 | // Example: 27 | // values = [1, 2, 3, 4, 5, 6] 28 | // row_pointers = [0, 2, 5, 6] 29 | // This represents a sparse matrix with 3 rows: 30 | // Row 0: [1, 2] 31 | // Row 1: [3, 4, 5] 32 | // Row 2: [6] 33 | // ValuesView and RowPointersView are template parameters that represent a view 34 | // into the underlying data (or even the actual data itself): 35 | // - ValuesView is required to support `operator[]`. 36 | // - RowPointersView is required to support `operator[]`. 37 | // This allows the class to be used with different types of data sources, such 38 | // as vectors, arrays, or other data structures. 39 | template 40 | class ABSL_ATTRIBUTE_VIEW SparseCsrInputBatchStream { 41 | public: 42 | // Ensures that ValuesView and RowPointersView are view-like types that are 43 | // cheap to copy. This prevents expensive copies of underlying data 44 | // containers (e.g. std::vector) and encourages passing views (e.g. 45 | // absl::Span) instead. 46 | static_assert(std::is_trivially_copyable_v, 47 | "ValuesView must be trivially copyable."); 48 | static_assert(std::is_trivially_copyable_v, 49 | "RowPointersView must be trivially copyable."); 50 | 51 | SparseCsrInputBatchStream( 52 | ValuesView values ABSL_ATTRIBUTE_LIFETIME_BOUND, 53 | RowPointersView row_pointers ABSL_ATTRIBUTE_LIFETIME_BOUND, int row_start, 54 | int row_end, absl::string_view table_name = "unknown_table_name", 55 | T max_vocab_id = std::numeric_limits::max()) 56 | : values_ref_(values), 57 | row_pointers_(row_pointers), 58 | row_start_(row_start), 59 | curr_row_(row_start), 60 | row_end_(row_end), 61 | curr_idx_(row_pointers[row_start]), 62 | curr_row_start_idx_(row_pointers[row_start]), 63 | max_vocab_id_(max_vocab_id), 64 | table_name_(table_name) { 65 | curr_row_cols_ = curr_row_ == row_end_ 66 | ? 0 67 | : row_pointers_[curr_row_ + 1] - curr_row_start_idx_; 68 | } 69 | 70 | // Returns number of values in current row. 71 | int cols() const { return curr_row_cols_; } 72 | 73 | void NextRow() { 74 | ++curr_row_; 75 | if (curr_row_ < row_end_) { 76 | curr_row_start_idx_ = row_pointers_[curr_row_]; 77 | curr_idx_ = curr_row_start_idx_; 78 | curr_row_cols_ = row_pointers_[curr_row_ + 1] - curr_row_start_idx_; 79 | } 80 | } 81 | 82 | void NextCol() { ++curr_idx_; } 83 | 84 | void SeekCol(int col) { curr_idx_ = curr_row_start_idx_ + col; } 85 | 86 | int row() const { return curr_row_; } 87 | 88 | int col() const { return curr_idx_ - curr_row_start_idx_; } 89 | 90 | T get() const { 91 | DCHECK_LT(curr_idx_, row_pointers_[curr_row_ + 1]); 92 | T embedding_id = values_ref_[curr_idx_]; 93 | DCHECK(embedding_id >= 0 && embedding_id <= max_vocab_id_) 94 | << "Invalid vocabulary id: " << embedding_id << " for table " 95 | << table_name_ << " with vocabulary size: " << max_vocab_id_; 96 | return embedding_id; 97 | } 98 | 99 | private: 100 | ValuesView values_ref_; 101 | RowPointersView row_pointers_; 102 | int row_start_; 103 | int curr_row_; 104 | int row_end_; 105 | int curr_idx_; 106 | 107 | // Cached values to avoid memory reads when processing a row. 108 | int curr_row_start_idx_; 109 | int curr_row_cols_; 110 | 111 | T max_vocab_id_; 112 | absl::string_view table_name_; 113 | }; 114 | 115 | } // namespace jax_sc_embedding 116 | 117 | #endif // JAX_TPU_EMBEDDING_SPARSECORE_LIB_CORE_SPARSE_CSR_INPUT_STREAM_IMPL_H_ 118 | -------------------------------------------------------------------------------- /jax_tpu_embedding/sparsecore/lib/proto/embedding_spec.proto: -------------------------------------------------------------------------------- 1 | // Copyright 2024 The JAX SC Authors. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | syntax = "proto3"; 15 | 16 | package third_party.py.jax_tpu_embedding.sparsecore_lib_proto; 17 | 18 | option java_multiple_files = true; 19 | 20 | message FeatureSpecProto { 21 | // The name of the user defined feature. 22 | string feature_name = 1; 23 | // The shape of the input training batch. 24 | repeated int64 input_shape = 2; 25 | // The expected shape of the output embedding lookup. 26 | repeated int64 output_shape = 3; 27 | // When multiple features are stacked, the `row offset` specifies the first 28 | // row of this feature lookup output. 29 | int64 row_offset = 4; 30 | // When multiple features are stacked, the `col_offset` specifies the first 31 | // row of vocabulary of this feature's tables inthe stacked table. 32 | int64 col_offset = 5; 33 | // The `col_shift` specifies how the embedding table shards are rotated on the 34 | // device. 35 | int64 col_shift = 6; 36 | } 37 | 38 | message TableSpecProto { 39 | // The name of the table. 40 | string table_name = 1; 41 | // The user defined vocab size of the table. 42 | int64 vocab_size = 2; 43 | // The user defined embedding dim of the table. 44 | int64 embedding_dim = 3; 45 | // The max number of ids per partition for the table. This is an input data 46 | // dependent value and is required by the compiler to appropriately allocate 47 | // memory. When table stacking is used, this value is overridden by the 48 | // max_ids_per_partition of the stacked table. 49 | int64 max_ids_per_partition = 5; 50 | // The max number of unique ids per partition for the table. This is an input 51 | // data dependent value and is required by the compiler to appropriately 52 | // allocate memory. When table stacking is used, this value is overridden by 53 | // the max_unique_ids_per_partition of the stacked table. 54 | int64 max_unique_ids_per_partition = 6; 55 | // The padded vocab size of the table. This is the vocab size rounded up to 56 | // the next multiple of 8 times number of sparsecores. 57 | int64 padded_vocab_size = 7; 58 | // The padded embedding dim of the table. This is the embedding dim rounded up 59 | // to the next multiple of 8. 60 | int64 padded_embedding_dim = 8; 61 | // When table stacking is used, this is the row offset of the table in the 62 | // stacked table shard. 0 otherwise. 63 | int64 row_offset_in_shard = 9; 64 | // When table stacking is used, this is the rotation of the table shard in the 65 | // stacked table shard. 0 otherwise. 66 | int64 shard_rotation = 10; 67 | // The list of features that point to this table. 68 | repeated FeatureSpecProto feature_specs = 11; 69 | } 70 | 71 | message StackedTableSpecProto { 72 | // The name of the stack. This is usually a concatenation of the table names 73 | // that are stacked. 74 | string stack_name = 1; 75 | // The vocab size of the stack. This is the sum of the vocab sizes (padded) 76 | // of all the tables in the stack. 77 | int64 stack_vocab_size = 2; 78 | // The embedding dim of the stack. This is the sum of the embedding dims 79 | // (padded) of all the tables in the stack. 80 | int64 stack_embedding_dim = 3; 81 | // The total number of samples (batch size) for the stack. This is the sum of 82 | // sample sizes (batch dimension) of all the features in the stack. 83 | int64 total_sample_count = 6; 84 | // The max number of ids per partition for the stack. This is an input data 85 | // dependent value and is required by the compiler to appropriately allocate 86 | // memory. 87 | int64 max_ids_per_partition = 7; 88 | // The max number of unique ids per partition for the stack. This is an input 89 | // data dependent value and is required by the compiler to appropriately 90 | // allocate memory. 91 | int64 max_unique_ids_per_partition = 8; 92 | // Total number of sparsecores used in the training setup. 93 | int64 num_sparsecores = 9; 94 | // Specs for the table that are stacked. 95 | repeated TableSpecProto table_specs = 10; 96 | } 97 | 98 | message EmbeddingSpecProto { 99 | // The list of stacked tables that represent layout of embedding tables 100 | // during training. 101 | repeated StackedTableSpecProto stacked_table_specs = 1; 102 | } 103 | -------------------------------------------------------------------------------- /jax_tpu_embedding/sparsecore/lib/core/primitives/local_sparse_dense_matmul.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The JAX SC Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Primitive local_tpu_sparse_dense_matmul.""" 15 | 16 | import functools 17 | import json 18 | 19 | import jax 20 | from jax import core 21 | import jax.extend as jex 22 | from jax.extend.mlir import ir 23 | from jax.extend.mlir.dialects import stablehlo as hlo 24 | from jax.interpreters import mlir 25 | from jax.interpreters import xla 26 | import jax.numpy as jnp 27 | import numpy as np 28 | 29 | # Define the local sparse dense matmul primitive. 30 | tpu_local_sparse_dense_matmul_primitive = jex.core.Primitive( 31 | "local_sparse_dense_matmul" 32 | ) 33 | 34 | 35 | # Define the impl function for the sparse dense matmul primitive. 36 | tpu_local_sparse_dense_matmul_primitive.def_impl( 37 | functools.partial( 38 | xla.apply_primitive, tpu_local_sparse_dense_matmul_primitive 39 | ) 40 | ) 41 | 42 | 43 | # Define the abstract eval function for the sparse dense matmul primitive. 44 | def _tpu_local_sparse_dense_matmul_abstract_eval( 45 | lhs_local_embedding_ids: jnp.ndarray, 46 | lhs_local_sample_ids: jnp.ndarray, 47 | lhs_gains: jnp.ndarray, 48 | embedding_table: jnp.ndarray, 49 | *_, 50 | device_batch_size: int, 51 | ): 52 | """Abstract eval for sdmm.""" 53 | 54 | if lhs_local_sample_ids.dtype != np.int32: 55 | raise ValueError( 56 | "lhs_local_sample_ids must have type int32, got" 57 | f" {lhs_local_sample_ids.dtype}" 58 | ) 59 | 60 | if lhs_local_embedding_ids.dtype != np.int32: 61 | raise ValueError( 62 | "lhs_local_embedding_ids must have type int32, got" 63 | f" {lhs_local_embedding_ids.dtype}" 64 | ) 65 | 66 | if lhs_gains.dtype != np.float32: 67 | raise ValueError(f"lhs_gains must have type float32, got {lhs_gains.dtype}") 68 | 69 | if embedding_table.dtype != np.float32: 70 | raise ValueError( 71 | f"embedding_table must have type float32, got {embedding_table.dtype}" 72 | ) 73 | 74 | if ( 75 | lhs_local_sample_ids.shape != lhs_local_embedding_ids.shape 76 | or lhs_gains.shape != lhs_local_embedding_ids.shape 77 | or len(lhs_local_sample_ids.shape) != 1 78 | ): 79 | raise ValueError( 80 | "LHS sample IDs, embedding IDs, and gains must all have " 81 | f"equal rank 1 shapes, got shapes {lhs_local_sample_ids.shape}, " 82 | f"{lhs_local_embedding_ids.shape} and {lhs_gains.shape}" 83 | ) 84 | 85 | if len(embedding_table.shape) != 2: 86 | raise ValueError( 87 | f"embedding_table must have rank 2, got {embedding_table.shape}" 88 | ) 89 | 90 | return core.ShapedArray( 91 | (device_batch_size, embedding_table.shape[1]), 92 | dtype=jnp.float32, 93 | ) 94 | 95 | 96 | tpu_local_sparse_dense_matmul_primitive.def_abstract_eval( 97 | _tpu_local_sparse_dense_matmul_abstract_eval 98 | ) 99 | 100 | 101 | # Define the mlir lowering rule for the local sparse dense matmul primitive. 102 | def _tpu_local_sparse_dense_matmul_lowering( 103 | ctx, 104 | lhs_local_embedding_ids: np.ndarray, 105 | lhs_local_sample_ids: np.ndarray, 106 | lhs_gains: np.ndarray, 107 | embedding_table: np.ndarray, 108 | *, 109 | device_batch_size: int, 110 | ) -> jnp.ndarray: 111 | """Lowering for tpu_sparse_dense_matmul.""" 112 | (out_aval,) = ctx.avals_out 113 | 114 | constant_op = hlo.constant(ir.DenseElementsAttr.get(np.float32(0.0))) 115 | activation_init = hlo.broadcast( 116 | constant_op, 117 | mlir.dense_int_array([ 118 | device_batch_size, 119 | ir.RankedTensorType(embedding_table.type).get_dim_size(1), 120 | ]), 121 | ) 122 | 123 | backend_config = json.dumps({ 124 | "device_type": "DEVICE_TYPE_SPARSECORE", 125 | }) 126 | 127 | operands = [ 128 | embedding_table, 129 | lhs_local_embedding_ids, 130 | lhs_local_sample_ids, 131 | lhs_gains, 132 | activation_init, 133 | ] 134 | 135 | return jax.ffi.ffi_lowering( 136 | "SparseDenseMatmulLocalOp", 137 | result_types=[mlir.aval_to_ir_type(out_aval)], 138 | api_version=1, 139 | backend_config=backend_config, 140 | skip_ffi_layout_processing=True, 141 | )( 142 | ctx, *operands 143 | ) # type: ignore 144 | 145 | 146 | mlir.register_lowering( 147 | tpu_local_sparse_dense_matmul_primitive, 148 | _tpu_local_sparse_dense_matmul_lowering, 149 | ) 150 | -------------------------------------------------------------------------------- /WORKSPACE: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The JAX SC Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | workspace(name = "jax_tpu_embedding") 15 | 16 | ############################################################################### 17 | ## XLA Initialization 18 | ############################################################################### 19 | # This is adapted from JAX's WORKSPACE file. 20 | 21 | # The XLA commit is determined by external/xla/workspace.bzl. 22 | load("//third_party/xla:workspace.bzl", xla_repo = "repo") 23 | 24 | xla_repo() 25 | 26 | load("@xla//:workspace4.bzl", "xla_workspace4") 27 | xla_workspace4() 28 | 29 | load("@xla//:workspace3.bzl", "xla_workspace3") 30 | xla_workspace3() 31 | 32 | # Initialize hermetic C++. 33 | load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") 34 | 35 | http_archive( 36 | name = "rules_ml_toolchain", 37 | sha256 = "1ae689a7681da56dc109ab0bfd9c2588b056bd494db65452dabfbb7068cebc30", 38 | strip_prefix = "rules_ml_toolchain-5b12b030160b5e0a4a2188f808bc732d3afd45bd", 39 | urls = [ 40 | "https://github.com/google-ml-infra/rules_ml_toolchain/archive/5b12b030160b5e0a4a2188f808bc732d3afd45bd.tar.gz", 41 | ], 42 | ) 43 | 44 | load( 45 | "@rules_ml_toolchain//cc/deps:cc_toolchain_deps.bzl", 46 | "cc_toolchain_deps", 47 | ) 48 | 49 | cc_toolchain_deps() 50 | 51 | register_toolchains("@rules_ml_toolchain//cc:linux_x86_64_linux_x86_64") 52 | register_toolchains("@rules_ml_toolchain//cc:linux_aarch64_linux_aarch64") 53 | 54 | # Initialize hermetic Python 55 | load("@xla//third_party/py:python_init_rules.bzl", "python_init_rules") 56 | 57 | python_init_rules() 58 | 59 | load("@xla//third_party/py:python_init_repositories.bzl", "python_init_repositories") 60 | 61 | python_init_repositories( 62 | default_python_version = "system", 63 | requirements = { 64 | "3.10": "//third_party/py:requirements_lock_3_10.txt", 65 | "3.11": "//third_party/py:requirements_lock_3_11.txt", 66 | "3.12": "//third_party/py:requirements_lock_3_12.txt", 67 | "3.13": "//third_party/py:requirements_lock_3_13.txt", 68 | }, 69 | ) 70 | 71 | load("@xla//third_party/py:python_init_toolchains.bzl", "python_init_toolchains") 72 | 73 | python_init_toolchains() 74 | 75 | load("//third_party/bazel/python:python_init_pip.bzl", "python_init_pip") 76 | 77 | python_init_pip() 78 | 79 | load("@pypi//:requirements.bzl", "install_deps") 80 | 81 | install_deps() 82 | 83 | # Load all XLA dependencies. 84 | load("@xla//:workspace2.bzl", "xla_workspace2") 85 | 86 | xla_workspace2() 87 | 88 | load("@xla//:workspace1.bzl", "xla_workspace1") 89 | 90 | xla_workspace1() 91 | 92 | load("@xla//:workspace0.bzl", "xla_workspace0") 93 | 94 | xla_workspace0() 95 | 96 | # Even though we don't use CUDA, this is required since it is needed 97 | # by TSL, one of our dependencies. 98 | load( 99 | "@rules_ml_toolchain//gpu/cuda:cuda_json_init_repository.bzl", 100 | "cuda_json_init_repository", 101 | ) 102 | cuda_json_init_repository() 103 | 104 | load( 105 | "@cuda_redist_json//:distributions.bzl", 106 | "CUDA_REDISTRIBUTIONS", 107 | "CUDNN_REDISTRIBUTIONS", 108 | ) 109 | load( 110 | "@rules_ml_toolchain//gpu/cuda:cuda_redist_init_repositories.bzl", 111 | "cuda_redist_init_repositories", 112 | "cudnn_redist_init_repository", 113 | ) 114 | 115 | cuda_redist_init_repositories( 116 | cuda_redistributions = CUDA_REDISTRIBUTIONS, 117 | ) 118 | 119 | cudnn_redist_init_repository( 120 | cudnn_redistributions = CUDNN_REDISTRIBUTIONS, 121 | ) 122 | 123 | load( 124 | "@rules_ml_toolchain//gpu/cuda:cuda_configure.bzl", 125 | "cuda_configure", 126 | ) 127 | 128 | cuda_configure(name = "local_config_cuda") 129 | 130 | ############################################################################### 131 | ## SparseCore-Specific Dependencies 132 | ############################################################################### 133 | 134 | HIGHWAY_VERSION = "1.2.0" 135 | HIGHWAY_SHA256 = "7e0be78b8318e8bdbf6fa545d2ecb4c90f947df03f7aadc42c1967f019e63343" 136 | HIGHWAY_ARCHIVE = "https://github.com/google/highway/archive/{version}.tar.gz".format(version = HIGHWAY_VERSION) 137 | http_archive( 138 | name = "highway", 139 | sha256 = HIGHWAY_SHA256, 140 | strip_prefix = "highway-{version}".format(version = HIGHWAY_VERSION), 141 | urls = [HIGHWAY_ARCHIVE], 142 | ) 143 | 144 | FUZZTEST_COMMIT = "0f82dad406f431ca5e8607626825be15423ba339" 145 | 146 | http_archive( 147 | name = "com_google_fuzztest", 148 | strip_prefix = "fuzztest-" + FUZZTEST_COMMIT, 149 | url = "https://github.com/google/fuzztest/archive/" + FUZZTEST_COMMIT + ".zip", 150 | ) 151 | -------------------------------------------------------------------------------- /jax_tpu_embedding/sparsecore/lib/core/primitives/optimizers_computation.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The JAX SC Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Defines common optimizers for embedding lookups.""" 15 | 16 | from jax.extend.mlir import ir 17 | from jax.extend.mlir.dialects import func as func_dialect 18 | from jax.extend.mlir.dialects import stablehlo as hlo 19 | from jax.interpreters import mlir 20 | 21 | 22 | def sgd( 23 | ctx: mlir.LoweringRuleContext, computation_name: str, dim_size: int 24 | ) -> None: 25 | """A Callable SGD lowering.""" 26 | optimizer_update = func_dialect.FuncOp( 27 | computation_name, 28 | ( 29 | [ 30 | ir.RankedTensorType.get( 31 | [1, dim_size], 32 | ir.F32Type.get(), 33 | ), 34 | ir.RankedTensorType.get( 35 | [1, dim_size], 36 | ir.F32Type.get(), 37 | ), 38 | ir.RankedTensorType.get( 39 | [1, dim_size], 40 | ir.F32Type.get(), 41 | ), 42 | ], 43 | [ 44 | ir.TupleType.get_tuple( 45 | [ 46 | ir.RankedTensorType.get( 47 | [1, dim_size], 48 | ir.F32Type.get(), 49 | ) 50 | ] 51 | ), 52 | ], 53 | ), 54 | ip=ctx.module_context.ip, 55 | visibility="private", 56 | ) 57 | entry_block = optimizer_update.add_entry_block() 58 | with ir.InsertionPoint(entry_block): 59 | # lr * grad 60 | gradient_update = hlo.multiply( 61 | entry_block.arguments[0], 62 | entry_block.arguments[2], 63 | ) 64 | # updated_embedding_table = embedding_table - lr * grad 65 | updated_embedding_table = hlo.subtract( 66 | entry_block.arguments[1], gradient_update 67 | ) 68 | updated_embedding_tables = hlo.tuple([updated_embedding_table]) 69 | func_dialect.ReturnOp([updated_embedding_tables]) 70 | 71 | 72 | def adagrad( 73 | ctx: mlir.LoweringRuleContext, computation_name: str, dim_size: int 74 | ) -> None: 75 | """A callable Adagrad lowering. 76 | 77 | When using this optimizer, the expected ordering of the embedding variables is 78 | 0. embedding_table 79 | 1. accumulator 80 | 81 | Args: 82 | ctx: The lowering rule context. 83 | computation_name: The name of the computation. 84 | dim_size: The dimension of the embedding table. 85 | """ 86 | optimizer_update = func_dialect.FuncOp( 87 | computation_name, 88 | ( 89 | [ 90 | ir.RankedTensorType.get( 91 | [1, dim_size], 92 | ir.F32Type.get(), 93 | ), 94 | ir.RankedTensorType.get( 95 | [1, dim_size], 96 | ir.F32Type.get(), 97 | ), 98 | ir.RankedTensorType.get( 99 | [1, dim_size], 100 | ir.F32Type.get(), 101 | ), 102 | ir.RankedTensorType.get( 103 | [1, dim_size], 104 | ir.F32Type.get(), 105 | ), 106 | ], 107 | [ 108 | ir.TupleType.get_tuple([ 109 | ir.RankedTensorType.get( 110 | [1, dim_size], 111 | ir.F32Type.get(), 112 | ), 113 | ir.RankedTensorType.get( 114 | [1, dim_size], 115 | ir.F32Type.get(), 116 | ), 117 | ]), 118 | ], 119 | ), 120 | ip=ctx.module_context.ip, 121 | visibility="private", 122 | ) 123 | 124 | entry_block = optimizer_update.add_entry_block() 125 | with ir.InsertionPoint(entry_block): 126 | # new_accumulator = accumulator + grad * grad 127 | grad_squared = hlo.multiply( 128 | entry_block.arguments[0], 129 | entry_block.arguments[0], 130 | ) 131 | new_accumulator = hlo.add( 132 | entry_block.arguments[2], 133 | grad_squared, 134 | ) 135 | updated_embedding_table = hlo.subtract( 136 | entry_block.arguments[1], 137 | hlo.divide( 138 | hlo.multiply( 139 | entry_block.arguments[3], 140 | entry_block.arguments[0], 141 | ), 142 | hlo.sqrt(new_accumulator), 143 | ), 144 | ) 145 | updated_embedding_tables = hlo.tuple( 146 | [updated_embedding_table, new_accumulator] 147 | ) 148 | func_dialect.ReturnOp([updated_embedding_tables]) 149 | -------------------------------------------------------------------------------- /jax_tpu_embedding/sparsecore/lib/core/grpc/BUILD: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The JAX SC Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | load("@com_github_grpc_grpc//bazel:cc_grpc_library.bzl", "cc_grpc_library") 15 | load("@com_google_protobuf//bazel:cc_proto_library.bzl", "cc_proto_library") 16 | load("@com_google_protobuf//bazel:proto_library.bzl", "proto_library") 17 | load("@rules_cc//cc:cc_library.bzl", "cc_library") 18 | load("@rules_cc//cc:cc_test.bzl", "cc_test") 19 | load("//jax_tpu_embedding/sparsecore:jax_tpu_embedding.bzl", "CORE_USERS") 20 | 21 | package( 22 | default_applicable_licenses = ["//:license"], 23 | default_visibility = CORE_USERS, 24 | ) 25 | 26 | proto_library( 27 | name = "all_reduce_proto", 28 | srcs = ["all_reduce.proto"], 29 | # has_services = True, 30 | ) 31 | 32 | cc_proto_library( 33 | name = "all_reduce_cc_proto", 34 | deps = [":all_reduce_proto"], 35 | ) 36 | 37 | cc_grpc_library( 38 | name = "all_reduce_cc_grpc_proto", 39 | srcs = [":all_reduce_proto"], 40 | grpc_only = True, 41 | deps = [":all_reduce_cc_proto"], 42 | ) 43 | 44 | cc_library( 45 | name = "grpc_all_reduce_interface", 46 | srcs = ["all_reduce_interface.cc"], 47 | hdrs = ["all_reduce_interface.h"], 48 | deps = [ 49 | ":all_reduce_cc_grpc_proto", 50 | ":all_reduce_service_impl", 51 | ":grpc_credentials", 52 | "//jax_tpu_embedding/sparsecore/lib/core:all_reduce_interface", 53 | "@com_github_grpc_grpc//:grpc++", 54 | "@com_google_absl//absl/base:core_headers", 55 | "@com_google_absl//absl/container:flat_hash_map", 56 | "@com_google_absl//absl/log", 57 | "@com_google_absl//absl/log:check", 58 | "@com_google_absl//absl/status", 59 | "@com_google_absl//absl/status:statusor", 60 | "@com_google_absl//absl/strings", 61 | "@com_google_absl//absl/synchronization", 62 | "@com_google_absl//absl/time", 63 | "@tsl//tsl/platform:errors", 64 | "@tsl//tsl/platform:statusor", 65 | "@tsl//tsl/profiler/lib:traceme", 66 | ], 67 | ) 68 | 69 | cc_library( 70 | name = "all_reduce_service_impl", 71 | srcs = ["all_reduce_service_impl.cc"], 72 | hdrs = ["all_reduce_service_impl.h"], 73 | deps = [ 74 | ":all_reduce_cc_grpc_proto", 75 | ":all_reduce_cc_proto", 76 | "@com_github_grpc_grpc//:grpc++", 77 | "@com_google_absl//absl/base:core_headers", 78 | "@com_google_absl//absl/container:flat_hash_map", 79 | "@com_google_absl//absl/log", 80 | "@com_google_absl//absl/log:check", 81 | "@com_google_absl//absl/status", 82 | "@com_google_absl//absl/status:statusor", 83 | "@com_google_absl//absl/synchronization", 84 | "@com_google_absl//absl/time", 85 | ], 86 | ) 87 | 88 | cc_library( 89 | name = "grpc_credentials", 90 | hdrs = ["grpc_credentials.h"], 91 | deps = [ 92 | "@com_github_grpc_grpc//:grpc++", 93 | ], 94 | ) 95 | 96 | cc_library( 97 | name = "minibatching_node", 98 | hdrs = ["minibatching_node.h"], 99 | deps = [ 100 | ":all_reduce_service_impl", 101 | ":grpc_all_reduce_interface", 102 | ":grpc_credentials", 103 | "//jax_tpu_embedding/sparsecore/lib/core:all_reduce_interface", 104 | "//jax_tpu_embedding/sparsecore/lib/core/grpc/oss:grpc_credentials", # buildcleaner: keep 105 | "@com_github_grpc_grpc//:grpc++", 106 | "@com_google_absl//absl/base:core_headers", 107 | "@com_google_absl//absl/strings", 108 | ], 109 | ) 110 | 111 | cc_test( 112 | name = "all_reduce_test", 113 | srcs = ["all_reduce_test.cc"], 114 | deps = [ 115 | ":all_reduce_cc_proto", 116 | ":all_reduce_service_impl", 117 | ":grpc_all_reduce_interface", 118 | ":grpc_credentials", 119 | ":minibatching_node", 120 | "//jax_tpu_embedding/sparsecore/lib/core:minibatching_test_utils", 121 | "//jax_tpu_embedding/sparsecore/lib/core/grpc/oss:grpc_credentials", # buildcleaner: keep 122 | "@com_github_grpc_grpc//:grpc++", 123 | "@com_google_absl//absl/algorithm:container", 124 | "@com_google_absl//absl/log:check", 125 | "@com_google_absl//absl/status", 126 | "@com_google_absl//absl/status:status_matchers", 127 | "@com_google_absl//absl/status:statusor", 128 | "@com_google_absl//absl/strings", 129 | "@com_google_absl//absl/synchronization", 130 | "@com_google_fuzztest//fuzztest", 131 | "@com_google_fuzztest//fuzztest:googletest_fixture_adapter", 132 | "@com_google_googletest//:gtest_main", 133 | "@tsl//tsl/platform:env", 134 | "@tsl//tsl/platform:test", 135 | ], 136 | ) 137 | -------------------------------------------------------------------------------- /jax_tpu_embedding/sparsecore/lib/core/minibatching_splits_impl.h: -------------------------------------------------------------------------------- 1 | // Copyright 2024 The JAX SC Authors. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | #ifndef JAX_TPU_EMBEDDING_SPARSECORE_LIB_CORE_MINIBATCHING_SPLITS_IMPL_H_ 15 | #define JAX_TPU_EMBEDDING_SPARSECORE_LIB_CORE_MINIBATCHING_SPLITS_IMPL_H_ 16 | 17 | #include 18 | #include 19 | #include 20 | #include 21 | 22 | #include "absl/functional/function_ref.h" // from @com_google_absl 23 | #include "absl/log/check.h" // from @com_google_absl 24 | #include "absl/numeric/bits.h" // from @com_google_absl 25 | #include "absl/types/span.h" // from @com_google_absl 26 | #include "jax_tpu_embedding/sparsecore/lib/core/coo_format.h" 27 | 28 | namespace jax_sc_embedding { 29 | namespace internal { 30 | 31 | // Computes a minibatching split indicator for a given set of unique IDs per 32 | // bucket. 33 | // 34 | // Example: 35 | // Suppose we have 8 buckets and `max_unique_ids_per_partition` is 7. 36 | // unique_ids_per_bucket = {2, 5, 1, 6, 3, 4, 2, 0} 37 | // 38 | // The binary tree computation proceeds as follows: 39 | // - Level 1: 40 | // - {2, 5} -> 7 <= 7, no split, unique_ids_per_bucket[0] = 7 41 | // - {1, 6} -> 7 <= 7, no split, unique_ids_per_bucket[2] = 7 42 | // - {3, 4} -> 7 <= 7, no split, unique_ids_per_bucket[4] = 7 43 | // - {2, 0} -> 2 <= 7, no split, unique_ids_per_bucket[6] = 2 44 | // - Level 2: 45 | // - {7, 7} -> 14 > 7, split, split_index = 4 46 | // - {7, 2} -> 9 > 7, split, split_index = 5 47 | // - Level 3: 48 | // - {14, 9} -> 23 > 7, split, split_index = 6 49 | // 50 | // The resulting `split` bitset is 0b1110000. 51 | // 52 | // WARNING: This function modifies the input `unique_ids_per_bucket` by 53 | // combining counts of sibling buckets for efficiency. 54 | // 55 | // Example for subtree_size: 56 | // - subtree_size = 2: {0,1} {2,3} {4,5}, ... 57 | // - subtree_size = 4: {0,2} {4,6} {8,10}, ... 58 | // - ... 59 | // - subtree_size = N: {0, N/2} 60 | // If two buckets are merged, the count is combined into the left sibling and 61 | // propagated up the tree. 62 | template 63 | std::bitset ComputeMinibatchingSplit( 64 | absl::Span unique_ids_per_bucket, 65 | int max_unique_ids_per_partition) { 66 | static_assert(absl::has_single_bit(N)); 67 | DCHECK_EQ(unique_ids_per_bucket.size(), N); 68 | std::bitset split = 0; 69 | int split_index = 0; 70 | for (int subtree_size = 2; subtree_size <= N; subtree_size *= 2) { 71 | for (int i = 0; i < N; i += subtree_size, ++split_index) { 72 | const int val_left = unique_ids_per_bucket[i]; 73 | const int val_right = unique_ids_per_bucket[i + subtree_size / 2]; 74 | if (val_left + val_right > max_unique_ids_per_partition) { 75 | split.set(split_index); 76 | unique_ids_per_bucket[i] = std::max(val_left, val_right); 77 | } else { 78 | unique_ids_per_bucket[i] += val_right; 79 | } 80 | } 81 | } 82 | return split; 83 | } 84 | 85 | using MergeFn = absl::FunctionRef; 86 | struct NoOpMerge { 87 | void operator()(int, int) const {} 88 | }; 89 | 90 | // Merges buckets based on a binary tree split indicator. 91 | // The `split` bitset indicates which nodes of the binary tree should be split. 92 | // If a node is not split, its left and right children are merged using the 93 | // provided `merge_fun`. 94 | // 95 | // Example: 96 | // Input split: 0b1100011 (split index {0,1,5,6}) 97 | // This represents splits between: 98 | // - {0,1} 99 | // - {2,3} 100 | // - {5,6} 101 | // - {3,4} 102 | // This function will call `merge_fun` for the following pairs of buckets: 103 | // - index = 2: {4,5} 104 | // - index = 3: {6,7} 105 | // - index = 4: {1,2} 106 | // 107 | // pos 0 1 2 3 4 5 6 7 108 | // split index 0 1 2 3 109 | // split index 4 5 110 | // split index 6 111 | // 112 | // The `merge_fun` is called with the indices of the buckets to be merged. 113 | template 114 | void MergeBuckets(std::bitset split, MergeFn merge_fun = NoOpMerge()) { 115 | static_assert(absl::has_single_bit(N)); 116 | int split_index = 0; 117 | for (int subtree_size = 2; subtree_size <= N; subtree_size *= 2) { 118 | for (int i = 0; i < N; i += subtree_size, ++split_index) { 119 | const int right_index = i + subtree_size / 2; 120 | const int left_index = i; 121 | // Merge/Split left subtree with right subtree. 122 | // Implementations decide how to do the merging, for input preprocessing, 123 | // we merge the last unmerged node in left subtree with first unmerged 124 | // node in right subtree. 125 | if (!split.test(split_index)) { 126 | merge_fun(left_index, right_index); 127 | } 128 | } 129 | } 130 | } 131 | 132 | } // namespace internal 133 | } // namespace jax_sc_embedding 134 | 135 | #endif // JAX_TPU_EMBEDDING_SPARSECORE_LIB_CORE_MINIBATCHING_SPLITS_IMPL_H_ 136 | -------------------------------------------------------------------------------- /jax_tpu_embedding/sparsecore/lib/core/primitives/tests/local_sparse_dense_matmul_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The JAX SC Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from absl.testing import absltest 15 | import jax 16 | import jax.numpy as jnp 17 | from jax_tpu_embedding.sparsecore.lib.core.primitives import local_sparse_dense_matmul 18 | from jax_tpu_embedding.sparsecore.utils import utils 19 | import numpy as np 20 | 21 | jax.config.update("jax_enable_x64", True) 22 | 23 | 24 | class SparseDenseMatmulCsrTest(absltest.TestCase): 25 | 26 | def setUp(self): 27 | super().setUp() 28 | self.num_chips = 1 29 | self.batch_size = 16 30 | self.vocab_size = 32 31 | self.emb_size = 8 32 | self.num_sc_per_device = utils.num_sparsecores_per_device(jax.devices()[0]) 33 | self.embedding_ids = np.asarray( 34 | [5, 3, 9, 1, 6, 12, 0, 4, 15, 13, 11, 7, 8, 14, 2, 10], 35 | dtype=np.int64, 36 | ) 37 | self.sample_ids = np.arange(self.batch_size, dtype=np.int32) 38 | self.gains = np.ones_like(self.embedding_ids, dtype=np.float32) 39 | # Define the embedding table. 40 | self.emb_table = ( 41 | np.array( 42 | [[i for _ in range(self.emb_size)] for i in range(self.vocab_size)] 43 | ) 44 | .reshape(self.vocab_size, self.emb_size) 45 | .astype(np.float32) 46 | ) 47 | 48 | self.tpu_local_sparse_dense_matmul = jax.named_call( 49 | local_sparse_dense_matmul.tpu_local_sparse_dense_matmul_primitive.bind, 50 | name="tpu_local_sparse_dense_matmul", 51 | ) 52 | 53 | def test_sc_emb_forward_pass_invalid_input_dtypes(self): 54 | with self.subTest("invalid_local_embedding_ids_type"): 55 | with self.assertRaises(ValueError): 56 | self.tpu_local_sparse_dense_matmul( 57 | self.embedding_ids.astype(jnp.float64), 58 | self.sample_ids, 59 | self.gains, 60 | self.emb_table, 61 | device_batch_size=self.batch_size, 62 | ) 63 | 64 | with self.subTest("invalid_local_sample_ids_type"): 65 | with self.assertRaises(ValueError): 66 | self.tpu_local_sparse_dense_matmul( 67 | self.embedding_ids, 68 | self.sample_ids.astype(jnp.float32), 69 | self.gains, 70 | self.emb_table, 71 | device_batch_size=self.batch_size, 72 | ) 73 | 74 | with self.subTest("invalid_gains_type"): 75 | with self.assertRaises(ValueError): 76 | self.tpu_local_sparse_dense_matmul( 77 | self.embedding_ids, 78 | self.sample_ids, 79 | self.gains.astype(jnp.int32), 80 | self.emb_table, 81 | device_batch_size=self.batch_size, 82 | ) 83 | 84 | with self.subTest("invalid_emb_table_type"): 85 | with self.assertRaises(ValueError): 86 | self.tpu_local_sparse_dense_matmul( 87 | self.embedding_ids, 88 | self.sample_ids, 89 | self.gains, 90 | self.emb_table.astype(jnp.int32), 91 | device_batch_size=self.batch_size, 92 | ) 93 | 94 | def test_sc_emb_forward_pass_invalid_input_shapes(self): 95 | with self.subTest("invalid_sample_id_shape"): 96 | ids = self.embedding_ids.reshape(4, 4) 97 | with self.assertRaises(ValueError): 98 | self.tpu_local_sparse_dense_matmul( 99 | ids, 100 | self.sample_ids, 101 | self.gains, 102 | self.emb_table, 103 | device_batch_size=self.batch_size, 104 | ) 105 | 106 | def test_sc_emb_forward_pass(self): 107 | # Do the embedding lookup. 108 | emb_activations = self.tpu_local_sparse_dense_matmul( 109 | jnp.asarray(self.embedding_ids, dtype=jnp.int32), 110 | jnp.asarray(self.sample_ids, dtype=jnp.int32), 111 | jnp.asarray(self.gains, dtype=jnp.float32), 112 | jnp.asarray(self.emb_table, dtype=jnp.float32), 113 | device_batch_size=self.batch_size // self.num_chips, 114 | ) 115 | 116 | # Check the embedding activations. 117 | expected_emb_activations = np.array( 118 | [ 119 | [5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0], 120 | [3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0], 121 | [9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0], 122 | [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 123 | [6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0], 124 | [12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0], 125 | [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 126 | [4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0], 127 | [15.0, 15.0, 15.0, 15.0, 15.0, 15.0, 15.0, 15.0], 128 | [13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0], 129 | [11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0], 130 | [7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0], 131 | [8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0], 132 | [14.0, 14.0, 14.0, 14.0, 14.0, 14.0, 14.0, 14.0], 133 | [2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0], 134 | [10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0], 135 | ], 136 | dtype=np.float32, 137 | ) 138 | 139 | np.testing.assert_equal(emb_activations, expected_emb_activations) 140 | 141 | 142 | if __name__ == "__main__": 143 | absltest.main() 144 | -------------------------------------------------------------------------------- /jax_tpu_embedding/sparsecore/lib/auto_pipelining/preprocess.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The JAX SC Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Preprocessor for autopipelining.""" 15 | 16 | import jax.extend as jex 17 | 18 | from jax_tpu_embedding.sparsecore.lib.auto_pipelining import utils 19 | 20 | 21 | def _has_permitive(eqn: jex.core.JaxprEqn, primitive_name_prefix: str) -> bool: 22 | """Checks if a JaxprEqn contains a primitive with the given prefix. 23 | 24 | This function recursively checks the equation and any nested Jaxprs (e.g., in 25 | conditionals or loops) for the presence of the primitive. 26 | 27 | Args: 28 | eqn: The JaxprEqn to check. 29 | primitive_name_prefix: The prefix of the primitive name to search for. 30 | 31 | Returns: 32 | True if the primitive is found, False otherwise. 33 | """ 34 | if eqn.primitive.name.startswith(primitive_name_prefix): 35 | return True 36 | for param in eqn.params.values(): 37 | if isinstance(param, jex.core.ClosedJaxpr) or isinstance( 38 | param, jex.core.Jaxpr 39 | ): 40 | if any(_has_permitive(eqn, primitive_name_prefix) for eqn in param.eqns): 41 | return True 42 | return False 43 | 44 | 45 | def _has_embedding_lookup(eqn: jex.core.JaxprEqn) -> bool: 46 | """Checks if a JaxprEqn contains an embedding lookup operation.""" 47 | return _has_permitive(eqn, utils.EMBEDDING_LOOKUP_PRIMITIVE_PREFIX) 48 | 49 | 50 | def _has_embedding_update(eqn: jex.core.JaxprEqn) -> bool: 51 | """Checks if a JaxprEqn contains an embedding update operation.""" 52 | return _has_permitive(eqn, utils.EMBEDDING_UPDATE_PRIMITIVE_PREFIX) 53 | 54 | 55 | def _inline_custom_vjp(jaxpr: jex.core.Jaxpr) -> jex.core.Jaxpr: 56 | """Inlines embedding lookup inside custom_vjp_call_jaxpr.""" 57 | eqns = [] 58 | for eqn in jaxpr.eqns: 59 | if eqn.primitive.name == 'custom_vjp_call' and _has_embedding_lookup( 60 | eqn 61 | ): 62 | eqns.extend( 63 | utils.inline_jaxpr( 64 | eqn.params['call_jaxpr'].jaxpr, eqn.invars, eqn.outvars 65 | ) 66 | ) 67 | else: 68 | eqns.append(eqn) 69 | return jaxpr.replace(eqns=eqns) 70 | 71 | 72 | def _validate_embedding_lookup(eqn: jex.core.JaxprEqn) -> None: 73 | """Validates whether the embedding lookups can be transformed.""" 74 | # shard_map should be on the top level so that we can combine them. 75 | assert ( 76 | eqn.primitive.name == utils.SHARD_MAP_PRIMITIVE_NAME 77 | ), 'Embedding lookup should be wrapped directly by shard_map' 78 | # lookup primitive should be the first equation in the shard_map, for easy 79 | # check in the later transformations. 80 | jaxpr = eqn.params['jaxpr'] 81 | # This will assert if there's no lookup operation in the shard_map. 82 | lookup_eqn = utils.get_embedding_lookup_eqn(jaxpr.eqns) 83 | 84 | # Embedding table should be the last input of the lookup shard_map. 85 | # Slot variables are not used for embedding lookup. 86 | assert ( 87 | lookup_eqn.invars[-1] == jaxpr.invars[utils.EMBEDDING_LOOKUP_DATA_LEN] 88 | ), 'Embedding table should be the last input of the lookup shard_map' 89 | 90 | 91 | def _validate_embedding_update(eqn: jex.core.JaxprEqn) -> None: 92 | """Validates whether the embedding updates can be transformed.""" 93 | # shard_map should be on the top level so that we can combine them. 94 | assert ( 95 | eqn.primitive.name == utils.SHARD_MAP_PRIMITIVE_NAME 96 | ), 'Embedding update should be wrapped directly by shard_map' 97 | # update primitive should be the last equation in the shard_map, for easy 98 | # check in the later transformations. 99 | jaxpr = eqn.params['jaxpr'] 100 | update_eqn = jaxpr.eqns[-1] 101 | assert update_eqn.primitive.name.startswith( 102 | utils.EMBEDDING_UPDATE_PRIMITIVE_PREFIX 103 | ), ( 104 | 'The last equation in the shard_map is not an embedding update. ' 105 | f'Got {update_eqn.primitive.name}' 106 | ) 107 | 108 | # The embedding table should be the last input of the update shard_map. 109 | # Used when passing updates from dense to SC backward. 110 | embed_tables = jaxpr.invars[utils.EMBEDDING_UPDATE_DATA_LEN :] 111 | assert ( 112 | update_eqn.invars[utils.EMBEDDING_UPDATE_DATA_LEN - 1 :][ 113 | : len(embed_tables) 114 | ] 115 | == embed_tables 116 | ), 'Embedding table should be the last input of the update shard_map' 117 | 118 | # Embedding table should be the only output of the update shard_map. 119 | # This is used when we combine the lookup and update shard_maps. 120 | assert ( 121 | update_eqn.outvars == jaxpr.outvars 122 | ), 'Embedding table should be the first output of the update shard_map' 123 | 124 | 125 | def validate_jaxpr(jaxpr: jex.core.Jaxpr) -> None: 126 | """Validates the structure of the Jaxpr for auto-pipelining.""" 127 | for eqn in jaxpr.eqns: 128 | has_lookup = _has_embedding_lookup(eqn) 129 | has_update = _has_embedding_update(eqn) 130 | assert not ( 131 | has_lookup and has_update 132 | ), 'Embedding lookup and update should not be in the same equation' 133 | if has_lookup: 134 | _validate_embedding_lookup(eqn) 135 | if has_update: 136 | _validate_embedding_update(eqn) 137 | 138 | 139 | def preprocess(jaxpr: jex.core.Jaxpr) -> jex.core.Jaxpr: 140 | """Preprocesses the Jaxpr for auto-pipelining.""" 141 | jaxpr = _inline_custom_vjp(jaxpr) 142 | validate_jaxpr(jaxpr) 143 | return jaxpr 144 | -------------------------------------------------------------------------------- /jax_tpu_embedding/sparsecore/lib/fdo/file_fdo_client_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The JAX SC Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Unit tests for file based FDO client.""" 15 | 16 | import os 17 | 18 | from absl.testing import absltest 19 | from jax_tpu_embedding.sparsecore.lib.fdo import file_fdo_client 20 | from jax_tpu_embedding.sparsecore.lib.nn import embedding 21 | import numpy as np 22 | 23 | 24 | class NpzFdoClientTest(absltest.TestCase): 25 | 26 | def setUp(self): 27 | super().setUp() 28 | self.base_dir = self.create_tempdir( 29 | cleanup=absltest.TempFileCleanup.OFF 30 | ).full_path 31 | 32 | def _assert_stats_equal(self, actual, expected): 33 | self.assertLen(actual, len(expected)) 34 | for key in expected: 35 | self.assertIn(key, actual) 36 | np.testing.assert_array_equal(expected[key], actual[key]) 37 | 38 | def test_record_and_publish_load(self): 39 | fdo_client = file_fdo_client.NPZFileFDOClient(self.base_dir) 40 | stats = embedding.SparseDenseMatmulInputStats( 41 | max_ids_per_partition={"tab_one": np.array([10, 20, 30, 40])}, 42 | max_unique_ids_per_partition={"tab_one": np.array([1, 2, 3, 4])}, 43 | required_buffer_size_per_sc={}, 44 | ) 45 | fdo_client.record(stats) 46 | fdo_client.publish() 47 | loaded_stats = fdo_client.load() 48 | self._assert_stats_equal( 49 | loaded_stats.max_ids_per_partition, stats.max_ids_per_partition 50 | ) 51 | self._assert_stats_equal( 52 | loaded_stats.max_unique_ids_per_partition, 53 | stats.max_unique_ids_per_partition, 54 | ) 55 | self._assert_stats_equal( 56 | loaded_stats.required_buffer_size_per_sc, 57 | stats.required_buffer_size_per_sc, 58 | ) 59 | 60 | def test_multiple_record(self): 61 | fdo_client = file_fdo_client.NPZFileFDOClient(self.base_dir) 62 | stats = embedding.SparseDenseMatmulInputStats( 63 | max_ids_per_partition={"tab_one": np.array([10, 20, 30, 40])}, 64 | max_unique_ids_per_partition={"tab_one": np.array([1, 2, 3, 4])}, 65 | required_buffer_size_per_sc={"tab_one": np.array([256])}, 66 | ) 67 | fdo_client.record(stats) 68 | fdo_client.record(stats) # intentional 69 | fdo_client.publish() 70 | stats = fdo_client.load() 71 | 72 | self._assert_stats_equal( 73 | stats.max_ids_per_partition, 74 | {"tab_one": np.array([[10, 20, 30, 40], [10, 20, 30, 40]])}, 75 | ) 76 | self._assert_stats_equal( 77 | stats.max_unique_ids_per_partition, 78 | {"tab_one": np.array([[1, 2, 3, 4], [1, 2, 3, 4]])}, 79 | ) 80 | self._assert_stats_equal( 81 | stats.required_buffer_size_per_sc, {"tab_one": np.array([[256], [256]])} 82 | ) 83 | 84 | def test_load_multiple_files(self): 85 | base_dir = self.create_tempdir().full_path 86 | np.savez( 87 | os.path.join(base_dir, "fdo_stats_0_10.npz"), 88 | **{ 89 | "t_one_max_ids": np.array([10, 20, 30, 40]), 90 | "t_two_max_ids": np.array([50, 60, 70, 80]), 91 | "t_one_max_unique_ids": np.array([1, 2, 3, 4]), 92 | "t_two_max_unique_ids": np.array([5, 6, 7, 8]), 93 | "t_one_required_buffer_size": np.array([64]), 94 | "t_two_required_buffer_size": np.array([128]), 95 | }, 96 | ) 97 | np.savez( 98 | os.path.join(base_dir, "fdo_stats_1_11.npz"), 99 | **{ 100 | "t_one_max_ids": np.array([20, 10, 40, 30]), 101 | "t_two_max_ids": np.array([60, 60, 80, 70]), 102 | "t_one_max_unique_ids": np.array([2, 1, 4, 3]), 103 | "t_two_max_unique_ids": np.array([6, 5, 8, 7]), 104 | "t_one_required_buffer_size": np.array([128]), 105 | "t_two_required_buffer_size": np.array([256]), 106 | }, 107 | ) 108 | 109 | fdo_client = file_fdo_client.NPZFileFDOClient(base_dir) 110 | loaded_stats = fdo_client.load() 111 | self._assert_stats_equal( 112 | loaded_stats.max_ids_per_partition, 113 | { 114 | "t_one": np.array([20, 20, 40, 40]), 115 | "t_two": np.array([60, 60, 80, 80]), 116 | }, 117 | ) 118 | self._assert_stats_equal( 119 | loaded_stats.max_unique_ids_per_partition, 120 | { 121 | "t_one": np.array([2, 2, 4, 4]), 122 | "t_two": np.array([6, 6, 8, 8]), 123 | }, 124 | ) 125 | self._assert_stats_equal( 126 | loaded_stats.required_buffer_size_per_sc, 127 | {"t_one": np.array([128]), "t_two": np.array([256])}, 128 | ) 129 | 130 | def test_files_not_found(self): 131 | fdo_client = file_fdo_client.NPZFileFDOClient(self.base_dir) 132 | with self.assertRaises(FileNotFoundError): 133 | _ = fdo_client.load() 134 | 135 | def test_latest_files_by_process(self): 136 | files = [ 137 | "temp/fdo_dumps/fdo_stats_0_10.npz", 138 | "temp/fdo_dumps/fdo_stats_0_20.npz", 139 | "temp/fdo_dumps/fdo_stats_0_30.npz", 140 | "temp/fdo_dumps/fdo_stats_1_09.npz", 141 | "temp/fdo_dumps/fdo_stats_0_40.npz", 142 | "temp/fdo_dumps/fdo_stats_2_10.npz", 143 | ] 144 | fdo_client = file_fdo_client.NPZFileFDOClient(self.base_dir) 145 | latest_files = fdo_client._get_latest_files_by_process(files) 146 | self.assertEqual( 147 | latest_files, 148 | [ 149 | "temp/fdo_dumps/fdo_stats_2_10.npz", 150 | "temp/fdo_dumps/fdo_stats_1_09.npz", 151 | "temp/fdo_dumps/fdo_stats_0_40.npz", 152 | ], 153 | ) 154 | 155 | 156 | if __name__ == "__main__": 157 | absltest.main() 158 | --------------------------------------------------------------------------------