├── AUTHORS
├── clu
├── run_pytest.google.sh
├── __init__.py
├── internal
│ ├── __init__.py
│ ├── utils.py
│ └── utils_test.py
├── metric_writers
│ ├── tf
│ │ ├── __init__.py
│ │ ├── summary_writer.py
│ │ └── summary_writer_test.py
│ ├── summary_writer.py
│ ├── multi_writer_test.py
│ ├── __init__.py
│ ├── multi_writer.py
│ ├── torch_tensorboard_writer.py
│ ├── torch_tensorboard_writer_test.py
│ ├── logging_writer_test.py
│ ├── async_writer_test.py
│ ├── utils.py
│ ├── logging_writer.py
│ ├── async_writer.py
│ ├── utils_test.py
│ └── interface.py
├── data
│ ├── __init__.py
│ ├── dataset_iterator_test.py
│ └── dataset_iterator.py
├── platform
│ ├── __init__.py
│ ├── local.py
│ └── interface.py
├── profiler.py
├── values.py
├── asynclib_test.py
├── asynclib.py
├── parameter_overview_test.py
├── preprocess_spec_test.py
├── periodic_actions_test.py
├── parameter_overview.py
├── deterministic_data_test.py
└── preprocess_spec.py
├── README.md
├── .github
└── workflows
│ ├── python-publish.yml
│ └── build.yml
├── CONTRIBUTING.md
├── setup.py
├── CHANGELOG.md
└── LICENSE
/AUTHORS:
--------------------------------------------------------------------------------
1 | # This is the list of Common Loop Utils significant contributors.
2 | #
3 | # This does not necessarily list everyone who has contributed code,
4 | # especially since many employees of one corporation may be contributing.
5 | # To see the full list of contributors, see the revision history in
6 | # source control.
7 | Google LLC
8 |
--------------------------------------------------------------------------------
/clu/run_pytest.google.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | set -e -x
4 |
5 | CLU_DST="${CLU_DST:-/tmp/clu}"
6 | CLU_ENV="${CLU_ENV:-/tmp/clu_env}"
7 |
8 | copybara third_party/py/clu/copy.bara.sky local .. \
9 | --folder-dir="${CLU_DST}" --ignore-noop
10 |
11 | # Note: we're reusing the environment if it already exists.
12 | mkdir -p "${CLU_ENV}"
13 | cd "${CLU_ENV}"
14 | python3 -m virtualenv .
15 | . bin/activate
16 |
17 | cd "${CLU_DST}"
18 | pip install . .[test]
19 |
20 | pytest
21 |
--------------------------------------------------------------------------------
/clu/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The CLU Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
--------------------------------------------------------------------------------
/clu/internal/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The CLU Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
--------------------------------------------------------------------------------
/clu/metric_writers/tf/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The CLU Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Package __init__ file."""
16 |
--------------------------------------------------------------------------------
/clu/metric_writers/summary_writer.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The CLU Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """MetricWriter for writing to TF summary files."""
16 | # pylint: disable=unused-import
17 |
18 | from .tf.summary_writer import SummaryWriter
19 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # CLU - Common Loop Utils
2 |
3 | This repository contains common functionality for writing ML training loops. The
4 | goal is to make trainings loops short and readable (but moving common tasks to
5 | small libraries) without removing the flexibility required for research.
6 |
7 | To get started, check out this Colab:
8 |
9 | https://colab.research.google.com/github/google/CommonLoopUtils/blob/main/clu_synopsis.ipynb
10 |
11 | If you're looking for usage examples, see:
12 |
13 | https://github.com/google/flax/tree/main/examples
14 |
15 | You can also find answers to common questions about CLU on Flax Github
16 | discussions page:
17 |
18 | https://github.com/google/flax/discussions
19 |
20 | Note: As this point we are not accepting contributions. Please fork the
21 | repository if you want to extend the libraries for your use case.
22 |
--------------------------------------------------------------------------------
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 | # This workflows will upload a Python Package using Twine when a release is created
2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
3 |
4 | name: Upload Python Package
5 |
6 | on:
7 | release:
8 | types: [created]
9 |
10 | jobs:
11 | deploy:
12 |
13 | runs-on: ubuntu-latest
14 |
15 | steps:
16 | - uses: actions/checkout@v3
17 | - name: Set up Python
18 | uses: actions/setup-python@v4
19 | with:
20 | python-version: '3.x'
21 | - name: Install dependencies
22 | run: |
23 | python -m pip install --upgrade pip
24 | pip install setuptools wheel twine
25 | - name: Build and publish
26 | env:
27 | TWINE_USERNAME: __token__
28 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
29 | run: |
30 | python setup.py sdist bdist_wheel
31 | twine upload dist/*
32 |
--------------------------------------------------------------------------------
/clu/data/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The CLU Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """DatasetIterator is an interface for input pipelines."""
16 | # pylint: disable=g-multiple-import
17 | # pylint: disable=unused-import
18 |
19 | from clu.data.dataset_iterator import (
20 | Array,
21 | ArraySpec,
22 | DatasetIterator,
23 | Element,
24 | ElementSpec,
25 | TfDatasetIterator,
26 | PeekableDatasetIterator,
27 | PyTree,
28 | )
29 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | At this time we are focused on supporting research done by Google Research and
4 | are not accepting patches.
5 |
6 | You are however free to start of fork of the project for your purposes as
7 | permitted by the license.
8 |
9 | ## Contributor License Agreement
10 |
11 | Contributions to this project must be accompanied by a Contributor License
12 | Agreement (CLA). You (or your employer) retain the copyright to your
13 | contribution; this simply gives us permission to use and redistribute your
14 | contributions as part of the project. Head over to
15 | to see your current agreements on file or
16 | to sign a new one.
17 |
18 | You generally only need to submit a CLA once, so if you've already submitted one
19 | (even if it was for a different project), you probably don't need to do it
20 | again.
21 |
22 | ## Community Guidelines
23 |
24 | This project follows
25 | [Google's Open Source Community Guidelines](https://opensource.google/conduct/).
26 |
--------------------------------------------------------------------------------
/.github/workflows/build.yml:
--------------------------------------------------------------------------------
1 | # This workflow will install Python dependencies, run tests and lint.
2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
3 |
4 | name: Build
5 |
6 | on:
7 | push:
8 | branches:
9 | - main
10 | - 'test_*'
11 | pull_request:
12 | branches:
13 | - main
14 |
15 | jobs:
16 | build:
17 | runs-on: ubuntu-latest
18 | strategy:
19 | matrix:
20 | python-version: ['3.10', '3.11']
21 | steps:
22 | - name: Cancel previous
23 | uses: styfle/cancel-workflow-action@0.8.0
24 | with:
25 | access_token: ${{ github.token }}
26 | - uses: actions/checkout@v4
27 | - name: Set up Python ${{ matrix.python-version }}
28 | uses: actions/setup-python@v5
29 | with:
30 | python-version: ${{ matrix.python-version }}
31 | - name: Install dependencies
32 | run: |
33 | pip install .
34 | pip install .[test]
35 | - name: Test with pytest and generate coverage report
36 | run: |
37 | pytest .
38 | - name: Upload coverage to Codecov
39 | uses: codecov/codecov-action@v1
40 | with:
41 | file: ./coverage.xml
42 |
--------------------------------------------------------------------------------
/clu/platform/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The CLU Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Methods for interacting with the experiment platform.
16 |
17 | Use cases include informing the platform of the experiment status and providing
18 | a platform independent interface for interactions.
19 | """
20 |
21 | import threading
22 |
23 |
24 | from clu.platform.interface import ArtifactType
25 | from clu.platform.interface import WorkUnit
26 | from clu.platform.local import LocalWorkUnit
27 |
28 | # TODO(b/200953513): Migrate away from logging imports (on module level)
29 | # to logging the actual usage. See b/200953513.
30 |
31 | _work_unit = None
32 | _work_unit_lock = threading.Lock()
33 |
34 |
35 | def work_unit() -> WorkUnit:
36 | """Gets the global work unit for this experiment trial."""
37 | global _work_unit
38 | if _work_unit is None:
39 | with _work_unit_lock:
40 | _work_unit = LocalWorkUnit()
41 | return _work_unit
42 |
--------------------------------------------------------------------------------
/clu/platform/local.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The CLU Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Implementation for platform functionality when running locally."""
16 |
17 | from typing import Any
18 |
19 | from absl import logging
20 | from clu.platform import interface
21 |
22 | WorkUnit = interface.WorkUnit
23 | ArtifactType = interface.ArtifactType
24 |
25 |
26 | class LocalWorkUnit(WorkUnit):
27 | """Dummy work unit for running locally."""
28 |
29 | @property
30 | def experiment_id(self):
31 | """ID of the experiment of the work unit."""
32 | return -1
33 |
34 | @property
35 | def id(self):
36 | """Unique identifier for the work unit."""
37 | return -1
38 |
39 | def set_notes(self, msg: str):
40 | """Sets the notes for this work unit."""
41 | logging.info("Setting work unit notes: %s", msg)
42 |
43 | def set_task_status(self, msg: str):
44 | """Sets the status string for this task."""
45 | logging.info("Setting task status: %s", msg)
46 |
47 | def create_artifact(self, artifact_type: ArtifactType, artifact: Any,
48 | description: str):
49 | """Creates an artifact entry for the work unit."""
50 | logging.info("Created artifact %s of type %s and value %s.", description,
51 | artifact_type, artifact)
52 |
--------------------------------------------------------------------------------
/clu/profiler.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The CLU Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Methods for running triggering a profiler for accelerators.
16 |
17 | Where results are stored depends on the platform (e.g. TensorBoard).
18 | """
19 | from collections.abc import Callable, Sequence
20 | import threading
21 | from typing import Optional, Protocol
22 |
23 | from absl import logging
24 |
25 | import jax
26 |
27 |
28 |
29 | def start(logdir: str, options=None):
30 | """Starts profiling."""
31 | if options is not None:
32 | raise NotImplementedError(
33 | "'options' not supported by clu.profiler.start(). Please file an issue "
34 | "at https://github.com/google/jax/issues requesting profiler option "
35 | "support if you need this feature.")
36 | if logdir is None:
37 | raise ValueError("Must specify logdir where profile should be written!")
38 | jax.profiler.start_trace(logdir)
39 |
40 |
41 | def stop() -> Optional[str]:
42 | """Stops profiling."""
43 | jax.profiler.stop_trace()
44 |
45 |
46 | CollectCallback = Callable[[Optional[str]], None]
47 |
48 |
49 | def collect(logdir: str,
50 | callback: CollectCallback,
51 | hosts: Optional[Sequence[str]] = None,
52 | duration_ms: int = 3_000):
53 | """Calls start() followed by stop() after specified duration."""
54 | del hosts # not used.
55 | start(logdir)
56 |
57 | def timer_cb():
58 | stop()
59 | callback(None)
60 |
61 | threading.Timer(duration_ms / 1e3, timer_cb).start()
62 |
63 |
64 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The CLU Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """setup.py for Common Loop Utils.
16 |
17 | Install for development:
18 |
19 | pip intall -e . .[tests]
20 | """
21 |
22 | import sys
23 |
24 | from setuptools import find_packages
25 | from setuptools import setup
26 |
27 |
28 | if sys.version_info < (3, 10):
29 | sys.exit("Python < 3.10 not supported anymore!")
30 |
31 | tests_require = [
32 | "pytest",
33 | "tensorflow",
34 | "tensorflow_datasets",
35 | "torch>=2.0.0",
36 | ]
37 |
38 | setup(
39 | name="clu",
40 | version="0.0.12",
41 | description="Set of libraries for ML training loops in JAX.",
42 | author="Common Loop Utils Authors",
43 | author_email="no-reply@google.com",
44 | long_description=open("README.md").read(),
45 | long_description_content_type="text/markdown",
46 | url="http://github.com/google/CommonLoopUtils",
47 | license="Apache 2.0",
48 | packages=find_packages(),
49 | include_package_data=True,
50 | install_requires=[
51 | "absl-py",
52 | "etils[epath,epy]",
53 | "flax",
54 | "jax",
55 | "jaxlib",
56 | "ml_collections",
57 | "numpy",
58 | "packaging",
59 | "typing_extensions",
60 | "wrapt",
61 | ],
62 | tests_require=tests_require,
63 | extras_require=dict(test=tests_require),
64 | classifiers=[
65 | "Development Status :: 4 - Beta",
66 | "Intended Audience :: Developers",
67 | "Intended Audience :: Science/Research",
68 | "License :: OSI Approved :: Apache Software License",
69 | "Programming Language :: Python",
70 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
71 | ],
72 | keywords="JAX machine learning",
73 | )
74 |
--------------------------------------------------------------------------------
/clu/platform/interface.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The CLU Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Interface work units."""
16 |
17 | import abc
18 | import enum
19 | from typing import Any
20 |
21 |
22 | class ArtifactType(enum.Enum):
23 | # A URL for dashboards, etc.
24 | URL = 1
25 | # File path.
26 | FILE = 2
27 | # Directory path.
28 | DIRECTORY = 3
29 |
30 |
31 | class WorkUnit(abc.ABC):
32 | """A work unit represents a single trial in an experiment.
33 |
34 | Experiments will usually have multiple work units with different
35 | hyperparameters. Each work unit can have multiple jobs (training,
36 | evaluation, etc.). And jobs can have multiple tasks when the training
37 | is distributed across multiple machines.
38 | """
39 |
40 | @property
41 | @abc.abstractmethod
42 | def experiment_id(self):
43 | """ID of the experiment of the work unit."""
44 |
45 | @property
46 | @abc.abstractmethod
47 | def id(self):
48 | """Unique identifier for the work unit."""
49 |
50 | @property
51 | def name(self):
52 | """Returns the name of the work unit as /.
53 |
54 | XID is a ID of the experiment and WID is the number of the work unit
55 | within the experiment.
56 |
57 | Returns:
58 | The work unit name. e.g. 12345/1.
59 | """
60 | return f"{self.experiment_id}/{self.id}"
61 |
62 | @abc.abstractmethod
63 | def set_notes(self, msg: str):
64 | """Sets the notes for this work unit. These are displayed in the UI."""
65 |
66 | @abc.abstractmethod
67 | def set_task_status(self, msg: str):
68 | """Sets the status string for this task."""
69 |
70 | @abc.abstractmethod
71 | def create_artifact(self, artifact_type: ArtifactType, artifact: Any,
72 | description: str):
73 | """Creates an artifact entry for the work unit."""
74 |
--------------------------------------------------------------------------------
/clu/metric_writers/multi_writer_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The CLU Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for MultiWriter."""
16 |
17 | from unittest import mock
18 |
19 | from clu.metric_writers import interface
20 | from clu.metric_writers import multi_writer
21 | import numpy as np
22 | import tensorflow as tf
23 |
24 |
25 | class MultiWriterTest(tf.test.TestCase):
26 |
27 | def setUp(self):
28 | super().setUp()
29 | self.writers = [
30 | mock.create_autospec(interface.MetricWriter),
31 | mock.create_autospec(interface.MetricWriter)
32 | ]
33 | self.writer = multi_writer.MultiWriter(self.writers)
34 |
35 | def test_write_scalars(self):
36 | self.writer.write_scalars(0, {"a": 3, "b": 0.15})
37 | self.writer.write_scalars(2, {"a": 5, "b": 0.007})
38 | self.writer.flush()
39 | for w in self.writers:
40 | w.write_scalars.assert_has_calls([
41 | mock.call(step=0, scalars={
42 | "a": 3,
43 | "b": 0.15
44 | }),
45 | mock.call(step=2, scalars={
46 | "a": 5,
47 | "b": 0.007
48 | })
49 | ])
50 | w.flush.assert_called()
51 |
52 | def test_write_pointcloud(self):
53 | point_clouds = np.random.normal(0, 1, (1, 1024, 3)).astype(np.float32)
54 | point_colors = np.random.uniform(0, 1, (1, 1024, 3)).astype(np.float32)
55 | config = {
56 | "material": "PointCloudMaterial",
57 | "size": 0.09,
58 | }
59 | self.writer.write_pointcloud(
60 | step=0,
61 | point_clouds={"pcd": point_clouds},
62 | point_colors={"pcd": point_colors},
63 | configs={"config": config},
64 | )
65 | self.writer.flush()
66 | for w in self.writers:
67 | w.write_pointcloud.assert_called_with(
68 | step=0,
69 | point_clouds={"pcd": point_clouds},
70 | point_colors={"pcd": point_colors},
71 | configs={"config": config},
72 | )
73 | w.flush.assert_called()
74 |
75 |
76 | if __name__ == "__main__":
77 | tf.test.main()
78 |
--------------------------------------------------------------------------------
/clu/metric_writers/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The CLU Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Metric writers write ML model outputs during model training and evaluation.
16 |
17 | This module introduces the MetricWriter interface. MetricWriters allow users
18 | to write out metrics about ML models during training and evaluation (e.g. loss,
19 | accuracy).
20 | There is a MetricWriter implementation for each back end (e.g. TensorFlow
21 | summaries) and classes that work on top other MetricWriter to
22 | write to multiple writes at once or write asynchronously.
23 |
24 | Note: The current interface might not contain write() methods for all possible
25 | data types. We are open for extending the interface to other data types
26 | (e.g. audio).
27 |
28 | Usage:
29 | writer = MyMetricWriterImplementation()
30 | # Before training.
31 | writer.write_hparams({"learning_rate": 0.001, "batch_size": 64})
32 | # Start training loop.
33 | for step in range(num_train_steps):
34 | loss = train_step()
35 | if step % 50 == 0:
36 | writer.write_scalars(step, {"loss": loss})
37 | accuracy = evaluate()
38 | writer.write_scalars(step, {"accuracy": accuracy})
39 | # Make sure all values were written.
40 | writer.flush() # or use metric_writers.ensure_flushes() context.
41 | """
42 |
43 | # pylint: disable=unused-import
44 | # pylint: disable=g-importing-member
45 |
46 |
47 | from clu.metric_writers.async_writer import AsyncMultiWriter
48 | from clu.metric_writers.async_writer import AsyncWriter
49 | from clu.metric_writers.async_writer import ensure_flushes
50 | from clu.metric_writers.interface import MetricWriter
51 | from clu.metric_writers.logging_writer import LoggingWriter
52 | from clu.metric_writers.multi_writer import MultiWriter
53 | from clu.metric_writers.summary_writer import SummaryWriter
54 | from clu.metric_writers.utils import create_default_writer
55 | from clu.metric_writers.utils import write_values
56 |
57 | # TODO(b/200953513): Migrate away from logging imports (on module level)
58 | # to logging the actual usage. See b/200953513.
59 |
60 |
--------------------------------------------------------------------------------
/clu/metric_writers/multi_writer.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The CLU Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """MetricWriter that writes to multiple MetricWriters."""
16 |
17 | from collections.abc import Mapping, Sequence
18 | from typing import Any, Optional
19 |
20 | from clu.metric_writers import interface
21 |
22 | Array = interface.Array
23 | Scalar = interface.Scalar
24 |
25 |
26 | class MultiWriter(interface.MetricWriter):
27 | """MetricWriter that writes to multiple writers at once."""
28 |
29 | def __init__(self, writers: Sequence[interface.MetricWriter]):
30 | self._writers = tuple(writers)
31 |
32 | def write_summaries(
33 | self, step: int,
34 | values: Mapping[str, Array],
35 | metadata: Optional[Mapping[str, Any]] = None):
36 | for w in self._writers:
37 | w.write_summaries(step, values, metadata)
38 |
39 | def write_scalars(self, step: int, scalars: Mapping[str, Scalar]):
40 | for w in self._writers:
41 | w.write_scalars(step, scalars)
42 |
43 | def write_images(self, step: int, images: Mapping[str, Array]):
44 | for w in self._writers:
45 | w.write_images(step, images)
46 |
47 | def write_videos(self, step: int, videos: Mapping[str, Array]):
48 | for w in self._writers:
49 | w.write_videos(step, videos)
50 |
51 | def write_audios(
52 | self, step: int, audios: Mapping[str, Array], *, sample_rate: int):
53 | for w in self._writers:
54 | w.write_audios(step, audios, sample_rate=sample_rate)
55 |
56 | def write_texts(self, step: int, texts: Mapping[str, str]):
57 | for w in self._writers:
58 | w.write_texts(step, texts)
59 |
60 | def write_histograms(self,
61 | step: int,
62 | arrays: Mapping[str, Array],
63 | num_buckets: Optional[Mapping[str, int]] = None):
64 | for w in self._writers:
65 | w.write_histograms(step, arrays, num_buckets)
66 |
67 | def write_pointcloud(
68 | self,
69 | step: int,
70 | point_clouds: Mapping[str, Array],
71 | *,
72 | point_colors: Mapping[str, Array] | None = None,
73 | configs: Mapping[str, str | float | bool | None] | None = None,
74 | ):
75 | for w in self._writers:
76 | w.write_pointcloud(
77 | step, point_clouds, point_colors=point_colors, configs=configs
78 | )
79 |
80 | def write_hparams(self, hparams: Mapping[str, Any]):
81 | for w in self._writers:
82 | w.write_hparams(hparams)
83 |
84 | def flush(self):
85 | for w in self._writers:
86 | w.flush()
87 |
88 | def close(self):
89 | for w in self._writers:
90 | w.close()
91 |
--------------------------------------------------------------------------------
/clu/values.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The CLU Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Defines available types for use by Metrics when written.
16 |
17 | A Metric should return one of the following types when compute() is called.
18 | """
19 |
20 | import dataclasses
21 | from typing import Any, Union, Protocol, runtime_checkable
22 |
23 | import jax.numpy as jnp
24 | import numpy as np
25 |
26 | ArrayType = Union[np.ndarray, jnp.ndarray]
27 | ScalarType = Union[int, float, np.number, np.ndarray, jnp.ndarray]
28 |
29 |
30 | @runtime_checkable
31 | class Value(Protocol):
32 | """Class defining available metric computation return values.
33 |
34 | Types mirror those available in MetricWriter. See
35 | clu/metric_writers/interface.py
36 | """
37 | value: Any
38 |
39 |
40 | @dataclasses.dataclass
41 | class Summary(Value):
42 | value: ArrayType
43 | metadata: Any
44 |
45 |
46 | @dataclasses.dataclass
47 | class Scalar(Value):
48 | value: ScalarType
49 |
50 |
51 | @dataclasses.dataclass
52 | class Image(Value):
53 | """Image type.
54 |
55 | Mapping from image key to images. Images should have the shape [N, H, W, C] or
56 | [H, W, C], where H is the height, W is the width and C the
57 | number of channels (1 or 3). N is the number of images that will be
58 | written. Image dimensions can differ between different image keys but
59 | not between different steps for the same image key.
60 | """
61 | value: ArrayType
62 |
63 |
64 | @dataclasses.dataclass
65 | class Audio(Value):
66 | """Audio type.
67 |
68 | Mapping from audio key to audios. Audios should have the shape [N, T, C],
69 | where T is the time length and C the number of channels (1 -
70 | mono, 2 - stereo, >= 3 - surround; not all writers support any number of
71 | channels). N is the number of audios that will be written. Audio
72 | dimensions can differ between different audio keys but not between
73 | different steps for the same audio key. Values should be floating-point
74 | values in [-1, +1].
75 | """
76 | value: ArrayType
77 | sample_rate: int
78 |
79 |
80 | @dataclasses.dataclass
81 | class Text(Value):
82 | value: str
83 |
84 |
85 | @dataclasses.dataclass
86 | class Histogram(Value):
87 | # value must be an array of counts (integers)
88 | value: ArrayType
89 | num_buckets: int
90 |
91 |
92 | @dataclasses.dataclass
93 | class HyperParam(Value):
94 | """The name of the hyperparameter should be handled outside this class.
95 |
96 | Value should correspond to a single hyperparameter, while a Mapping[str,
97 | HyperParam] (name to HyperParam) is maintained independently.
98 | """
99 | value: Any
100 |
--------------------------------------------------------------------------------
/clu/internal/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The CLU Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Small utilities by CLU libraries."""
16 |
17 | import contextlib
18 | import sys
19 | import time
20 | from typing import Any, List, Mapping, Tuple, Union
21 |
22 | from absl import logging
23 |
24 | import jax.numpy as jnp
25 | import numpy as np
26 | import wrapt
27 |
28 |
29 | @contextlib.contextmanager
30 | def log_activity(activity_name: str):
31 | """Logs `activity_name` and timing information (or exception)."""
32 | t0 = time.time()
33 | logging.info("%s ...", activity_name)
34 | try:
35 | yield
36 | finally:
37 | dt = time.time() - t0
38 | exc, *_ = sys.exc_info()
39 | if exc is not None:
40 | logging.exception("%s FAILED after %.2fs with %s.", activity_name, dt,
41 | exc.__name__)
42 | else:
43 | logging.info("%s finished after %.2fs.", activity_name, dt)
44 |
45 |
46 |
47 | def logged_with(activity_name: str):
48 | """Returns a decorator wrapping a function with `log_activity()`."""
49 | @wrapt.decorator
50 | def decorator(wrapped, instance, args, kwargs):
51 | del instance # not used
52 | with log_activity(activity_name):
53 | return wrapped(*args, **kwargs)
54 | return decorator
55 |
56 |
57 | def check_param(value, *, ndim=None, dtype=jnp.float32):
58 | """Raises a `ValueError` if `value` does not match ndim/dtype.
59 |
60 | Args:
61 | value: Value to be tested.
62 | ndim: Expected dimensions.
63 | dtype: Expected dtype.
64 |
65 | Raises:
66 | A `ValueError` if `value` does not match `ndim` or `dtype`, or if `value`
67 | is not an instance of `jnp.ndarray`.
68 | """
69 | if not isinstance(value, (np.ndarray, jnp.ndarray)):
70 | raise ValueError(f"Expected np.array or jnp.array, got type={type(value)}")
71 | if ndim is not None and value.ndim != ndim:
72 | raise ValueError(f"Expected ndim={ndim}, got ndim={value.ndim}")
73 | if dtype is not None and value.dtype != dtype:
74 | raise ValueError(f"Expected dtype={dtype}, got dtype={value.dtype}")
75 |
76 |
77 | def flatten_dict(
78 | d: Mapping[str, Any], prefix: Tuple[str, ...] = ()
79 | ) -> List[Tuple[str, Union[int, float, str]]]:
80 | """Returns a sequence of flattened (k, v) pairs for tfsummary.hparams().
81 |
82 | Args:
83 | d: A dict-like object that has an `.item()` method.
84 | prefix: Prefix to add to keys in `d`.
85 |
86 | Returns:
87 | Sequence of (k, v) pairs where k is the flattened key with individual
88 | subkeys separated by dots. `None` values are replaced by the empty string.
89 | """
90 | ret = []
91 | for k, v in d.items():
92 | # Note `ml_collections.ConfigDict` is not (yet) a `Mapping`.
93 | if isinstance(v, Mapping) or hasattr(v, "items"):
94 | ret += flatten_dict(v, prefix + (k,))
95 | elif isinstance(v, (list, tuple)):
96 | ret += flatten_dict({str(idx): value for idx, value in enumerate(v)},
97 | prefix + (k,))
98 | else:
99 | ret.append((".".join(prefix + (k,)), v if v is not None else ""))
100 | return ret
101 |
--------------------------------------------------------------------------------
/clu/metric_writers/torch_tensorboard_writer.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The CLU Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """MetricWriter for Pytorch summary files.
16 |
17 | Use this writer for the Pytorch-based code.
18 |
19 | """
20 |
21 | from collections.abc import Mapping
22 | from typing import Any, Optional
23 | from absl import logging
24 |
25 | from clu.metric_writers import interface
26 | from torch.utils import tensorboard
27 |
28 | Array = interface.Array
29 | Scalar = interface.Scalar
30 |
31 |
32 | class TorchTensorboardWriter(interface.MetricWriter):
33 | """MetricWriter that writes Pytorch summary files."""
34 |
35 | def __init__(self, logdir: str):
36 | super().__init__()
37 | self._writer = tensorboard.SummaryWriter(log_dir=logdir)
38 |
39 |
40 | def write_summaries(
41 | self, step: int,
42 | values: Mapping[str, Array],
43 | metadata: Optional[Mapping[str, Any]] = None):
44 | logging.log_first_n(
45 | logging.WARNING,
46 | "TorchTensorboardWriter does not support writing raw summaries.", 1)
47 |
48 | def write_scalars(self, step: int, scalars: Mapping[str, Scalar]):
49 | for key, value in scalars.items():
50 | self._writer.add_scalar(key, value, global_step=step)
51 |
52 | def write_images(self, step: int, images: Mapping[str, Array]):
53 | for key, value in images.items():
54 | self._writer.add_image(key, value, global_step=step, dataformats="HWC")
55 |
56 | def write_videos(self, step: int, videos: Mapping[str, Array]):
57 | logging.log_first_n(
58 | logging.WARNING,
59 | "TorchTensorBoardWriter does not support writing videos.", 1)
60 |
61 | def write_audios(
62 | self, step: int, audios: Mapping[str, Array], *, sample_rate: int):
63 | for key, value in audios.items():
64 | self._writer.add_audio(
65 | key, value, global_step=step, sample_rate=sample_rate)
66 |
67 | def write_texts(self, step: int, texts: Mapping[str, str]):
68 | raise NotImplementedError(
69 | "TorchTensorBoardWriter does not support writing texts."
70 | )
71 |
72 | def write_histograms(self,
73 | step: int,
74 | arrays: Mapping[str, Array],
75 | num_buckets: Optional[Mapping[str, int]] = None):
76 | for tag, values in arrays.items():
77 | bins = None if num_buckets is None else num_buckets.get(tag)
78 | self._writer.add_histogram(
79 | tag, values, global_step=step, bins="auto", max_bins=bins)
80 |
81 | def write_pointcloud(
82 | self,
83 | step: int,
84 | point_clouds: Mapping[str, Array],
85 | *,
86 | point_colors: Mapping[str, Array] | None = None,
87 | configs: Mapping[str, str | float | bool | None] | None = None,
88 | ):
89 | logging.log_first_n(
90 | logging.WARNING,
91 | "TorchTensorBoardWriter does not support writing point clouds.",
92 | 1,
93 | )
94 |
95 | def write_hparams(self, hparams: Mapping[str, Any]):
96 | self._writer.add_hparams(hparams, {})
97 |
98 | def flush(self):
99 | self._writer.flush()
100 |
101 | def close(self):
102 | self._writer.close()
103 |
--------------------------------------------------------------------------------
/clu/metric_writers/torch_tensorboard_writer_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The CLU Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for TorchTensorboardWriter."""
16 |
17 | import collections
18 | import os
19 | from typing import Any, Dict
20 |
21 | from clu.metric_writers import torch_tensorboard_writer
22 | import numpy as np
23 | import tensorflow as tf
24 |
25 |
26 | def _load_scalars_data(logdir: str):
27 | """Loads scalar summaries from events in a logdir."""
28 | paths = tf.io.gfile.glob(os.path.join(logdir, "events.out.tfevents.*"))
29 | data = collections.defaultdict(dict)
30 | for path in paths:
31 | for event in tf.compat.v1.train.summary_iterator(path):
32 | for value in event.summary.value:
33 | data[event.step][value.tag] = value.simple_value
34 |
35 | return data
36 |
37 |
38 | def _load_histograms_data(logdir: str) -> Dict[int, Dict[str, Any]]:
39 | """Loads histograms summaries from events in a logdir.
40 |
41 | Args:
42 | logdir: a directory to find logs
43 |
44 | Returns:
45 | A generated histograms in a shape step -> tag -> histo.
46 | """
47 | paths = tf.io.gfile.glob(os.path.join(logdir, "events.out.tfevents.*"))
48 | data = {}
49 | for path in paths:
50 | for event in tf.compat.v1.train.summary_iterator(path):
51 | if event.step not in data:
52 | data[event.step] = {}
53 | step_data = {}
54 | for value in event.summary.value:
55 | print(" value:", value)
56 | step_data[value.tag] = value.histo
57 | data[event.step].update(step_data)
58 |
59 | return data
60 |
61 |
62 | class TorchTensorboardWriterTest(tf.test.TestCase):
63 |
64 | def setUp(self):
65 | super().setUp()
66 | self.logdir = self.get_temp_dir()
67 | self.writer = torch_tensorboard_writer.TorchTensorboardWriter(self.logdir)
68 |
69 | def test_write_scalar(self):
70 | self.writer.write_scalars(11, {"a": 0.6, "b": 15})
71 | self.writer.write_scalars(20, {"a": 0.8, "b": 12})
72 | self.writer.flush()
73 | data = _load_scalars_data(self.logdir)
74 | self.assertAllClose(data[11], {"a": 0.6, "b": 15})
75 | self.assertAllClose(data[20], {"a": 0.8, "b": 12})
76 |
77 | def test_write_histograms(self):
78 | self.writer.write_histograms(
79 | 0, {
80 | "a": np.asarray([0.3, 0.1, 0.5, 0.7, 0.1]),
81 | "b": np.asarray([-0.1, 0.3, 0.2, 0.4, 0.4]),
82 | }, num_buckets={"a": 2, "b": 2})
83 | self.writer.write_histograms(
84 | 2, {
85 | "a": np.asarray([0.2, 0.4, 0.5, 0.1, -0.1]),
86 | "b": np.asarray([0.7, 0.3, 0.2, 0.1, 0.0]),
87 | }, num_buckets={"a": 2, "b": 2})
88 | self.writer.flush()
89 | data = _load_histograms_data(self.logdir)
90 | self.assertNear(data[0]["a"].min, 0.1, 0.001)
91 | self.assertNear(data[0]["a"].max, 0.7, 0.001)
92 | self.assertNear(data[0]["b"].min, -0.1, 0.001)
93 | self.assertNear(data[0]["b"].max, 0.4, 0.001)
94 | self.assertNear(data[2]["a"].min, -0.1, 0.001)
95 | self.assertNear(data[2]["a"].max, 0.5, 0.001)
96 | self.assertNear(data[2]["b"].min, 0.0, 0.001)
97 | self.assertNear(data[2]["b"].max, 0.7, 0.001)
98 |
99 |
100 | if __name__ == "__main__":
101 | tf.test.main()
102 |
--------------------------------------------------------------------------------
/clu/internal/utils_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The CLU Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from unittest import mock
16 |
17 | from absl.testing import absltest
18 | from clu.internal import utils
19 | import jax.numpy as jnp
20 | import ml_collections
21 |
22 |
23 | class TestError(BaseException):
24 | __test__ = False
25 | pass
26 |
27 |
28 | class HelpersTest(absltest.TestCase):
29 |
30 | def test_log_activity(
31 | self,
32 | ):
33 | with self.assertLogs() as logs:
34 | with utils.log_activity("test_activity"):
35 | pass
36 | self.assertLen(logs.output, 2)
37 | self.assertEqual(logs.output[0], "INFO:absl:test_activity ...")
38 | self.assertRegex(logs.output[1],
39 | r"^INFO:absl:test_activity finished after \d+.\d\ds.$")
40 |
41 | def test_log_activity_fails(
42 | self,
43 | ):
44 | with self.assertRaises(TestError): # pylint: disable=g-error-prone-assert-raises, line-too-long
45 | with self.assertLogs() as logs:
46 | with utils.log_activity("test_activity"):
47 | raise TestError()
48 | self.assertLen(logs.output, 2)
49 | self.assertEqual(logs.output[0], "INFO:absl:test_activity ...")
50 | self.assertRegex(logs.output[1],
51 | r"^ERROR:absl:test_activity FAILED after \d+.\d\ds")
52 |
53 | def test_logged_with(self):
54 |
55 | @utils.logged_with("test_activity")
56 | def test():
57 | pass
58 |
59 | with self.assertLogs() as logs:
60 | test()
61 | self.assertLen(logs.output, 2)
62 | self.assertEqual(logs.output[0], "INFO:absl:test_activity ...")
63 | self.assertRegex(logs.output[1],
64 | r"^INFO:absl:test_activity finished after \d+.\d\ds.$")
65 |
66 | def test_logged_with_fails(self):
67 |
68 | @utils.logged_with("test_activity")
69 | def test():
70 | raise TestError()
71 |
72 | with self.assertRaises(TestError): # pylint: disable=g-error-prone-assert-raises, line-too-long
73 | with self.assertLogs() as logs:
74 | test()
75 | self.assertLen(logs.output, 2)
76 | self.assertEqual(logs.output[0], "INFO:absl:test_activity ...")
77 | self.assertRegex(logs.output[1],
78 | r"^ERROR:absl:test_activity FAILED after \d+.\d\ds")
79 |
80 | def test_check_param(self):
81 | a = jnp.array(0.)
82 | with self.assertRaisesRegex(ValueError, r"^Expected np.array or jnp.array"):
83 | utils.check_param(None, ndim=1)
84 | with self.assertRaisesRegex(ValueError, r"^Expected ndim"):
85 | utils.check_param(a, ndim=1)
86 | with self.assertRaisesRegex(ValueError, r"^Expected dtype"):
87 | utils.check_param(a, ndim=0, dtype=jnp.int32)
88 | utils.check_param(a, ndim=0) # should work
89 | utils.check_param(a, ndim=0, dtype=jnp.float32) # should also work
90 |
91 | def test_flatten_dict(self):
92 | self.assertEqual(
93 | utils.flatten_dict(
94 | ml_collections.ConfigDict({
95 | "x": 1,
96 | "y": None,
97 | "z": ml_collections.ConfigDict({
98 | "a": "bc",
99 | })
100 | })), [("x", 1), ("y", ""), ("z.a", "bc")])
101 |
102 |
103 | if __name__ == "__main__":
104 | absltest.main()
105 |
--------------------------------------------------------------------------------
/CHANGELOG.md:
--------------------------------------------------------------------------------
1 | # Changelog
2 |
3 | ## v0.0.1-alpha.1
4 |
5 | Initial PyPi Release
6 |
7 | Current list of modules:
8 |
9 | - `clu.checkpoint`
10 | - `clu.deterministic_training`
11 | - `clu.metric_writers`
12 | - `clu.periodic_actions`
13 | - `clu.platform`
14 | - `clu.profiler`
15 |
16 | ## v0.0.1-alpha.2
17 |
18 | - Adds `metrics` module and some minor changes.
19 |
20 | ## v0.0.1a3
21 |
22 | - Added `metric_writers.TorchTensorboardWriter`
23 |
24 | ## v0.0.2
25 |
26 | - Added preprocess_spec.
27 | - Improvements to periodic_actions.
28 |
29 | ## v0.0.3
30 |
31 | - `metric_writers`: Lets `SummaryWriter` write nested dictionaries.
32 | - `internal`: Adds `async.Pool`.
33 | - `preprocess_spec`: Support nested dictionaries.
34 | - `profile`: Use JAX profiler APIs instead of TF profiler APIs.
35 |
36 | ## v0.0.4
37 |
38 | `deterministic_data`
39 |
40 | - Support non-positive input value for pad_up_to_batches.
41 | - Support padding dataset when data dimension is unknown.
42 | - Support TFDS specs in get_read_instruction_for_host.
43 | - Allow changing drop_remainder for batching.
44 | - Add RemainderOptions in deterministic_data.
45 |
46 | `metric_writers`
47 |
48 | - Support multiple writers in metric_writers.ensure_flushes.
49 |
50 | `metrics`
51 |
52 | - Makes internal.flatten_dict() work with ConfigDicts.
53 | - Forwards mask model output to metrics created via `Metric.from_output()`.
54 | - Forwards mask model output to metrics created via `Metric.from_fun()`.
55 | - Added `Collections.unreplicate()`, `Collections.create()`.
56 |
57 | `periodic_actions`
58 |
59 | - Formats long time strings in '{days}d{hours}h{mins}m' format.
60 |
61 | `preprocess_spec`
62 |
63 | - Make feature description of features in PreprocessFn more compact.
64 | - Better type check in `preprocess_spec.get_all_ops()`.
65 |
66 | Documentation:
67 |
68 | - Added `clu_synopsis.ipynb` Colab
69 |
70 | ## v0.0.5
71 |
72 | - Log error instead of failing when `profiler.start()` raises an exception.
73 | - Makes `periodic_actions.ProgressUpdate` show total number of steps.
74 | - Makes `AsyncWriter` non-blocking wrt JAX async computations.
75 | - Adds `clu_synopsis.ipynb` Colab as initial documentation.
76 | - Restore Checkpoint without providing the state
77 | - Makes `PreprocessFn` addable.
78 | - Allow n-dimensional arrays (and masks) to be passed to Metrics.Average().
79 | - Support slicing `PreprocessFn`.
80 |
81 | ## v0.0.6
82 |
83 | - Makes `deterministic_data` work with `tfds>4.4.0` and `tfds<=4.4.0`.
84 |
85 | This will be the last release supporting Python 3.6.
86 |
87 | ## v0.0.7
88 |
89 | - Moves `clu.internal.asynclib` to `clu.asynclib`.
90 | - Adds methods for writing raw tensors and audio to `MetricWriter`.
91 | - Adds `clu.values` to annotate arrays with a modality.
92 | - Adds `clu.data.DatasetIterator` - a generic interface between input
93 | pipelines and training loops.
94 | - Fixes various issues with `clu.metrics`.
95 |
96 | This will be the last release supporting Python 3.7.
97 |
98 | ## v0.0.9
99 |
100 | - Fix pytype failures related to teaching pytype about NumPy scalar types.
101 | - Fix a couple of docstring typos.
102 | - Updates README and clu_synposis.ipynb
103 |
104 | Last release before dropping support for Python 3.8 and 3.9
105 |
106 | ## v0.0.10
107 |
108 | - `clu.parameter_overview` now supports JAX global arrays.
109 | - Various small fixes in `clu.metrics` module.
110 | - Removed some tensorflow dependencies.
111 |
112 | ## v0.0.11
113 |
114 | - Removes numpy version pin
115 | - Adds sharding annotations, dtype, total bytes to `parameter_overview`
116 | - Makes `clu.metrics.Std` support same shapes as `clu.metrics.Average`
117 |
118 | ## v0.0.12
119 |
120 | - Switch from `jax.tree_map` (deprecated since JAX 0.4.26) to
121 | `jax.tree_util.tree_map`.
122 | - Improvements to parameter overview.
123 |
--------------------------------------------------------------------------------
/clu/asynclib_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The CLU Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for clu.asynclib."""
16 |
17 | from unittest import mock
18 |
19 | from absl.testing import absltest
20 | from clu import asynclib
21 |
22 |
23 | class AsyncWriterTest(absltest.TestCase):
24 |
25 | def test_async_execution(self):
26 | pool = asynclib.Pool()
27 | counter = 0
28 |
29 | @pool
30 | def fn(counter_increment, return_value):
31 | nonlocal counter
32 | counter += counter_increment
33 | return return_value
34 |
35 | future = fn(1, return_value=2)
36 | self.assertEqual(counter, 1)
37 | self.assertEqual(future.result(), 2)
38 |
39 | def test_reraise(self):
40 | pool = asynclib.Pool()
41 |
42 | @pool
43 | def error():
44 | raise ValueError("test")
45 |
46 | error()
47 | self.assertTrue(pool.has_errors)
48 | with self.assertRaisesRegex(asynclib.AsyncError, "test"):
49 | pool.join()
50 | self.assertFalse(pool.has_errors)
51 |
52 | @pool
53 | def noop():
54 | ...
55 |
56 | error()
57 | self.assertTrue(pool.has_errors)
58 | with self.assertRaisesRegex(asynclib.AsyncError, "test"):
59 | noop()
60 | self.assertFalse(pool.has_errors)
61 |
62 | pool.join()
63 |
64 | @mock.patch("concurrent.futures.ThreadPoolExecutor")
65 | def test_queue_length(self, executor_mock):
66 | pool_mock = mock.Mock()
67 | in_flight = []
68 |
69 | def execute_one():
70 | in_flight.pop(0)()
71 |
72 | def submit(fn, *args, **kwargs):
73 | in_flight.append(lambda: fn(*args, **kwargs))
74 |
75 | pool_mock.submit = submit
76 | executor_mock.return_value = pool_mock
77 |
78 | pool = asynclib.Pool()
79 |
80 | @pool
81 | def noop():
82 | ...
83 |
84 | self.assertEqual(pool.queue_length, 0)
85 | noop()
86 | self.assertEqual(pool.queue_length, 1)
87 | noop()
88 | self.assertEqual(pool.queue_length, 2)
89 | execute_one()
90 | self.assertEqual(pool.queue_length, 1)
91 | execute_one()
92 | self.assertEqual(pool.queue_length, 0)
93 |
94 | @mock.patch("concurrent.futures.ThreadPoolExecutor")
95 | def test_flush(self, executor_mock):
96 | pool_mock = mock.Mock()
97 | pool_mock._in_flight = None
98 |
99 | def execute_one():
100 | pool_mock._in_flight.pop(0)()
101 |
102 | def submit(fn, *args, **kwargs):
103 | pool_mock._in_flight.append(lambda: fn(*args, **kwargs))
104 |
105 | def create_pool(max_workers, thread_name_prefix):
106 | del max_workers
107 | del thread_name_prefix
108 | pool_mock._in_flight = []
109 | return pool_mock
110 |
111 | def shutdown(wait=False):
112 | if wait:
113 | while pool_mock._in_flight:
114 | execute_one()
115 | pool_mock._in_flight = None
116 |
117 | pool_mock.submit = submit
118 | executor_mock.side_effect = create_pool
119 | pool_mock.shutdown.side_effect = shutdown
120 |
121 | pool = asynclib.Pool()
122 |
123 | @pool
124 | def noop():
125 | ...
126 |
127 | self.assertEqual(pool.queue_length, 0)
128 | noop()
129 | self.assertEqual(pool.queue_length, 1)
130 | noop()
131 | pool.join()
132 | self.assertEqual(pool.queue_length, 0)
133 | noop()
134 | self.assertEqual(pool.queue_length, 1)
135 |
136 |
137 | if __name__ == "__main__":
138 | absltest.main()
139 |
--------------------------------------------------------------------------------
/clu/data/dataset_iterator_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The CLU Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for dataset_iterator."""
16 | import itertools
17 | import pathlib
18 | import tempfile
19 |
20 | from absl.testing import parameterized
21 | from clu.data import dataset_iterator
22 | import numpy as np
23 | import tensorflow as tf
24 |
25 | INDEX = "_index"
26 |
27 |
28 | class DatasetIteratorTest(parameterized.TestCase, tf.test.TestCase):
29 |
30 | def _create_iterator(self, start_index: int, checkpoint: bool = True):
31 | """Create an iterator over some prime numbers with index."""
32 | primes = tf.constant([2, 3, 5, 7, 11, 13, 17, 19, 23, 29])
33 | ds = tf.data.Dataset.range(start_index, 10)
34 | ds = ds.map(lambda i: {INDEX: i, "prime": primes[i]})
35 | # Remove index 1 and 3.
36 | ds = ds.filter(lambda x: tf.logical_and(x["prime"] != 3, x["prime"] != 7))
37 | ds = ds.batch(2, drop_remainder=True)
38 | return dataset_iterator.TfDatasetIterator(ds, checkpoint=checkpoint)
39 |
40 | def test_tf_iterator(self):
41 | it = self._create_iterator(0)
42 | self.assertEqual(
43 | it.element_spec, {
44 | INDEX: dataset_iterator.ArraySpec(np.int64, (2,)),
45 | "prime": dataset_iterator.ArraySpec(np.int32, (2,))
46 | })
47 | self.assertEqual(next(it), {INDEX: [0, 2], "prime": [2, 5]})
48 | self.assertEqual(next(it), {INDEX: [4, 5], "prime": [11, 13]})
49 | it.reset()
50 | # Iterator starts from the beginning.
51 | self.assertEqual(next(it), {INDEX: [0, 2], "prime": [2, 5]})
52 |
53 | def test_tf_iterator_save_and_load(self):
54 | it = self._create_iterator(0)
55 | next(it)
56 | next(it)
57 | next(it)
58 | work_dir = pathlib.Path(tempfile.mkdtemp())
59 | filename = work_dir / "ckpt"
60 | it.save(filename)
61 | self.assertTrue((work_dir / "ckpt.index").exists())
62 |
63 | it = self._create_iterator(0)
64 | # Iterator is at the beginning (batch 1).
65 | self.assertEqual(next(it), {INDEX: [0, 2], "prime": [2, 5]})
66 | it.load(filename)
67 | # Iterator is at the end (batch 4).
68 | self.assertEqual(next(it), {INDEX: [8, 9], "prime": [23, 29]})
69 |
70 | def test_tf_iterator_save_and_load_no_checkpoint(self):
71 | it = self._create_iterator(0, checkpoint=False)
72 | self.assertEqual(next(it), {INDEX: [0, 2], "prime": [2, 5]})
73 | self.assertEqual(next(it), {INDEX: [4, 5], "prime": [11, 13]})
74 | work_dir = pathlib.Path(tempfile.mkdtemp())
75 | filename = work_dir / "ckpt"
76 | it.save(filename) # Should be a no-op and not create a checkpoint.
77 | self.assertFalse((work_dir / "ckpt.index").exists())
78 |
79 | it = self._create_iterator(0, checkpoint=False)
80 | self.assertEqual(next(it), {INDEX: [0, 2], "prime": [2, 5]})
81 | it.restore(filename) # Should be a no-op, iterator just continues.
82 | self.assertEqual(next(it), {INDEX: [4, 5], "prime": [11, 13]})
83 |
84 | def test_peekable_dataset_iterator(self):
85 | it = self._create_iterator(0)
86 | it = dataset_iterator.PeekableDatasetIterator(it)
87 | self.assertEqual(it.peek(), {INDEX: [0, 2], "prime": [2, 5]})
88 | self.assertEqual(next(it), {INDEX: [0, 2], "prime": [2, 5]})
89 | self.assertEqual(next(it), {INDEX: [4, 5], "prime": [11, 13]})
90 |
91 | @parameterized.parameters(itertools.product([True, False], [True, False]))
92 | def test_peekable_dataset_iterator_async(self, wait: bool, peek_first: bool):
93 | it = self._create_iterator(0)
94 | it = dataset_iterator.PeekableDatasetIterator(it)
95 | future = it.peek_async()
96 | self.assertIsNone(it._peek)
97 | if wait:
98 | future.result()
99 | self.assertIsNotNone(it._peek)
100 | if peek_first:
101 | self.assertEqual(it.peek(), {INDEX: [0, 2], "prime": [2, 5]})
102 | self.assertEqual(next(it), {INDEX: [0, 2], "prime": [2, 5]})
103 | self.assertEqual(next(it), {INDEX: [4, 5], "prime": [11, 13]})
104 |
105 |
106 | if __name__ == "__main__":
107 | tf.test.main()
108 |
--------------------------------------------------------------------------------
/clu/metric_writers/tf/summary_writer.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The CLU Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """MetricWriter for writing to TF summary files.
16 |
17 | Only works in eager mode. Does not work for Pytorch code, please use
18 | TorchTensorboardWriter instead.
19 | """
20 |
21 | from collections.abc import Mapping
22 | from typing import Any, Optional
23 |
24 | from absl import logging
25 |
26 | from clu.internal import utils
27 | from clu.metric_writers import interface
28 | from etils import epy
29 | import tensorflow as tf
30 |
31 | with epy.lazy_imports():
32 | # pylint: disable=g-import-not-at-top
33 | from tensorboard.plugins.hparams import api as hparams_api
34 | from tensorboard.plugins.mesh import summary as mesh_summary # pylint: disable=line-too-long
35 | # pylint: enable=g-import-not-at-top
36 |
37 |
38 | Array = interface.Array
39 | Scalar = interface.Scalar
40 |
41 |
42 | class SummaryWriter(interface.MetricWriter):
43 | """MetricWriter that writes TF summary files."""
44 |
45 | def __init__(self, logdir: str):
46 | super().__init__()
47 | self._summary_writer = tf.summary.create_file_writer(logdir)
48 |
49 |
50 | def write_summaries(
51 | self,
52 | step: int,
53 | values: Mapping[str, Array],
54 | metadata: Optional[Mapping[str, Any]] = None,
55 | ):
56 | with self._summary_writer.as_default():
57 | for key, value in values.items():
58 | md = metadata.get(key) if metadata is not None else None
59 | tf.summary.write(key, value, step=step, metadata=md)
60 |
61 | def write_scalars(self, step: int, scalars: Mapping[str, Scalar]):
62 | with self._summary_writer.as_default():
63 | for key, value in scalars.items():
64 | tf.summary.scalar(key, value, step=step)
65 |
66 | def write_images(self, step: int, images: Mapping[str, Array]):
67 | with self._summary_writer.as_default():
68 | for key, value in images.items():
69 | if len(value.shape) == 3:
70 | value = value[None]
71 | tf.summary.image(key, value, step=step, max_outputs=value.shape[0])
72 |
73 | def write_videos(self, step: int, videos: Mapping[str, Array]):
74 | logging.log_first_n(
75 | logging.WARNING,
76 | "SummaryWriter does not support writing videos.", 1)
77 |
78 | def write_audios(
79 | self, step: int, audios: Mapping[str, Array], *, sample_rate: int):
80 | with self._summary_writer.as_default():
81 | for key, value in audios.items():
82 | tf.summary.audio(key, value, sample_rate=sample_rate, step=step,
83 | max_outputs=value.shape[0])
84 |
85 | def write_texts(self, step: int, texts: Mapping[str, str]):
86 | with self._summary_writer.as_default():
87 | for key, value in texts.items():
88 | tf.summary.text(key, value, step=step)
89 |
90 | def write_histograms(
91 | self,
92 | step: int,
93 | arrays: Mapping[str, Array],
94 | num_buckets: Optional[Mapping[str, int]] = None,
95 | ):
96 | with self._summary_writer.as_default():
97 | for key, value in arrays.items():
98 | buckets = None if num_buckets is None else num_buckets.get(key)
99 | tf.summary.histogram(key, value, step=step, buckets=buckets)
100 |
101 | def write_pointcloud(
102 | self,
103 | step: int,
104 | point_clouds: Mapping[str, Array],
105 | *,
106 | point_colors: Mapping[str, Array] | None = None,
107 | configs: Mapping[str, str | float | bool | None] | None = None,
108 | ):
109 | with self._summary_writer.as_default():
110 | for key, vertices in point_clouds.items():
111 | colors = None if point_colors is None else point_colors.get(key)
112 | config = None if configs is None else configs.get(key)
113 | mesh_summary.mesh(
114 | key,
115 | vertices=vertices,
116 | colors=colors,
117 | step=step,
118 | config_dict=config,
119 | )
120 |
121 | def write_hparams(self, hparams: Mapping[str, Any]):
122 | with self._summary_writer.as_default():
123 | hparams_api.hparams(dict(utils.flatten_dict(hparams)))
124 |
125 | def flush(self):
126 | self._summary_writer.flush()
127 |
128 | def close(self):
129 | self._summary_writer.close()
130 |
--------------------------------------------------------------------------------
/clu/metric_writers/logging_writer_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The CLU Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for the LoggingWriter."""
16 |
17 | from clu.metric_writers import logging_writer
18 | import numpy as np
19 | import tensorflow as tf
20 |
21 |
22 | class LoggingWriterTest(tf.test.TestCase):
23 |
24 | def setUp(self):
25 | super().setUp()
26 | self.writer = logging_writer.LoggingWriter()
27 |
28 | def test_write_scalars(self):
29 | with self.assertLogs(level="INFO") as logs:
30 | self.writer.write_scalars(0, {"a": 3, "b": 0.15})
31 | self.writer.write_scalars(2, {"a": 0.0000005, "b": 0.007})
32 | self.assertEqual(
33 | logs.output,
34 | ["INFO:absl:[0] a=3, b=0.15", "INFO:absl:[2] a=5e-07, b=0.007"])
35 |
36 | def test_write_images(self):
37 | images = np.zeros((2, 28, 28, 3))
38 | with self.assertLogs(level="INFO") as logs:
39 | self.writer.write_images(4, {"input_images": images})
40 | self.assertEqual(
41 | logs.output,
42 | ["INFO:absl:[4] Got images: {'input_images': (2, 28, 28, 3)}."])
43 |
44 | def test_write_videos(self):
45 | videos = np.zeros((2, 4, 28, 28, 3))
46 | with self.assertLogs(level="INFO") as logs:
47 | self.writer.write_videos(4, {"input_videos": videos})
48 | self.assertEqual(
49 | logs.output,
50 | ["INFO:absl:[4] Got videos: {'input_videos': (2, 4, 28, 28, 3)}."])
51 |
52 | def test_write_texts(self):
53 | with self.assertLogs(level="INFO") as logs:
54 | self.writer.write_texts(4, {"samples": "bla"})
55 | self.assertEqual(
56 | logs.output,
57 | ["INFO:absl:[4] Got texts: {'samples': 'bla'}."])
58 |
59 | def test_write_histogram(self):
60 | with self.assertLogs(level="INFO") as logs:
61 | self.writer.write_histograms(
62 | step=4,
63 | arrays={
64 | "a": np.asarray([-0.1, 0.1, 0.3]),
65 | "b": np.arange(31),
66 | "c": np.asarray([0.1, 0.1, 0.1, 0.1, 0.1]),
67 | },
68 | num_buckets={
69 | "a": 2,
70 | "c": 1
71 | })
72 | # Note: There are 31 distinct values [0, 1, ..., 30], and 30 buckets by
73 | # default. Last bucket gets 2 values.
74 | expected_histo_b = ", ".join([f"[{i}, {i + 1}): 1" for i in range(29)] +
75 | ["[29, 30]: 2"])
76 | self.assertEqual(logs.output, [
77 | "INFO:absl:[4] Histogram for 'a' = {[-0.1, 0.1): 1, [0.1, 0.3]: 2}",
78 | f"INFO:absl:[4] Histogram for 'b' = {{{expected_histo_b}}}",
79 | "WARNING:absl:num_buckets was automatically changed from 1 to 2",
80 | "INFO:absl:[4] Histogram for 'c' = {[-0.4, 0.6]: 5}",
81 | ])
82 |
83 | def test_write_pointcloud(self):
84 | point_clouds = np.random.normal(0, 1, (1, 1024, 3)).astype(np.float32)
85 | point_colors = np.random.uniform(0, 1, (1, 1024, 3)).astype(np.float32)
86 | config = {
87 | "material": "PointCloudMaterial",
88 | "size": 0.09,
89 | }
90 | with self.assertLogs(level="INFO") as logs:
91 | self.writer.write_pointcloud(
92 | step=4,
93 | point_clouds={"pcd": point_clouds},
94 | point_colors={"pcd": point_colors},
95 | configs={"configs": config},
96 | )
97 | self.assertEqual(
98 | logs.output,
99 | [
100 | "INFO:absl:[4] Got point clouds: {'pcd': (1, 1024, 3)},"
101 | " point_colors: {'pcd': (1, 1024, 3)}, configs: {'configs':"
102 | " {'material': 'PointCloudMaterial', 'size': 0.09}}."
103 | ],
104 | )
105 |
106 | def test_write_hparams(self):
107 | with self.assertLogs(level="INFO") as logs:
108 | self.writer.write_hparams({"learning_rate": 0.1, "batch_size": 128})
109 | self.assertEqual(logs.output, [
110 | "INFO:absl:[Hyperparameters] {'learning_rate': 0.1, 'batch_size': 128}"
111 | ])
112 |
113 | def test_collection(self):
114 | writer = logging_writer.LoggingWriter(collection="train")
115 | with self.assertLogs(level="INFO") as logs:
116 | writer.write_scalars(0, {"a": 3, "b": 0.15})
117 | writer.write_images(4, {"input_images": np.zeros((2, 28, 28, 3))})
118 | writer.write_texts(4, {"samples": "bla"})
119 | writer.write_histograms(
120 | step=4,
121 | arrays={
122 | "a": np.asarray([-0.1, 0.1, 0.3]),
123 | },
124 | num_buckets={
125 | "a": 2,
126 | })
127 | writer.write_hparams({"learning_rate": 0.1})
128 |
129 | self.assertEqual(logs.output, [
130 | "INFO:absl:[0] collection=train a=3, b=0.15",
131 | "INFO:absl:[4] collection=train Got images: {'input_images': (2, 28, 28, 3)}.",
132 | "INFO:absl:[4] collection=train Got texts: {'samples': 'bla'}.",
133 | "INFO:absl:[4] collection=train Histogram for 'a' = {[-0.1, 0.1): 1, [0.1, 0.3]: 2}",
134 | "INFO:absl:[Hyperparameters] collection=train {'learning_rate': 0.1}",
135 | ])
136 |
137 |
138 | if __name__ == "__main__":
139 | tf.test.main()
140 |
--------------------------------------------------------------------------------
/clu/metric_writers/async_writer_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The CLU Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for AsyncWriter."""
16 |
17 | import time
18 | from unittest import mock
19 |
20 | from clu import asynclib
21 | from clu.metric_writers import async_writer
22 | from clu.metric_writers import interface
23 | import numpy as np
24 | import tensorflow as tf
25 |
26 |
27 | class AsyncWriterTest(tf.test.TestCase):
28 |
29 | def setUp(self):
30 | super().setUp()
31 | self.sync_writer = mock.create_autospec(interface.MetricWriter)
32 | self.writer = async_writer.AsyncWriter(self.sync_writer)
33 |
34 | def test_write_summaries_async(self):
35 | self.writer.write_summaries(
36 | 11,
37 | {"a": np.eye(3, dtype=np.uint8),
38 | "b": np.eye(2, dtype=np.float32)},
39 | {"a": np.ones((2, 3)).tobytes()})
40 | self.writer.flush()
41 | self.sync_writer.write_summaries.assert_called_with(
42 | step=11,
43 | values={"a": mock.ANY, "b": mock.ANY},
44 | metadata={"a": mock.ANY})
45 |
46 | def test_write_scalars_async(self):
47 | self.writer.write_scalars(0, {"a": 3, "b": 0.15})
48 | self.writer.write_scalars(2, {"a": 5, "b": 0.007})
49 | self.writer.flush()
50 | self.sync_writer.write_scalars.assert_has_calls([
51 | mock.call(step=0, scalars={
52 | "a": 3,
53 | "b": 0.15
54 | }),
55 | mock.call(step=2, scalars={
56 | "a": 5,
57 | "b": 0.007
58 | })
59 | ])
60 |
61 | def test_write_images(self):
62 | images = np.zeros((2, 28, 28, 3))
63 | self.writer.write_images(4, {"input_images": images})
64 | self.writer.flush()
65 | self.sync_writer.write_images.assert_called_with(4,
66 | {"input_images": mock.ANY})
67 |
68 | def test_write_videos(self):
69 | videos = np.zeros((2, 4, 28, 28, 3))
70 | self.writer.write_videos(4, {"input_videos": videos})
71 | self.writer.flush()
72 | self.sync_writer.write_videos.assert_called_with(4,
73 | {"input_videos": mock.ANY})
74 |
75 | def test_write_pointcloud(self):
76 | point_clouds = np.random.normal(0, 1, (1, 1024, 3)).astype(np.float32)
77 | point_colors = np.random.uniform(0, 1, (1, 1024, 3)).astype(np.float32)
78 | config = {
79 | "material": "PointCloudMaterial",
80 | "size": 0.09,
81 | }
82 | self.writer.write_pointcloud(
83 | step=0,
84 | point_clouds={"pcd": point_clouds},
85 | point_colors={"pcd": point_colors},
86 | configs={"config": config},
87 | )
88 | self.writer.flush()
89 | self.sync_writer.write_pointcloud.assert_called_with(
90 | step=0,
91 | point_clouds={"pcd": mock.ANY},
92 | point_colors={"pcd": mock.ANY},
93 | configs={"config": mock.ANY},
94 | )
95 |
96 | def test_write_texts(self):
97 | self.writer.write_texts(4, {"samples": "bla"})
98 | self.writer.flush()
99 | self.sync_writer.write_texts.assert_called_with(4, {"samples": "bla"})
100 |
101 | def test_ensure_flushes(self):
102 | with async_writer.ensure_flushes(self.writer) as writer:
103 | writer.write_scalars(0, {"a": 3, "b": 0.15})
104 | writer.write_scalars(2, {"a": 5, "b": 0.007})
105 | self.sync_writer.write_scalars.assert_has_calls([
106 | mock.call(step=0, scalars={
107 | "a": 3,
108 | "b": 0.15
109 | }),
110 | mock.call(step=2, scalars={
111 | "a": 5,
112 | "b": 0.007
113 | })
114 | ])
115 | self.sync_writer.flush.assert_called_once()
116 |
117 | def test_ensure_flushes_with_multiple_writers(self):
118 | sync_writer1 = mock.create_autospec(interface.MetricWriter)
119 | writer1 = async_writer.AsyncWriter(sync_writer1)
120 | sync_writer2 = mock.create_autospec(interface.MetricWriter)
121 | writer2 = async_writer.AsyncWriter(sync_writer2)
122 |
123 | with async_writer.ensure_flushes(writer1, writer2):
124 | writer1.write_scalars(0, {"a": 3, "b": 0.15})
125 | writer2.write_scalars(2, {"a": 5, "b": 0.007})
126 |
127 | sync_writer1.write_scalars.assert_has_calls(
128 | [mock.call(step=0, scalars={
129 | "a": 3,
130 | "b": 0.15
131 | })])
132 |
133 | sync_writer2.write_scalars.assert_has_calls(
134 | [mock.call(step=2, scalars={
135 | "a": 5,
136 | "b": 0.007
137 | })])
138 |
139 | sync_writer1.flush.assert_called_once()
140 | sync_writer2.flush.assert_called_once()
141 |
142 | def test_flush_before_close(self):
143 | self.writer.close()
144 | self.sync_writer.flush.assert_called()
145 | self.sync_writer.close.assert_called()
146 |
147 | def test_reraises_exception(self):
148 | self.sync_writer.write_scalars.side_effect = ValueError("foo")
149 | self.writer.write_scalars(0, {"a": 3, "b": 0.15})
150 | time.sleep(0.1)
151 | with self.assertRaisesRegex(asynclib.AsyncError, "Consider re-running"):
152 | self.writer.write_scalars(2, {"a": 5, "b": 0.007})
153 |
154 |
155 | if __name__ == "__main__":
156 | tf.test.main()
157 |
--------------------------------------------------------------------------------
/clu/asynclib.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The CLU Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Utilities for async function calls."""
16 |
17 | import collections
18 | import concurrent.futures
19 | import functools
20 | import sys
21 | import threading
22 | from typing import Callable, List, Optional
23 |
24 | from absl import logging
25 |
26 |
27 | class AsyncError(Exception):
28 | """An exception that wraps another exception that ocurred asynchronously."""
29 |
30 |
31 | class Pool:
32 | """Pool for wrapping functions to be executed asynchronously.
33 |
34 | Synopsis:
35 |
36 | from clu.internal import asynclib
37 |
38 | pool = asynclib.Pool()
39 | @pool
40 | def fn():
41 | time.sleep(1)
42 |
43 | future = fn()
44 | print(future.result())
45 | fn() # This could re-raise an exception from the first execution.
46 | print(len(pool)) # Would print "1" because there is one function in flight.
47 | pool.flush() # This could re-raise an exception from the second execution.
48 | """
49 |
50 | def __init__(self, thread_name_prefix: str = "",
51 | max_workers: Optional[int] = None):
52 | """Creates a new pool that decorates functions for async execution.
53 |
54 | Args:
55 | thread_name_prefix: See documentation of `ThreadPoolExecutor`.
56 | max_workers: See documentation of `ThreadPoolExecutor`. The default `None`
57 | optimizes for parallelizability using the number of CPU cores. If you
58 | specify `max_workers=1` you the async calls are executed in the same
59 | order they have been scheduled.
60 | """
61 | self._pool = concurrent.futures.ThreadPoolExecutor(
62 | max_workers=max_workers, thread_name_prefix=thread_name_prefix)
63 | self._max_workers = max_workers
64 | self._thread_name_prefix = thread_name_prefix
65 | self._errors = collections.deque()
66 | self._errors_mutex = threading.Lock()
67 | self._queue_length = 0
68 |
69 | def _reraise(self) -> None:
70 | if self._errors:
71 | with self._errors_mutex:
72 | exc_info = self._errors.popleft()
73 | exc = exc_info[1].with_traceback(exc_info[2])
74 | raise AsyncError(f"Error '{exc}' occurred ASYNCHRONOUSLY.") from exc
75 |
76 | def close(self) -> None:
77 | """Closes this pool & raise a pending exception (if needed)."""
78 | self._pool.shutdown(wait=True)
79 | self._reraise()
80 |
81 | def join(self) -> None:
82 | """Blocks until all functions are processed.
83 |
84 | The pool can be used to schedule more functions after calling this function,
85 | but there might be more exceptions
86 |
87 | Side-effect:
88 | If any of the functions raised an exception, then the first of these
89 | exceptions is reraised.
90 | """
91 | self._pool.shutdown(wait=True)
92 | self._pool = concurrent.futures.ThreadPoolExecutor(
93 | max_workers=self._max_workers,
94 | thread_name_prefix=self._thread_name_prefix)
95 | self._reraise()
96 |
97 | @property
98 | def queue_length(self) -> int:
99 | """Returns the number of functions that have not returned yet."""
100 | return self._queue_length
101 |
102 | @property
103 | def has_errors(self) -> bool:
104 | """Returns True if there are any pending errors."""
105 | return bool(self._errors)
106 |
107 | def clear_errors(self) -> List[Exception]:
108 | """Clears all pending errors and returns them as a (possibly empty) list."""
109 | with self._errors_mutex:
110 | errors, self._errors = self._errors, collections.deque()
111 | return list(errors)
112 |
113 | def __call__(self, fn: Callable): # pylint: disable=g-bare-generic
114 | """Returns an async version of fn.
115 |
116 | The function will be executed by this class's ThreadPoolExecutor. Any errors
117 | will be stored and re-raised next time any function is called that is
118 | executed through this pool.
119 |
120 | Note that even if there was a previous error, the function is still
121 | scheduled upon re-execution of the wrapper returned by this function.
122 |
123 | Args:
124 | fn: Function to be wrapped.
125 |
126 | Returns:
127 | An async version of `fn`. The return value of that async version will be
128 | a future (unless an exception was re-raised).
129 | """
130 |
131 | def inner(*args, **kwargs):
132 |
133 | def trap_errors(*args, **kwargs):
134 | try:
135 | return fn(*args, **kwargs)
136 | except Exception as e:
137 | with self._errors_mutex:
138 | self._errors.append(sys.exc_info())
139 | logging.exception("Error in producer thread for %s",
140 | self._thread_name_prefix)
141 | raise e
142 | finally:
143 | self._queue_length -= 1
144 |
145 | self._queue_length += 1
146 | if not self.has_errors:
147 | return self._pool.submit(trap_errors, *args, **kwargs)
148 | self._pool.submit(trap_errors, *args, **kwargs)
149 | self._reraise()
150 |
151 | if isinstance(fn.__name__, str):
152 | # Regular function.
153 | return functools.wraps(fn)(inner)
154 | # Mock or another weird function that fails with functools.wraps().
155 | return inner
156 |
--------------------------------------------------------------------------------
/clu/metric_writers/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The CLU Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Defines a generic write interface.
16 |
17 | The write helper accepts a MetricWriter object and a Mapping[str,
18 | clu.metrics.Metric], and automatically writes to the appropriate typed write
19 | method of the writer depending on the type of the metric.
20 | """
21 |
22 | # pylint: disable=g-importing-member
23 |
24 | import collections
25 | import getpass
26 | import os
27 | import re
28 | from typing import Any, List, Mapping, Optional, Tuple, Union
29 |
30 | from absl import flags
31 | from absl import logging
32 | from clu import values
33 | from clu.metric_writers.async_writer import AsyncMultiWriter
34 | from clu.metric_writers.interface import MetricWriter
35 | from clu.metric_writers.logging_writer import LoggingWriter
36 | from clu.metric_writers.multi_writer import MultiWriter
37 | from clu.metric_writers.summary_writer import SummaryWriter
38 | from etils import epath
39 | import jax.numpy as jnp
40 | import numpy as np
41 |
42 |
43 | FLAGS = flags.FLAGS
44 |
45 |
46 | def _is_scalar(value: Any) -> bool:
47 | if isinstance(value, values.Scalar) or isinstance(
48 | value, (int, float, np.number)
49 | ):
50 | return True
51 | if isinstance(value, (np.ndarray, jnp.ndarray)):
52 | return value.ndim == 0 or value.size <= 1
53 | return False
54 |
55 |
56 | def write_values(
57 | writer: MetricWriter,
58 | step: int,
59 | metrics: Mapping[
60 | str, Union[values.Value, values.ArrayType, values.ScalarType]
61 | ],
62 | ):
63 | """Writes all provided metrics.
64 |
65 | Allows providing a mapping of name to Value object, where each Value
66 | specifies a type. The appropriate write method can then be called depending
67 | on the type.
68 |
69 | Args:
70 | writer: MetricWriter object
71 | step: Step at which the arrays were generated.
72 | metrics: Mapping from name to clu.values.Value object.
73 | """
74 | writes = collections.defaultdict(dict)
75 | histogram_num_buckets = collections.defaultdict(int)
76 | for k, v in metrics.items():
77 | if isinstance(v, values.Summary):
78 | writes[
79 | (writer.write_summaries, frozenset({"metadata": v.metadata}.items()))
80 | ][k] = v.value
81 | elif _is_scalar(v):
82 | if isinstance(v, values.Scalar):
83 | writes[(writer.write_scalars, frozenset())][k] = v.value
84 | else:
85 | writes[(writer.write_scalars, frozenset())][k] = v
86 | elif isinstance(v, values.Image):
87 | writes[(writer.write_images, frozenset())][k] = v.value
88 | elif isinstance(v, values.Text):
89 | writes[(writer.write_texts, frozenset())][k] = v.value
90 | elif isinstance(v, values.HyperParam):
91 | writes[(writer.write_hparams, frozenset())][k] = v.value
92 | elif isinstance(v, values.Histogram):
93 | writes[(writer.write_histograms, frozenset())][k] = v.value
94 | histogram_num_buckets[k] = v.num_buckets
95 | elif isinstance(v, values.Audio):
96 | writes[(
97 | writer.write_audios,
98 | frozenset({"sample_rate": v.sample_rate}.items()),
99 | )][k] = v.value
100 | else:
101 | raise ValueError("Metric: ", k, " has unsupported value: ", v)
102 |
103 | for (fn, extra_args), vals in writes.items():
104 | if fn == writer.write_histograms:
105 | # for write_histograms, the num_buckets arg is a Dict indexed by name
106 | writer.write_histograms(step, vals, num_buckets=histogram_num_buckets)
107 | else:
108 | fn(step, vals, **dict(extra_args))
109 |
110 |
111 |
112 |
113 | def create_default_writer(
114 | logdir: Optional[epath.PathLike] = None,
115 | *,
116 | just_logging: bool = False,
117 | asynchronous: bool = True,
118 | collection: Optional[str] = None,
119 | ) -> MultiWriter:
120 | """Create the default writer for the platform.
121 |
122 | On most platforms this will create a MultiWriter that writes to multiple back
123 | ends (logging, TF summaries etc.).
124 |
125 | Args:
126 | logdir: Logging dir to use for TF summary files. If empty/None will the
127 | returned writer will not write TF summary files.
128 | just_logging: If True only use a LoggingWriter. This is useful in multi-host
129 | setups when only the first host should write metrics and all other hosts
130 | should only write to their own logs.
131 | default (None) will automatically determine if you # GOOGLE-INTERNAL have
132 | asynchronous: If True return an AsyncMultiWriter to not block when writing
133 | metrics.
134 | collection: A string which, if provided, provides an indication that the
135 | provided metrics should all be written to the same collection, or
136 | grouping.
137 |
138 | Returns:
139 | A `MetricWriter` according to the platform and arguments.
140 | """
141 | if just_logging:
142 | if asynchronous:
143 | return AsyncMultiWriter([LoggingWriter(collection=collection)])
144 | else:
145 | return MultiWriter([LoggingWriter(collection=collection)])
146 | writers = [LoggingWriter(collection=collection)]
147 | if logdir is not None:
148 | logdir = epath.Path(logdir)
149 | if collection is not None:
150 | logdir /= collection
151 | writers.append(SummaryWriter(os.fspath(logdir)))
152 | if asynchronous:
153 | return AsyncMultiWriter(writers)
154 | return MultiWriter(writers)
155 |
--------------------------------------------------------------------------------
/clu/metric_writers/logging_writer.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The CLU Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """MetricWriter that writes all values to INFO log."""
16 |
17 | from collections.abc import Mapping
18 | from typing import Any, Optional
19 |
20 | from absl import logging
21 | from clu.metric_writers import interface
22 | import numpy as np
23 |
24 | Array = interface.Array
25 | Scalar = interface.Scalar
26 |
27 |
28 | class LoggingWriter(interface.MetricWriter):
29 | """MetricWriter that writes all values to INFO log."""
30 |
31 | def __init__(self, collection: Optional[str] = None):
32 | if collection:
33 | self._collection_str = f" collection={collection}"
34 | else:
35 | self._collection_str = ""
36 |
37 | def write_summaries(
38 | self, step: int,
39 | values: Mapping[str, Array],
40 | metadata: Optional[Mapping[str, Any]] = None):
41 | logging.info("[%d]%s Got raw tensors: %s.", step, self._collection_str,
42 | {k: v.shape for k, v in values.items()})
43 |
44 | def write_scalars(self, step: int, scalars: Mapping[str, Scalar]):
45 | values = [
46 | f"{k}={v:.6g}" if isinstance(v, float) else f"{k}={v}"
47 | for k, v in sorted(scalars.items())
48 | ]
49 | logging.info("[%d]%s %s", step, self._collection_str, ", ".join(values))
50 |
51 | def write_images(self, step: int, images: Mapping[str, Array]):
52 | logging.info("[%d]%s Got images: %s.", step, self._collection_str,
53 | {k: v.shape for k, v in images.items()})
54 |
55 | def write_videos(self, step: int, videos: Mapping[str, Array]):
56 | logging.info("[%d]%s Got videos: %s.", step, self._collection_str,
57 | {k: v.shape for k, v in videos.items()})
58 |
59 | def write_audios(
60 | self, step: int, audios: Mapping[str, Array], *, sample_rate: int):
61 | logging.info("[%d]%s Got audios: %s.", step, self._collection_str,
62 | {k: v.shape for k, v in audios.items()})
63 |
64 | def write_texts(self, step: int, texts: Mapping[str, str]):
65 | logging.info("[%d]%s Got texts: %s.", step, self._collection_str, texts)
66 |
67 | def write_histograms(self,
68 | step: int,
69 | arrays: Mapping[str, Array],
70 | num_buckets: Optional[Mapping[str, int]] = None):
71 | num_buckets = num_buckets or {}
72 | for key, value in arrays.items():
73 | histo, bins = _compute_histogram_as_tf(
74 | np.asarray(value), num_buckets=num_buckets.get(key))
75 | if histo is not None:
76 | logging.info("[%d]%s Histogram for %r = {%s}", step,
77 | self._collection_str, key,
78 | _get_histogram_as_string(histo, bins))
79 |
80 | def write_pointcloud(
81 | self,
82 | step: int,
83 | point_clouds: Mapping[str, Array],
84 | *,
85 | point_colors: Mapping[str, Any] | None = None,
86 | configs: Mapping[str, str | float | bool | None] | None = None,
87 | ):
88 | logging.info(
89 | "[%d]%s Got point clouds: %s, point_colors: %s, configs: %s.",
90 | step,
91 | self._collection_str,
92 | {k: v.shape for k, v in point_clouds.items()},
93 | (
94 | {k: v.shape for k, v in point_colors.items()}
95 | if point_colors is not None
96 | else None
97 | ),
98 | configs,
99 | )
100 |
101 | def write_hparams(self, hparams: Mapping[str, Any]):
102 | logging.info("[Hyperparameters]%s %s", self._collection_str, hparams)
103 |
104 | def flush(self):
105 | logging.flush()
106 |
107 | def close(self):
108 | self.flush()
109 |
110 |
111 | def _compute_histogram_as_tf(
112 | array: np.ndarray,
113 | num_buckets: Optional[int] = None
114 | ) -> tuple[Optional[np.ndarray], Optional[np.ndarray]]:
115 | """Compute the histogram of the input array as TF would do.
116 |
117 | Args:
118 | array: Input data. The histogram is computed over the flattened array.
119 | num_buckets: The number of equal-width bins used to create the histogram.
120 |
121 | Returns:
122 | histo: A numpy array with the values of the histogram.
123 | bins: A numpy array with the bin edges (its length is length(histo)+1).
124 |
125 | If the histogram cannot be built because the array is empty, returns
126 | (None, None).
127 | """
128 | # See DEFAULT_BUCKET_COUNT in tensorboard/plugins/histogram/summary_v2.py
129 | num_buckets = num_buckets or 30
130 | if num_buckets < 2:
131 | logging.log_first_n(logging.WARNING,
132 | "num_buckets was automatically changed from %d to 2", 1,
133 | num_buckets)
134 | num_buckets = 2
135 |
136 | if array.size == 0:
137 | return None, None
138 |
139 | range_max = np.max(array)
140 | range_min = np.min(array)
141 | if range_max == range_min:
142 | histo = np.asarray([array.size], dtype=np.int64)
143 | bins = np.asarray([range_max - 0.5, range_max + 0.5], dtype=np.float64)
144 | else:
145 | histo, bins = np.histogram(
146 | array, bins=num_buckets, range=(range_min, range_max))
147 | bins = np.asarray(bins, dtype=np.float64)
148 |
149 | return histo, bins
150 |
151 |
152 | def _get_histogram_as_string(histo: np.ndarray, bins: np.ndarray):
153 | # First items are right-open (i.e. [a, b)).
154 | items = [
155 | f"[{bins[i]:.3g}, {bins[i+1]:.3g}): {count}"
156 | for i, count in enumerate(histo[:-1])
157 | ]
158 | # Last item is right-closed (i.e. [a, b]).
159 | items.append(f"[{bins[-2]:.3g}, {bins[-1]:.3g}]: {histo[-1]}")
160 | return ", ".join(items)
161 |
--------------------------------------------------------------------------------
/clu/metric_writers/async_writer.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The CLU Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """MetricWriter that writes metrics in a separate thread.
16 |
17 | - The order of the write calls is preserved.
18 | - Users need to all `flush()` or use the `ensure_flushes()` context to make sure
19 | that all metrics have been written.
20 | - Errors while writing in the background thread will be re-raised in the main
21 | thread on the next write_*() call.
22 | """
23 |
24 | from collections.abc import Mapping, Sequence
25 | import contextlib
26 | from typing import Any, Optional
27 |
28 | from clu import asynclib
29 |
30 | from clu.metric_writers import interface
31 | from clu.metric_writers import multi_writer
32 | import wrapt
33 |
34 | Array = interface.Array
35 | Scalar = interface.Scalar
36 |
37 |
38 | @wrapt.decorator
39 | def _wrap_exceptions(wrapped, instance, args, kwargs):
40 | del instance
41 | try:
42 | return wrapped(*args, **kwargs)
43 | except asynclib.AsyncError as e:
44 | raise asynclib.AsyncError(
45 | "Consider re-running the code without AsyncWriter (e.g. creating a "
46 | "writer using "
47 | "`clu.metric_writers.create_default_writer(asynchronous=False)`)"
48 | ) from e
49 |
50 |
51 | class AsyncWriter(interface.MetricWriter):
52 | """MetricWriter that performs write operations in a separate thread.
53 |
54 | All write operations will be executed in a background thread. If an exceptions
55 | occurs in the background thread it will be raised on the main thread on the
56 | call of one of the write_* methods.
57 |
58 | Use num_workers > 1 at your own risk, if the underlying writer is not
59 | thread-safe or does not expect out-of-order events, this can cause problems.
60 | If num_workers is None then the ThreadPool will use `os.cpu_count()`
61 | processes.
62 | """
63 |
64 | def __init__(self,
65 | writer: interface.MetricWriter,
66 | *,
67 | num_workers: Optional[int] = 1):
68 | super().__init__()
69 | self._writer = writer
70 | # By default, we have a thread pool with a single worker to ensure that
71 | # calls to the function are run in order (but in a background thread).
72 | self._num_workers = num_workers
73 | self._pool = asynclib.Pool(
74 | thread_name_prefix="AsyncWriter", max_workers=num_workers)
75 |
76 |
77 | @_wrap_exceptions
78 | def write_summaries(
79 | self, step: int,
80 | values: Mapping[str, Array],
81 | metadata: Optional[Mapping[str, Any]] = None):
82 | self._pool(self._writer.write_summaries)(
83 | step=step, values=values, metadata=metadata)
84 |
85 | @_wrap_exceptions
86 | def write_scalars(self, step: int, scalars: Mapping[str, Scalar]):
87 | self._pool(self._writer.write_scalars)(step=step, scalars=scalars)
88 |
89 | @_wrap_exceptions
90 | def write_images(self, step: int, images: Mapping[str, Array]):
91 | self._pool(self._writer.write_images)(step=step, images=images)
92 |
93 | @_wrap_exceptions
94 | def write_videos(self, step: int, videos: Mapping[str, Array]):
95 | self._pool(self._writer.write_videos)(step=step, videos=videos)
96 |
97 | @_wrap_exceptions
98 | def write_audios(
99 | self, step: int, audios: Mapping[str, Array], *, sample_rate: int):
100 | self._pool(self._writer.write_audios)(
101 | step=step, audios=audios, sample_rate=sample_rate)
102 |
103 | @_wrap_exceptions
104 | def write_texts(self, step: int, texts: Mapping[str, str]):
105 | self._pool(self._writer.write_texts)(step=step, texts=texts)
106 |
107 | @_wrap_exceptions
108 | def write_histograms(self,
109 | step: int,
110 | arrays: Mapping[str, Array],
111 | num_buckets: Optional[Mapping[str, int]] = None):
112 | self._pool(self._writer.write_histograms)(
113 | step=step, arrays=arrays, num_buckets=num_buckets)
114 |
115 | @_wrap_exceptions
116 | def write_pointcloud(
117 | self,
118 | step: int,
119 | point_clouds: Mapping[str, Array],
120 | *,
121 | point_colors: Mapping[str, Array] | None = None,
122 | configs: Mapping[str, str | float | bool | None] | None = None,
123 | ):
124 | self._pool(self._writer.write_pointcloud)(
125 | step=step,
126 | point_clouds=point_clouds,
127 | point_colors=point_colors,
128 | configs=configs,
129 | )
130 |
131 | @_wrap_exceptions
132 | def write_hparams(self, hparams: Mapping[str, Any]):
133 | self._pool(self._writer.write_hparams)(hparams=hparams)
134 |
135 | def flush(self):
136 | try:
137 | self._pool.join()
138 | finally:
139 | self._writer.flush()
140 |
141 | def close(self):
142 | try:
143 | self.flush()
144 | finally:
145 | self._writer.close()
146 |
147 |
148 | class AsyncMultiWriter(multi_writer.MultiWriter):
149 | """AsyncMultiWriter writes to multiple writes in a separate thread."""
150 |
151 | def __init__(self,
152 | writers: Sequence[interface.MetricWriter],
153 | *,
154 | num_workers: Optional[int] = 1):
155 | super().__init__([AsyncWriter(w, num_workers=num_workers) for w in writers])
156 |
157 |
158 | @contextlib.contextmanager
159 | def ensure_flushes(*writers: interface.MetricWriter):
160 | """Context manager which ensures that one or more writers are flushed."""
161 | try:
162 | # The caller should not need to use the yielded value, but we yield
163 | # the first writer to stay backwards compatible for a single writer.
164 | yield writers[0]
165 | finally:
166 | for writer in writers:
167 | writer.flush()
168 |
--------------------------------------------------------------------------------
/clu/parameter_overview_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The CLU Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for parameter overviews."""
16 |
17 | from absl.testing import absltest
18 | from clu import parameter_overview
19 | from flax import linen as nn
20 | import jax
21 | import jax.numpy as jnp
22 | import numpy as np
23 |
24 |
25 | EMPTY_PARAMETER_OVERVIEW = """+------+-------+-------+------+------+-----+
26 | | Name | Shape | Dtype | Size | Mean | Std |
27 | +------+-------+-------+------+------+-----+
28 | +------+-------+-------+------+------+-----+
29 | Total: 0 -- 0 bytes"""
30 |
31 | FLAX_CONV2D_PARAMETER_OVERVIEW = """+-------------+--------------+---------+------+
32 | | Name | Shape | Dtype | Size |
33 | +-------------+--------------+---------+------+
34 | | conv/bias | (2,) | float32 | 2 |
35 | | conv/kernel | (3, 3, 3, 2) | float32 | 54 |
36 | +-------------+--------------+---------+------+
37 | Total: 56 -- 224 bytes"""
38 |
39 | FLAX_CONV2D_PARAMETER_OVERVIEW_WITH_SHARDING = """+-------------+--------------+---------+------+----------+
40 | | Name | Shape | Dtype | Size | Sharding |
41 | +-------------+--------------+---------+------+----------+
42 | | conv/bias | (2,) | float32 | 2 | () |
43 | | conv/kernel | (3, 3, 3, 2) | float32 | 54 | () |
44 | +-------------+--------------+---------+------+----------+
45 | Total: 56 -- 224 bytes"""
46 |
47 | FLAX_CONV2D_PARAMETER_OVERVIEW_WITH_STATS = """+-------------+--------------+---------+------+------+-----+
48 | | Name | Shape | Dtype | Size | Mean | Std |
49 | +-------------+--------------+---------+------+------+-----+
50 | | conv/bias | (2,) | float32 | 2 | 1.0 | 0.0 |
51 | | conv/kernel | (3, 3, 3, 2) | float32 | 54 | 1.0 | 0.0 |
52 | +-------------+--------------+---------+------+------+-----+
53 | Total: 56 -- 224 bytes"""
54 |
55 | FLAX_CONV2D_PARAMETER_OVERVIEW_WITH_STATS_AND_SHARDING = """+-------------+--------------+---------+------+------+-----+----------+
56 | | Name | Shape | Dtype | Size | Mean | Std | Sharding |
57 | +-------------+--------------+---------+------+------+-----+----------+
58 | | conv/bias | (2,) | float32 | 2 | 1.0 | 0.0 | () |
59 | | conv/kernel | (3, 3, 3, 2) | float32 | 54 | 1.0 | 0.0 | () |
60 | +-------------+--------------+---------+------+------+-----+----------+
61 | Total: 56 -- 224 bytes"""
62 |
63 | FLAX_CONV2D_MAPPING_PARAMETER_OVERVIEW_WITH_STATS = """+--------------------+--------------+---------+------+------+-----+
64 | | Name | Shape | Dtype | Size | Mean | Std |
65 | +--------------------+--------------+---------+------+------+-----+
66 | | params/conv/bias | (2,) | float32 | 2 | 1.0 | 0.0 |
67 | | params/conv/kernel | (3, 3, 3, 2) | float32 | 54 | 1.0 | 0.0 |
68 | +--------------------+--------------+---------+------+------+-----+
69 | Total: 56 -- 224 bytes"""
70 |
71 |
72 | class CNN(nn.Module):
73 |
74 | @nn.compact
75 | def __call__(self, x):
76 | return nn.Conv(features=2, kernel_size=(3, 3), name="conv")(x)
77 |
78 |
79 | class JaxParameterOverviewTest(absltest.TestCase):
80 |
81 | def test_count_parameters_empty(self):
82 | self.assertEqual(0, parameter_overview.count_parameters({}))
83 |
84 | def test_count_parameters(self):
85 | rng = jax.random.PRNGKey(42)
86 | # Weights of a 2D convolution with 2 filters.
87 | variables = CNN().init(rng, jnp.zeros((2, 5, 5, 3)))
88 | # 3 * 3*3 * 2 + 2 (bias) = 56 parameters
89 | self.assertEqual(56,
90 | parameter_overview.count_parameters(variables["params"]))
91 |
92 | def test_get_parameter_overview_empty(self):
93 | self.assertEqual(EMPTY_PARAMETER_OVERVIEW,
94 | parameter_overview.get_parameter_overview({}))
95 | self.assertEqual(EMPTY_PARAMETER_OVERVIEW,
96 | parameter_overview.get_parameter_overview({"a": {}}))
97 |
98 | def test_get_parameter_overview(self):
99 | rng = jax.random.PRNGKey(42)
100 | # Weights of a 2D convolution with 2 filters.
101 | variables = CNN().init(rng, jnp.zeros((2, 5, 5, 3)))
102 | variables = jax.tree_util.tree_map(jnp.ones_like, variables)
103 | self.assertEqual(
104 | FLAX_CONV2D_PARAMETER_OVERVIEW,
105 | parameter_overview.get_parameter_overview(
106 | variables["params"], include_stats=False))
107 | self.assertEqual(
108 | FLAX_CONV2D_PARAMETER_OVERVIEW_WITH_STATS,
109 | parameter_overview.get_parameter_overview(variables["params"]))
110 | self.assertEqual(
111 | FLAX_CONV2D_MAPPING_PARAMETER_OVERVIEW_WITH_STATS,
112 | parameter_overview.get_parameter_overview(variables))
113 | # Add sharding with PartitionSpecs.
114 | mesh = jax.sharding.Mesh(np.asarray(jax.devices()), "d")
115 | sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
116 | variables = jax.jit(lambda x: x, out_shardings=sharding)(variables)
117 | self.assertEqual(
118 | FLAX_CONV2D_PARAMETER_OVERVIEW_WITH_SHARDING,
119 | parameter_overview.get_parameter_overview(
120 | variables["params"], include_stats="sharding"))
121 | self.assertEqual(
122 | FLAX_CONV2D_PARAMETER_OVERVIEW_WITH_STATS_AND_SHARDING,
123 | parameter_overview.get_parameter_overview(
124 | variables["params"], include_stats="global"))
125 |
126 | def test_get_parameter_overview_shape_dtype_struct(self):
127 | variables_shape_dtype_struct = jax.eval_shape(
128 | lambda: CNN().init(jax.random.PRNGKey(42), jnp.zeros((2, 5, 5, 3))))
129 | self.assertEqual(
130 | FLAX_CONV2D_PARAMETER_OVERVIEW,
131 | parameter_overview.get_parameter_overview(
132 | variables_shape_dtype_struct["params"], include_stats=False))
133 |
134 | def test_printing_bool(self):
135 | self.assertEqual(
136 | parameter_overview._default_table_value_formatter(True), "True")
137 | self.assertEqual(
138 | parameter_overview._default_table_value_formatter(False), "False")
139 |
140 |
141 | if __name__ == "__main__":
142 | absltest.main()
143 |
--------------------------------------------------------------------------------
/clu/metric_writers/utils_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The CLU Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for interface."""
16 | # pylint: disable=g-importing-member
17 |
18 | import itertools
19 | from typing import Any
20 | from unittest import mock
21 |
22 | from absl.testing import absltest
23 | from absl.testing import parameterized
24 | from clu import values
25 | from clu.metric_writers import utils
26 | from clu.metric_writers.async_writer import AsyncMultiWriter
27 | from clu.metric_writers.async_writer import AsyncWriter
28 | from clu.metric_writers.interface import MetricWriter
29 | from clu.metric_writers.logging_writer import LoggingWriter
30 | from clu.metric_writers.multi_writer import MultiWriter
31 | from clu.metric_writers.summary_writer import SummaryWriter
32 | import clu.metrics
33 | import flax.struct
34 | import jax.numpy as jnp
35 | import tensorflow as tf
36 |
37 |
38 | @flax.struct.dataclass
39 | class HistogramMetric(clu.metrics.Metric):
40 | value: jnp.ndarray
41 | num_buckets: int
42 |
43 | def compute_value(self):
44 | return values.Histogram(self.value, self.num_buckets)
45 |
46 |
47 | @flax.struct.dataclass
48 | class ImageMetric(clu.metrics.Metric):
49 | value: jnp.ndarray
50 |
51 | def compute_value(self):
52 | return values.Image(self.value)
53 |
54 |
55 | @flax.struct.dataclass
56 | class AudioMetric(clu.metrics.Metric):
57 | value: jnp.ndarray
58 | sample_rate: int
59 |
60 | def compute_value(self):
61 | return values.Audio(self.value, self.sample_rate)
62 |
63 |
64 | @flax.struct.dataclass
65 | class TextMetric(clu.metrics.Metric):
66 | value: str
67 |
68 | def compute_value(self):
69 | return values.Text(self.value)
70 |
71 |
72 | @flax.struct.dataclass
73 | class HyperParamMetric(clu.metrics.Metric):
74 | value: float
75 |
76 | def compute_value(self):
77 | return values.HyperParam(self.value)
78 |
79 |
80 | @flax.struct.dataclass
81 | class SummaryMetric(clu.metrics.Metric):
82 | value: jnp.ndarray
83 | metadata: Any
84 |
85 | def compute_value(self):
86 | return values.Summary(self.value, self.metadata)
87 |
88 |
89 | def _to_summary(metrics):
90 | return {k: v.value for k, v in metrics.items()}
91 |
92 |
93 | def _to_list_of_dicts(d):
94 | return [{k: v} for k, v in d.items()]
95 |
96 |
97 | class ONEOF(object):
98 | """ONEOF(options_list) check value in options_list."""
99 |
100 | def __init__(self, container):
101 | if not hasattr(container, "__contains__"):
102 | raise TypeError(f"{container!r} is not a container")
103 | if not container:
104 | raise ValueError(f"{container!r} is empty")
105 | self._c = container
106 |
107 | def __eq__(self, o):
108 | return o in self._c
109 |
110 | def __ne__(self, o):
111 | return o not in self._c
112 |
113 | def __repr__(self):
114 | return "".format(",".join(repr(i) for i in self._c))
115 |
116 |
117 | class MetricWriterTest(tf.test.TestCase, parameterized.TestCase):
118 |
119 | def test_write(self):
120 | writer = mock.Mock(spec_set=MetricWriter)
121 | step = 3
122 | num_buckets = 4
123 | sample_rate = 10
124 | scalar_metrics = {
125 | "loss": clu.metrics.Average.from_model_output(jnp.asarray([1, 2, 3])),
126 | "accuracy": clu.metrics.LastValue.from_model_output(jnp.asarray([5])),
127 | }
128 | image_metrics = {
129 | "image": ImageMetric(jnp.asarray([[4, 5], [1, 2]])),
130 | }
131 | histogram_metrics = {
132 | "hist": HistogramMetric(
133 | value=jnp.asarray([7, 8]), num_buckets=num_buckets
134 | ),
135 | "hist2": HistogramMetric(
136 | value=jnp.asarray([9, 10]), num_buckets=num_buckets
137 | ),
138 | }
139 | audio_metrics = {
140 | "audio": AudioMetric(
141 | value=jnp.asarray([1, 5]), sample_rate=sample_rate
142 | ),
143 | "audio2": AudioMetric(
144 | value=jnp.asarray([1, 5]), sample_rate=sample_rate + 2
145 | ),
146 | }
147 | text_metrics = {
148 | "text": TextMetric(value="hello"),
149 | }
150 | hparam_metrics = {
151 | "lr": HyperParamMetric(value=0.01),
152 | }
153 | summary_metrics = {
154 | "summary": SummaryMetric(
155 | value=jnp.asarray([2, 3, 10]), metadata="some info"
156 | ),
157 | "summary2": SummaryMetric(value=jnp.asarray([2, 3, 10]), metadata=5),
158 | }
159 | metrics = {
160 | **scalar_metrics,
161 | **image_metrics,
162 | **histogram_metrics,
163 | **audio_metrics,
164 | **text_metrics,
165 | **hparam_metrics,
166 | **summary_metrics,
167 | }
168 | metrics = {k: m.compute_value() for k, m in metrics.items()}
169 | utils.write_values(writer, step, metrics)
170 |
171 | writer.write_scalars.assert_called_once_with(
172 | step, {k: m.compute() for k, m in scalar_metrics.items()}
173 | )
174 | writer.write_images.assert_called_once_with(
175 | step, _to_summary(image_metrics)
176 | )
177 | writer.write_histograms.assert_called_once_with(
178 | step,
179 | _to_summary(histogram_metrics),
180 | num_buckets={k: v.num_buckets for k, v in histogram_metrics.items()},
181 | )
182 | writer.write_audios.assert_called_with(
183 | step,
184 | ONEOF(_to_list_of_dicts(_to_summary(audio_metrics))),
185 | sample_rate=ONEOF([sample_rate, sample_rate + 2]),
186 | )
187 | writer.write_texts.assert_called_once_with(step, _to_summary(text_metrics))
188 | writer.write_hparams.assert_called_once_with(
189 | step, _to_summary(hparam_metrics)
190 | )
191 | writer.write_summaries.assert_called_with(
192 | step,
193 | ONEOF(_to_list_of_dicts(_to_summary(summary_metrics))),
194 | metadata=ONEOF(["some info", 5]),
195 | )
196 |
197 |
198 | def test_create_default_writer_summary_writer_is_added(self):
199 | writer = utils.create_default_writer(
200 | logdir=self.get_temp_dir(), asynchronous=False
201 | )
202 | self.assertTrue(any(isinstance(w, SummaryWriter) for w in writer._writers))
203 | writer = utils.create_default_writer(logdir=None, asynchronous=False)
204 | self.assertFalse(any(isinstance(w, SummaryWriter) for w in writer._writers))
205 |
206 |
207 | if __name__ == "__main__":
208 | absltest.main()
209 |
--------------------------------------------------------------------------------
/clu/preprocess_spec_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The CLU Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import dataclasses
16 |
17 | from absl import logging
18 | from absl.testing import parameterized
19 | from clu import preprocess_spec
20 | import tensorflow as tf
21 |
22 | Features = preprocess_spec.Features
23 | SEED_KEY = preprocess_spec.SEED_KEY
24 |
25 |
26 | @dataclasses.dataclass(frozen=True)
27 | class ToFloat:
28 |
29 | def __call__(self, features: Features) -> Features:
30 | return {k: tf.cast(v, tf.float32) / 255.0 for k, v in features.items()}
31 |
32 |
33 | @dataclasses.dataclass(frozen=True)
34 | class Rescale:
35 |
36 | scale: int = 1
37 |
38 | def __call__(self, features: Features) -> Features:
39 | features["image"] *= self.scale
40 | features["segmentation_mask"] *= self.scale
41 | return features
42 |
43 |
44 | @dataclasses.dataclass(frozen=True)
45 | class AddRandomInteger(preprocess_spec.RandomMapTransform):
46 |
47 | def _transform(self, features, seed):
48 | features["x"] = tf.random.stateless_uniform([], seed)
49 | return features
50 |
51 |
52 | all_ops = lambda: preprocess_spec.get_all_ops(__name__)
53 |
54 |
55 | class PreprocessSpecTest(parameterized.TestCase, tf.test.TestCase):
56 | """Tests for parsing preprocessing op spec."""
57 |
58 | def test_no_arguments(self):
59 | op = preprocess_spec._parse_single_preprocess_op("rescale", dict(all_ops()))
60 | logging.info("op: %r", op)
61 | self.assertEqual(str(op), "Rescale(scale=1)")
62 |
63 | def test_positional_argument(self):
64 | op = preprocess_spec._parse_single_preprocess_op("rescale(2)",
65 | dict(all_ops()))
66 | logging.info("op: %r", op)
67 | self.assertEqual(str(op), "Rescale(scale=2)")
68 |
69 | def test_keyword_argument(self):
70 | op = preprocess_spec._parse_single_preprocess_op("rescale(scale=3)",
71 | dict(all_ops()))
72 | logging.info("op: %r", op)
73 | self.assertEqual(str(op), "Rescale(scale=3)")
74 |
75 | def test_invalid_op_name(self):
76 | with self.assertRaisesRegex(
77 | ValueError, r"'does_not_exist' is not available \(available ops: "
78 | r"\['add_random_integer', 'rescale', 'to_float'\]\)."):
79 | preprocess_spec._parse_single_preprocess_op("does_not_exist",
80 | dict(all_ops()))
81 |
82 | def test_invalid_spec(self):
83 | with self.assertRaisesRegex(
84 | ValueError, r"'rescale\)' is not a valid preprocess op spec."):
85 | preprocess_spec._parse_single_preprocess_op("rescale)", dict(all_ops()))
86 |
87 | def test_pos_and_kw_arg(self):
88 | with self.assertRaisesRegex(
89 | ValueError,
90 | r"Rescale'> given both as positional argument \(value: 2\) and keyword "
91 | r"argument \(value: 3\)."):
92 | preprocess_spec._parse_single_preprocess_op("rescale(2, scale=3)",
93 | dict(all_ops()))
94 |
95 | def test_parsing_empty_string(self):
96 | preprocess_fn = preprocess_spec.parse("", all_ops())
97 | self.assertEqual(
98 | str(preprocess_fn), "PreprocessFn(ops=[], only_jax_types=True)")
99 |
100 | def test_multi_op_spec(self):
101 | preprocess_fn = preprocess_spec.parse("to_float|rescale(3)", all_ops())
102 | logging.info("preprocess_fn: %r", preprocess_fn)
103 | self.assertEqual(str(preprocess_fn.ops), "[ToFloat(), Rescale(scale=3)]")
104 |
105 | def test_two_tensors(self):
106 | preprocess_fn = preprocess_spec.parse("rescale(scale=7)", all_ops())
107 | x = {"image": tf.constant(3), "segmentation_mask": tf.constant(2)}
108 | y = preprocess_fn(x)
109 | self.assertEqual(y, {
110 | "image": tf.constant(21),
111 | "segmentation_mask": tf.constant(14),
112 | })
113 |
114 | def test_only_jax_types(self):
115 | preprocess_fn = preprocess_spec.parse("", all_ops())
116 | x = {
117 | "image": tf.constant(2),
118 | # Strings are not supported.
119 | "label": tf.constant("bla"),
120 | # Sparse tensors are not supported.
121 | "foo": tf.sparse.eye(4),
122 | # Ragged tensors are not supported.
123 | "bar": tf.RaggedTensor.from_tensor([[1, 2, 3], [4, 5, 6]]),
124 | }
125 | y = preprocess_fn(x)
126 | self.assertEqual(y, {"image": tf.constant(2)})
127 |
128 | def test_only_jax_types_nested_inputs(self):
129 | preprocess_fn = preprocess_spec.parse("", all_ops())
130 | x = {
131 | "nested": {
132 | "not_allowed": tf.constant("bla"),
133 | "allowed": tf.constant(2),
134 | }
135 | }
136 | y = preprocess_fn(x)
137 | self.assertEqual(y, {"nested": {"allowed": tf.constant(2)}})
138 |
139 | def test_not_only_jax_types(self):
140 | preprocess_fn = preprocess_spec.parse("", all_ops(), only_jax_types=False)
141 | x = {"image": tf.constant(2), "label": tf.constant("bla")}
142 | y = preprocess_fn(x)
143 | self.assertEqual(y, x)
144 |
145 | def test_add_preprocess_fn(self):
146 | op1 = ToFloat()
147 | op2 = ToFloat()
148 | op3 = ToFloat()
149 | fn1 = preprocess_spec.PreprocessFn(ops=(op1, op2), only_jax_types=False)
150 | fn2 = preprocess_spec.PreprocessFn(ops=(op3,), only_jax_types=True)
151 | fn12 = fn1 + fn2
152 | # Note: `+` is not supported on Sequence[PreprocessOp]; need to use `list`.
153 | self.assertSequenceEqual(fn12.ops, list(fn1.ops) + list(fn2.ops))
154 | self.assertTrue(fn12.only_jax_types)
155 |
156 | def test_slice_preprocess_fn(self):
157 | op1 = ToFloat()
158 | op2 = Rescale()
159 | op3 = ToFloat()
160 | fn = preprocess_spec.PreprocessFn(ops=(op1, op2, op3), only_jax_types=True)
161 | self.assertEqual(fn[:-1].ops, (op1, op2))
162 | self.assertTrue(fn[:-1].only_jax_types)
163 | self.assertEqual(fn[1].ops, [op2])
164 | self.assertTrue(fn[1].only_jax_types)
165 |
166 | def test_random_map_transform(self):
167 | ds = tf.data.Dataset.from_tensor_slices(
168 | {SEED_KEY: [[1, 2], [3, 4], [1, 2]]})
169 | ds = ds.map(AddRandomInteger())
170 | actual = list(ds)
171 | print("actual:", actual)
172 | expect = [
173 | # Random number was generated and random seed changed.
174 | {
175 | "x": 0.8838011,
176 | SEED_KEY: [1105988140, 1738052849]
177 | },
178 | {
179 | "x": 0.33396423,
180 | SEED_KEY: [-1860230133, -671226999]
181 | },
182 | # Same random seed as first element creates same outcome.
183 | {
184 | "x": 0.8838011,
185 | SEED_KEY: [1105988140, 1738052849]
186 | },
187 | ]
188 | self.assertAllClose(actual, expect)
189 |
190 |
191 | if __name__ == "__main__":
192 | tf.test.main()
193 |
--------------------------------------------------------------------------------
/clu/metric_writers/interface.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The CLU Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Library for unify reporting model metrics across various logging formats.
16 |
17 | This library provides a MetricWriter for each logging format (SummyWriter,
18 | LoggingWriter, etc.) and composing MetricWriter to add support for asynchronous
19 | logging or writing to multiple formats.
20 | """
21 |
22 | import abc
23 | from collections.abc import Mapping
24 | from typing import Any, Optional, Union
25 |
26 | import jax.numpy as jnp
27 | import numpy as np
28 |
29 | Array = Union[np.ndarray, jnp.ndarray]
30 | Scalar = Union[int, float, np.number, np.ndarray, jnp.ndarray]
31 |
32 |
33 | class MetricWriter(abc.ABC):
34 | """MetricWriter inferface."""
35 |
36 | @abc.abstractmethod
37 | def write_summaries(
38 | self, step: int,
39 | values: Mapping[str, Array],
40 | metadata: Optional[Mapping[str, Any]] = None):
41 | """Saves an arbitrary tensor summary.
42 |
43 | Useful when working with custom plugins or constructing a summary directly.
44 |
45 | Args:
46 | step: Step at which the scalar values occurred.
47 | values: Mapping from tensor keys to tensors.
48 | metadata: Optional SummaryMetadata, as a proto or serialized bytes.
49 | Note that markdown formatting is rendered by tensorboard.
50 | """
51 |
52 | @abc.abstractmethod
53 | def write_scalars(self, step: int, scalars: Mapping[str, Scalar]):
54 | """Write scalar values for the step.
55 |
56 | Consecutive calls to this method can provide different sets of scalars.
57 | Repeated writes for the same metric at the same step are not allowed.
58 |
59 | Args:
60 | step: Step at which the scalar values occurred.
61 | scalars: Mapping from metric name to value.
62 | """
63 |
64 | @abc.abstractmethod
65 | def write_images(self, step: int, images: Mapping[str, Array]):
66 | """Write images for the step.
67 |
68 | Consecutive calls to this method can provide different sets of images.
69 | Repeated writes for the same image key at the same step are not allowed.
70 |
71 | Warning: Not all MetricWriter implementation support writing images!
72 |
73 | Args:
74 | step: Step at which the images occurred.
75 | images: Mapping from image key to images. Images should have the shape [N,
76 | H, W, C] or [H, W, C], where H is the height, W is the width and C the
77 | number of channels (1 or 3). N is the number of images that will be
78 | written. Image dimensions can differ between different image keys but
79 | not between different steps for the same image key.
80 | """
81 |
82 | @abc.abstractmethod
83 | def write_videos(self, step: int, videos: Mapping[str, Array]):
84 | """Write videos for the step.
85 |
86 | Warning: Logging only.
87 | Not all MetricWriter implementation support writing videos!
88 |
89 | Consecutive calls to this method can provide different sets of videos.
90 | Repeated writes for the same video key at the same step are not allowed.
91 |
92 |
93 | Args:
94 | step: Step at which the videos occurred.
95 | videos: Mapping from video key to videos. videos should have the shape
96 | [N, T, H, W, C] or [T, H, W, C], where T is time, H is the height,
97 | W is the width and C the number of channels (1 or 3). N is the number
98 | of videos that will be written. Video dimensions can differ between
99 | different video keys but not between different steps for the same
100 | video key.
101 | """
102 |
103 | @abc.abstractmethod
104 | def write_audios(
105 | self, step: int, audios: Mapping[str, Array], *, sample_rate: int):
106 | """Write audios for the step.
107 |
108 | Consecutive calls to this method can provide different sets of audios.
109 | Repeated writes for the same audio key at the same step are not allowed.
110 |
111 | Warning: Not all MetricWriter implementation support writing audios!
112 |
113 | Args:
114 | step: Step at which the audios occurred.
115 | audios: Mapping from audio key to audios. Audios should have the shape
116 | [N, T, C], where T is the time length and C the number of channels
117 | (1 - mono, 2 - stereo, >= 3 - surround; not all writers support any
118 | number of channels). N is the number of audios that will be written.
119 | Audio dimensions can differ between different audio keys but not between
120 | different steps for the same audio key. Values should be floating-point
121 | values in [-1, +1].
122 | sample_rate: Sample rate for the audios.
123 | """
124 |
125 | @abc.abstractmethod
126 | def write_texts(self, step: int, texts: Mapping[str, str]):
127 | """Writes text snippets for the step.
128 |
129 | Warning: Not all MetricWriter implementation support writing text!
130 |
131 | Args:
132 | step: Step at which the text snippets occurred.
133 | texts: Mapping from name to text snippet.
134 | """
135 |
136 | @abc.abstractmethod
137 | def write_histograms(self,
138 | step: int,
139 | arrays: Mapping[str, Array],
140 | num_buckets: Optional[Mapping[str, int]] = None):
141 | """Writes histograms for the step.
142 |
143 | Consecutive calls to this method can provide different sets of scalars.
144 | Repeated writes for the same metric at the same step are not allowed.
145 |
146 | Warning: Not all MetricWriter implementation support writing histograms!
147 |
148 | Args:
149 | step: Step at which the arrays were generated.
150 | arrays: Mapping from name to arrays to summarize.
151 | num_buckets: Number of buckets used to create the histogram of the arrays.
152 | The default number of buckets depends on the particular implementation
153 | of the MetricWriter.
154 | """
155 |
156 | def write_pointcloud(
157 | self,
158 | step: int,
159 | point_clouds: Mapping[str, Array],
160 | *,
161 | point_colors: Mapping[str, Array] | None = None,
162 | configs: Mapping[str, str | float | bool | None] | None = None,
163 | ):
164 | """Writes point cloud summaries.
165 |
166 | Args:
167 | step: Step at which the point cloud was generated.
168 | point_clouds: Mapping from point clouds key to point cloud of shape [N, 3]
169 | array of point coordinates.
170 | point_colors: Mapping from point colors key to [N, 3] array of point
171 | colors.
172 | configs: A dictionary of configuration options for the point cloud.
173 | """
174 | raise NotImplementedError()
175 |
176 | @abc.abstractmethod
177 | def write_hparams(self, hparams: Mapping[str, Any]):
178 | """Write hyper parameters.
179 |
180 | Do not call twice.
181 |
182 | Args:
183 | hparams: Flat mapping from hyper parameter name to value.
184 | """
185 |
186 | @abc.abstractmethod
187 | def flush(self):
188 | """Tells the MetricWriter to write out any cached values."""
189 |
190 | @abc.abstractmethod
191 | def close(self):
192 | """Flushes and closes the MetricWriter.
193 |
194 | Calling any method on MetricWriter after MetricWriter.close()
195 | is undefined behavior.
196 | """
197 |
--------------------------------------------------------------------------------
/clu/metric_writers/tf/summary_writer_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The CLU Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for SummaryWriter."""
16 |
17 | import collections
18 | import os
19 |
20 | from clu.metric_writers.tf import summary_writer
21 | import numpy as np
22 | import tensorflow as tf
23 |
24 | from tensorboard.plugins.hparams import plugin_data_pb2
25 |
26 |
27 | def _load_summaries_data(logdir):
28 | """Loads raw summaries data from events in a logdir."""
29 | paths = tf.io.gfile.glob(os.path.join(logdir, "events.out.tfevents.*"))
30 | data = collections.defaultdict(dict)
31 | metadata = collections.defaultdict(dict)
32 | for path in paths:
33 | for event in tf.compat.v1.train.summary_iterator(path):
34 | for value in event.summary.value:
35 | data[event.step][value.tag] = tf.make_ndarray(value.tensor)
36 | if value.HasField("metadata"):
37 | metadata[event.step][value.tag] = value.metadata.SerializeToString()
38 | return data, metadata
39 |
40 |
41 | def _load_histograms_data(logdir):
42 | """Loads tensor summaries from events in a logdir."""
43 | # Note: new versions of histograms don't use the HistogramProto type, but
44 | # they are written as tensors representing the bounds and counts of buckets,
45 | # with plugin_name = "histogram".
46 | paths = tf.io.gfile.glob(os.path.join(logdir, "events.out.tfevents.*"))
47 | data = {}
48 | for path in paths:
49 | for event in tf.compat.v1.train.summary_iterator(path):
50 | for value in event.summary.value:
51 | current_steps, current_tensors = data.get(value.tag, ([], []))
52 | data[value.tag] = (current_steps + [event.step],
53 | current_tensors + [tf.make_ndarray(value.tensor)])
54 | return {
55 | tag: (np.stack(steps), np.stack(tensors))
56 | for tag, (steps, tensors) in data.items()
57 | }
58 |
59 |
60 | def _load_scalars_data(logdir: str):
61 | """Loads scalar summaries from events in a logdir."""
62 | paths = tf.io.gfile.glob(os.path.join(logdir, "events.out.tfevents.*"))
63 | data = collections.defaultdict(dict)
64 | for path in paths:
65 | for event in tf.compat.v1.train.summary_iterator(path):
66 | for value in event.summary.value:
67 | data[event.step][value.tag] = tf.make_ndarray(value.tensor).flat[0]
68 |
69 | return data
70 |
71 |
72 | def _load_pointcloud_data(logdir: str):
73 | """Loads pointcloud summaries from events in a logdir."""
74 | paths = tf.io.gfile.glob(os.path.join(logdir, "events.out.tfevents.*"))
75 | data = collections.defaultdict(dict)
76 | for path in paths:
77 | for event in tf.compat.v1.train.summary_iterator(path):
78 | for value in event.summary.value:
79 | if value.metadata.plugin_data.plugin_name == "mesh":
80 | if "config" not in value.tag:
81 | data[event.step][value.tag] = tf.make_ndarray(value.tensor)
82 | else:
83 | data[event.step][value.tag] = value.metadata.plugin_data.content
84 | return data
85 |
86 |
87 | def _load_hparams(logdir: str):
88 | """Loads hparams summaries from events in a logdir."""
89 | paths = tf.io.gfile.glob(os.path.join(logdir, "events.out.tfevents.*"))
90 | # data = collections.defaultdict(dict)
91 | hparams = []
92 | for path in paths:
93 | for event in tf.compat.v1.train.summary_iterator(path):
94 | for value in event.summary.value:
95 | if value.metadata.plugin_data.plugin_name == "hparams":
96 | hparams.append(plugin_data_pb2.HParamsPluginData.FromString(
97 | value.metadata.plugin_data.content))
98 | return hparams
99 |
100 |
101 | class SummaryWriterTest(tf.test.TestCase):
102 |
103 | def setUp(self):
104 | super().setUp()
105 | self.logdir = self.get_temp_dir()
106 | self.writer = summary_writer.SummaryWriter(self.logdir)
107 |
108 | def test_write_summaries(self):
109 | self.writer.write_summaries(
110 | 11,
111 | {"a": np.eye(3, dtype=np.uint8),
112 | "b": np.eye(2, dtype=np.float32)},
113 | {"a": np.ones((2, 3)).tobytes()})
114 | self.writer.flush()
115 | data, metadata = _load_summaries_data(self.logdir)
116 | self.assertAllClose(
117 | data[11],
118 | {"a": np.eye(3, dtype=np.uint8), "b": np.eye(2, dtype=np.float32)})
119 | self.assertIn("a", metadata[11])
120 |
121 | def test_write_scalar(self):
122 | self.writer.write_scalars(11, {"a": 0.6, "b": 15})
123 | self.writer.write_scalars(20, {"a": 0.8, "b": 12})
124 | self.writer.flush()
125 | data = _load_scalars_data(self.logdir)
126 | self.assertAllClose(data[11], {"a": 0.6, "b": 15})
127 | self.assertAllClose(data[20], {"a": 0.8, "b": 12})
128 |
129 | def test_write_histograms(self):
130 | self.writer.write_histograms(
131 | 0, {
132 | "a": np.asarray([0.3, 0.1, 0.5, 0.7, 0.1]),
133 | "b": np.asarray([-0.1, 0.3, 0.2, 0.4, 0.4]),
134 | }, num_buckets={"a": 2, "b": 2})
135 | self.writer.write_histograms(
136 | 2, {
137 | "a": np.asarray([0.2, 0.4, 0.5, 0.1, -0.1]),
138 | "b": np.asarray([0.7, 0.3, 0.2, 0.1, 0.0]),
139 | }, num_buckets={"a": 2, "b": 2})
140 | self.writer.flush()
141 | data = _load_histograms_data(self.logdir)
142 | # In the histograms, each tuple represents
143 | # (bucket_min, bucket_max, bucket_count), where bucket_min is inclusive and
144 | # bucket_max is exclusive (except the last bucket_max which is inclusive).
145 | expected_histograms_a = [
146 | # Step 0.
147 | [(0.1, 0.4, 3), (0.4, 0.7, 2)],
148 | # Step 1.
149 | [(-0.1, 0.2, 2), (0.2, 0.5, 3)],
150 | ]
151 | self.assertAllClose(data["a"], ([0, 2], expected_histograms_a))
152 | expected_histograms_b = [
153 | # Step 0.
154 | [(-0.1, 0.15, 1), (0.15, 0.4, 4)],
155 | # Step 1.
156 | [(0.0, 0.35, 4), (0.35, 0.7, 1)],
157 | ]
158 | self.assertAllClose(data["b"], ([0, 2], expected_histograms_b))
159 |
160 | def test_write_pointcloud(self):
161 | point_clouds = np.random.normal(0, 1, (1, 1024, 3)).astype(np.float32)
162 | point_colors = np.random.uniform(0, 1, (1, 1024, 3)).astype(np.float32)
163 | config = {
164 | "material": "PointCloudMaterial",
165 | "size": 0.09,
166 | }
167 | self.writer.write_pointcloud(
168 | step=0,
169 | point_clouds={"pcd": point_clouds},
170 | point_colors={"pcd": point_colors},
171 | configs={"config": config},
172 | )
173 | self.writer.flush()
174 | data = _load_pointcloud_data(self.logdir)
175 | self.assertAllClose(data[0]["pcd_VERTEX"], point_clouds)
176 | self.assertAllClose(data[0]["pcd_COLOR"], point_colors)
177 |
178 | def test_hparams(self):
179 | self.writer.write_hparams(dict(batch_size=512, num_epochs=90))
180 | hparams = _load_hparams(self.logdir)
181 | self.assertLen(hparams, 1)
182 | hparams_dict = hparams[0].session_start_info.hparams
183 | self.assertLen(hparams_dict, 2)
184 | self.assertEqual(512, hparams_dict["batch_size"].number_value)
185 | self.assertEqual(90, hparams_dict["num_epochs"].number_value)
186 |
187 | def test_hparams_nested(self):
188 | config = {
189 | "list": [1, 2],
190 | "tuple": (3, 4),
191 | "subconfig": {
192 | "value": "a",
193 | "list": [10, 20],
194 | },
195 | }
196 | self.writer.write_hparams(config)
197 | hparams = _load_hparams(self.logdir)
198 | self.assertLen(hparams, 1)
199 | hparams_dict = hparams[0].session_start_info.hparams
200 | self.assertLen(hparams_dict, 7)
201 | self.assertEqual(1, hparams_dict["list.0"].number_value)
202 | self.assertEqual(2, hparams_dict["list.1"].number_value)
203 | self.assertEqual(3, hparams_dict["tuple.0"].number_value)
204 | self.assertEqual(4, hparams_dict["tuple.1"].number_value)
205 | self.assertEqual("a", hparams_dict["subconfig.value"].string_value)
206 | self.assertEqual(10, hparams_dict["subconfig.list.0"].number_value)
207 | self.assertEqual(20, hparams_dict["subconfig.list.1"].number_value)
208 |
209 | if __name__ == "__main__":
210 | tf.test.main()
211 |
--------------------------------------------------------------------------------
/clu/data/dataset_iterator.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The CLU Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Interface for dataset iterators.
16 |
17 | This module provides the DatasetIterator interface. This intention is that
18 | several frameworks providing datasets can implement this interface without
19 | knowing anything about the framework used for the model and the training loop.
20 | Likewise can training loops assume to get an DatasetIterator object and do not
21 | need to care about the specifics of the input pipelines.
22 |
23 | This modules does not depend on TensorFlow. The interface is generic and users
24 | don't have to use `tf.data` to construct a DatasetIterator. However, if they
25 | use `tf.data` they can simply wrap their `tf.data.Dataset` object with
26 | `TfDatasetIterator` to satisfy the interface.
27 | """
28 | from __future__ import annotations
29 |
30 | import abc
31 | import collections.abc
32 | import concurrent.futures
33 | import dataclasses
34 | import os
35 | import threading
36 | import typing
37 | from typing import Any, Mapping, Optional, Sequence, Tuple, TypeVar, Union
38 |
39 | from absl import logging
40 | from clu import asynclib
41 | from etils import epath
42 | import jax.numpy as jnp # Just for type checking.
43 | import numpy as np
44 | import numpy.typing as npt
45 |
46 | Array = Union[np.ndarray, jnp.ndarray]
47 | # Sizes of dimensions, None means the dimension size is unknown.
48 | Shape = Tuple[Optional[int], ...]
49 |
50 |
51 | @dataclasses.dataclass(frozen=True)
52 | class ArraySpec:
53 | """Describes an array via it's dtype and shape."""
54 | dtype: npt.DTypeLike
55 | shape: Shape
56 |
57 | def __repr__(self):
58 | return f"ArraySpec(dtype={np.dtype(self.dtype).name}, shape={self.shape})"
59 |
60 | def __str__(self):
61 | return f"{np.dtype(self.dtype).name}{list(self.shape)}"
62 |
63 |
64 | # Elements are PyTrees with NumPy/JAX arrays.
65 |
66 | # Anything can be a PyTree (it's either a container or leaf). We define
67 | # PyTree[T] as a PyTree where all leaves are of type T.
68 | # See https://jax.readthedocs.io/en/latest/pytrees.html.
69 | L = TypeVar("L") # pylint: disable=invalid-name
70 |
71 | PyTree = Union[L, Sequence["PyTree[L]"], Mapping[str, "PyTree[L]"]]
72 |
73 | Element = PyTree[Array]
74 | ElementSpec = PyTree[ArraySpec]
75 |
76 |
77 | class DatasetIterator(collections.abc.Iterator): # pytype: disable=ignored-abstractmethod
78 | """Generic interface for iterating over a dataset.
79 |
80 | This does not support __getitem__ since it cannot be implemented efficiently
81 | for many datasets. However datasets should allow starting the iterator from
82 | an arbitrary position.
83 |
84 | The element_spec property helps consumers to validate the input without
85 | reading data. This is similar to `tf.data.Dataset.element_spec`.
86 |
87 | Subclasses may decided to not read/write checkpoints if their state is
88 | sufficiently tracked externally (e.g. input pipelines that can be correctly
89 | restarted from the step number).
90 | """
91 |
92 | def get_next(self) -> Element:
93 | """Returns the next element."""
94 | logging.error(
95 | "DatasetIterator.get_next() is deprecated. Please use next().")
96 | # Subclasses should implement __next__() and remove calls to get_next().
97 | return next(self)
98 |
99 | def reset(self):
100 | """Resets the iterator back to the beginning."""
101 | raise NotImplementedError
102 |
103 | @property
104 | @abc.abstractmethod
105 | def element_spec(self) -> ElementSpec:
106 | """Returns the spec elements."""
107 | raise NotImplementedError()
108 |
109 | def save(self, filename: epath.Path):
110 | """Saves the state of the iterator to a file.
111 |
112 | This should only handle this iterator - not iterators in other processes.
113 |
114 | Args:
115 | filename: Name of the checkpoint.
116 | """
117 | raise NotImplementedError()
118 |
119 | def restore(self, filename: epath.Path):
120 | """Restores the iterator from a file (if available).
121 |
122 | This should only handle this iterator - not iterators in other processes.
123 |
124 | Args:
125 | filename: Name of the checkpoint.
126 | """
127 | raise NotImplementedError()
128 |
129 | def load(self, filename: epath.Path):
130 | logging.error("DatasetIterator.load() is deprecated. Please use restore().")
131 | return self.restore(filename)
132 |
133 |
134 | class TfDatasetIterator(DatasetIterator):
135 | """DatasetIterator for wrapping a `tf.data.Dataset`."""
136 |
137 | def __init__(self, dataset, *, checkpoint: bool):
138 | """Wraps `tf.data.Dataset` object into the `DatasetIterator` interface.
139 |
140 | Warning: Do not wrap this interator to do asynchronous prefetching if you
141 | use `checkpoint=True` (default). tf.data iterators must be saved()
142 | synchronously.
143 |
144 | Args:
145 | dataset: The dataset to wrap. Elements are converted to NumPy arrays but
146 | no additional prefetching is done. tf.data should automatically prefetch
147 | elements (to CPU memory).
148 | checkpoint: Whether to checkpoint the dataset iterator object.
149 | Checkpointing dataset iterators is required for handling job
150 | pre-emptions but depending on your input pipeline can result in very
151 | large checkpoints. If set to False save() and load() are no-ops.
152 | """
153 | try:
154 | # Since this is the only class in this module using TF we only import
155 | # tensorflow if needed.
156 | if typing.TYPE_CHECKING:
157 | tf = Any
158 | else:
159 | import tensorflow as tf # pylint: disable=g-import-not-at-top
160 | except ImportError as e:
161 | raise RuntimeError("When using TfDatasetIterator your binary must "
162 | "depend on //third_party/py/tensorflow.") from e
163 | self._tf = tf
164 |
165 | if not isinstance(dataset, tf.data.Dataset):
166 | raise ValueError("`dataset` must be an instance of `tf.data.Dataset` "
167 | f"but got {type(dataset)}.")
168 | self._dataset = dataset
169 | self._checkpoint = checkpoint
170 | assert self.element_spec # Verify element spec.
171 | self.iterator = iter(dataset)
172 | self._ckpt = tf.train.Checkpoint(ds=self.iterator)
173 |
174 | def get_next(self) -> Element:
175 | return next(self)
176 |
177 | def __next__(self) -> Element:
178 | return {k: np.asarray(v) for k, v in next(self.iterator).items()}
179 |
180 | def reset(self):
181 | self.iterator = iter(self._dataset)
182 | self._ckpt = self._tf.train.Checkpoint(ds=self.iterator)
183 |
184 | @property
185 | def element_spec(self) -> ElementSpec:
186 | element_spec = self._dataset.element_spec
187 | if not isinstance(element_spec, dict):
188 | raise ValueError("Dataset elements must be flat dictionaries but got "
189 | f"{element_spec}.")
190 | invalid_features = [
191 | k for k, v in element_spec.items()
192 | if not isinstance(v, self._tf.TensorSpec)
193 | ]
194 | if invalid_features:
195 | raise ValueError(f"Features {invalid_features} are not tensors. Dataset "
196 | "elements must be flat dictionaries of tensors.")
197 | return {
198 | k: ArraySpec(dtype=v.dtype.as_numpy_dtype, shape=tuple(v.shape))
199 | for k, v in element_spec.items()
200 | }
201 |
202 | def save(self, filename: epath.Path):
203 | if self._checkpoint:
204 | self._ckpt.write(os.fspath(filename))
205 |
206 | def restore(self, filename: epath.Path):
207 | if self._checkpoint:
208 | self._ckpt.read(os.fspath(filename)).assert_consumed()
209 |
210 |
211 | class PeekableDatasetIterator(DatasetIterator):
212 | """Wraps a DatasetIterator to provide a peek() method.
213 |
214 | This allows to look at the next element which can be useful in 2 scenarios:
215 | a) Get the structure of elements if the element_spec property is not
216 | supported.
217 | b) Request the next element without consuming it. This is especially handy to
218 | trigger reading of the first element while the model is being initialized.
219 |
220 | Example use case:
221 | >>> pool = clu.asynclib.Pool()
222 | >>> @pool
223 | >>> def warmup_input_pipeline():
224 | >>> train_iter.peek()
225 | >>> first_batch_ready = warmup_input_pipeline()
226 | >>> # Do other stuff...
227 | >>> first_batch_ready.result() # wait for input pipeline to be ready.
228 | """
229 |
230 | def __init__(self, it: DatasetIterator):
231 | self._it = it
232 | # Mutex for self._it.
233 | self._mutex = threading.Lock()
234 | self._peek: Optional[Element] = None
235 | self._pool = None
236 | self._peek_future = None
237 |
238 | def __next__(self) -> Element:
239 | with self._mutex:
240 | if self._peek is None:
241 | return next(self._it)
242 | peek = self._peek
243 | self._peek = None
244 | return peek
245 |
246 | def reset(self):
247 | with self._mutex:
248 | self._it.reset()
249 | self._peek = None
250 | self._pool = None
251 | self._peek_future = None
252 |
253 | @property
254 | def element_spec(self) -> ElementSpec:
255 | return self._it.element_spec
256 |
257 | def peek(self) -> Element:
258 | """Returns the next element without consuming it.
259 |
260 | This will get the next element from the underlying iterator. The element
261 | is stored and return on the next call of __next__().
262 |
263 | Returns:
264 | The next element.
265 | """
266 | if self._peek is None:
267 | self._peek = next(self)
268 | return self._peek
269 |
270 | def peek_async(self) -> concurrent.futures.Future[Element]:
271 | """Same as peek() but returns the Future of the element.
272 |
273 | Users can call this to warm up the iterator.
274 |
275 | Returns:
276 | Future with the next element. The element is also kept and returned on the
277 | next call of __next__().
278 | """
279 | with self._mutex:
280 | if self._peek_future is None:
281 | if self._pool is None:
282 | self._pool = asynclib.Pool(max_workers=1)
283 | self._peek_future = self._pool(self.peek)()
284 | return self._peek_future
285 |
286 | def save(self, filename: epath.Path):
287 | with self._mutex:
288 | self._it.save(filename)
289 |
290 | def restore(self, filename: epath.Path):
291 | with self._mutex:
292 | self._it.restore(filename)
293 |
--------------------------------------------------------------------------------
/clu/periodic_actions_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The CLU Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for perodic actions."""
16 |
17 | import tempfile
18 | import time
19 | from unittest import mock
20 |
21 | from absl.testing import absltest
22 | from absl.testing import parameterized
23 | from clu import periodic_actions
24 |
25 |
26 | class ReportProgressTest(parameterized.TestCase):
27 |
28 | def test_every_steps(self):
29 | hook = periodic_actions.ReportProgress(
30 | every_steps=4, every_secs=None, num_train_steps=10
31 | )
32 | t = time.monotonic()
33 | with self.assertLogs(level="INFO") as logs:
34 | self.assertFalse(hook(1, t))
35 | t += 0.11
36 | self.assertFalse(hook(2, t))
37 | t += 0.13
38 | self.assertFalse(hook(3, t))
39 | t += 0.12
40 | self.assertTrue(hook(4, t))
41 | # We did 1 step every 0.12s => 8.333 steps/s.
42 | self.assertEqual(
43 | logs.output,
44 | [
45 | "INFO:absl:Setting work unit notes: 8.3 steps/s, 40.0% (4/10),"
46 | " ETA: 0m"
47 | ],
48 | )
49 |
50 | def test_every_secs(self):
51 | hook = periodic_actions.ReportProgress(
52 | every_steps=None, every_secs=0.3, num_train_steps=10
53 | )
54 | t = time.monotonic()
55 | with self.assertLogs(level="INFO") as logs:
56 | self.assertFalse(hook(1, t))
57 | t += 0.11
58 | self.assertFalse(hook(2, t))
59 | t += 0.13
60 | self.assertFalse(hook(3, t))
61 | t += 0.12
62 | self.assertTrue(hook(4, t))
63 | # We did 1 step every 0.12s => 8.333 steps/s.
64 | self.assertEqual(
65 | logs.output,
66 | [
67 | "INFO:absl:Setting work unit notes: 8.3 steps/s, 40.0% (4/10),"
68 | " ETA: 0m"
69 | ],
70 | )
71 |
72 | def test_without_num_train_steps(self):
73 | report = periodic_actions.ReportProgress(every_steps=2)
74 | t = time.monotonic()
75 | with self.assertLogs(level="INFO") as logs:
76 | self.assertFalse(report(1, t))
77 | self.assertTrue(report(2, t + 0.12))
78 | # We did 1 step in 0.12s => 8.333 steps/s.
79 | self.assertEqual(
80 | logs.output, ["INFO:absl:Setting work unit notes: 8.3 steps/s"]
81 | )
82 |
83 | def test_with_persistent_notes(self):
84 | report = periodic_actions.ReportProgress(every_steps=2)
85 | report.set_persistent_notes("Hello world")
86 | t = time.monotonic()
87 | with self.assertLogs(level="INFO") as logs:
88 | self.assertFalse(report(1, t))
89 | self.assertTrue(report(2, t + 0.12))
90 | # We did 1 step in 0.12s => 8.333 steps/s.
91 | self.assertEqual(
92 | logs.output,
93 | ["INFO:absl:Setting work unit notes: Hello world\n8.3 steps/s"],
94 | )
95 |
96 | def test_unknown_cardinality(self):
97 | report = periodic_actions.ReportProgress(every_steps=2)
98 | t = time.monotonic()
99 | with self.assertLogs(level="INFO") as logs:
100 | self.assertFalse(report(1, t))
101 | self.assertTrue(report(2, t + 0.12))
102 | # We did 1 step in 0.12s => 8.333 steps/s.
103 | self.assertEqual(
104 | logs.output, ["INFO:absl:Setting work unit notes: 8.3 steps/s"]
105 | )
106 |
107 | def test_called_every_step(self):
108 | hook = periodic_actions.ReportProgress(every_steps=3, num_train_steps=10)
109 | t = time.monotonic()
110 | with self.assertRaisesRegex(
111 | ValueError, "PeriodicAction must be called after every step"
112 | ):
113 | hook(1, t)
114 | hook(11, t) # Raises exception.
115 |
116 | @parameterized.named_parameters(
117 | ("_nowait", False),
118 | ("_wait", True),
119 | )
120 | @mock.patch("time.monotonic")
121 | def test_named(self, wait_jax_async_dispatch, mock_time):
122 | mock_time.return_value = 0
123 | hook = periodic_actions.ReportProgress(
124 | every_steps=1, every_secs=None, num_train_steps=10
125 | )
126 |
127 | def _wait():
128 | # Here we depend on hook._executor=ThreadPoolExecutor(max_workers=1)
129 | hook._executor.submit(lambda: None).result()
130 |
131 | self.assertFalse(hook(1)) # Never triggers on first execution.
132 | with hook.timed("test1", wait_jax_async_dispatch):
133 | _wait()
134 | mock_time.return_value = 1
135 | _wait()
136 | with hook.timed("test2", wait_jax_async_dispatch):
137 | _wait()
138 | mock_time.return_value = 2
139 | _wait()
140 | with hook.timed("test1", wait_jax_async_dispatch):
141 | _wait()
142 | mock_time.return_value = 3
143 | _wait()
144 | mock_time.return_value = 4
145 | with self.assertLogs(level="INFO") as logs:
146 | self.assertTrue(hook(2))
147 | self.assertEqual(
148 | logs.output,
149 | [
150 | "INFO:absl:Setting work unit notes: 0.2 steps/s, 20.0% (2/10), ETA:"
151 | " 0m (0m : 50.0% test1, 25.0% test2)"
152 | ],
153 | )
154 |
155 | @mock.patch("time.monotonic")
156 | def test_write_metrics(self, time_mock):
157 | time_mock.return_value = 0
158 | writer_mock = mock.Mock()
159 | hook = periodic_actions.ReportProgress(
160 | every_steps=2, every_secs=None, writer=writer_mock
161 | )
162 | time_mock.return_value = 1
163 | hook(1)
164 | time_mock.return_value = 2
165 | hook(2)
166 | self.assertEqual(
167 | writer_mock.write_scalars.mock_calls,
168 | [
169 | mock.call(2, {"steps_per_sec": 1}),
170 | mock.call(2, {"uptime": 2}),
171 | ],
172 | )
173 |
174 |
175 | class DummyProfilerSession:
176 | """Dummy Profiler that records the steps at which sessions started/ended."""
177 |
178 | def __init__(self):
179 | self.step = None
180 | self.start_session_call_steps = []
181 | self.end_session_call_steps = []
182 |
183 | def start_session(self):
184 | self.start_session_call_steps.append(self.step)
185 |
186 | def end_session_and_get_url(self, tag):
187 | del tag
188 | self.end_session_call_steps.append(self.step)
189 |
190 |
191 | class ProfileTest(absltest.TestCase):
192 |
193 | @mock.patch.object(periodic_actions, "profiler", autospec=True)
194 | @mock.patch("time.monotonic")
195 | def test_every_steps(self, mock_time, mock_profiler):
196 | start_steps = []
197 | stop_steps = []
198 | step = 0
199 |
200 | def add_start_step(logdir):
201 | del logdir # unused
202 | start_steps.append(step)
203 |
204 | def add_stop_step():
205 | stop_steps.append(step)
206 |
207 | mock_profiler.start.side_effect = add_start_step
208 | mock_profiler.stop.side_effect = add_stop_step
209 | hook = periodic_actions.Profile(
210 | logdir=tempfile.mkdtemp(),
211 | num_profile_steps=2,
212 | profile_duration_ms=2_000,
213 | first_profile=3,
214 | every_steps=7,
215 | )
216 | for step in range(1, 18):
217 | mock_time.return_value = step - 0.5 if step == 9 else step
218 | hook(step)
219 | self.assertEqual([3, 7, 14], start_steps)
220 | # Note: profiling 7..10 instead of 7..9 because 7..9 took only 1.5 seconds.
221 | self.assertEqual([5, 10, 16], stop_steps)
222 |
223 |
224 | class ProfileAllHostsTest(absltest.TestCase):
225 |
226 | @mock.patch.object(periodic_actions, "profiler", autospec=True)
227 | def test_every_steps(self, mock_profiler):
228 | start_steps = []
229 | step = 0
230 |
231 | def profile_collect(logdir, callback, hosts, duration_ms):
232 | del logdir, callback, hosts, duration_ms # unused
233 | start_steps.append(step)
234 |
235 | mock_profiler.collect.side_effect = profile_collect
236 | hook = periodic_actions.ProfileAllHosts(
237 | logdir=tempfile.mkdtemp(),
238 | profile_duration_ms=2_000,
239 | first_profile=3,
240 | every_steps=7,
241 | )
242 | for step in range(1, 18):
243 | hook(step)
244 | self.assertEqual([3, 7, 14], start_steps)
245 |
246 |
247 | class PeriodicCallbackTest(absltest.TestCase):
248 |
249 | def test_every_steps(self):
250 | callback = mock.Mock()
251 | hook = periodic_actions.PeriodicCallback(
252 | every_steps=2, callback_fn=callback
253 | )
254 |
255 | for step in range(1, 10):
256 | hook(step, 3, remainder=step % 3)
257 |
258 | expected_calls = [
259 | mock.call(remainder=2, step=2, t=3),
260 | mock.call(remainder=1, step=4, t=3),
261 | mock.call(remainder=0, step=6, t=3),
262 | mock.call(remainder=2, step=8, t=3),
263 | ]
264 | self.assertListEqual(expected_calls, callback.call_args_list)
265 |
266 | @mock.patch("time.monotonic")
267 | def test_every_secs(self, mock_time):
268 | callback = mock.Mock()
269 | hook = periodic_actions.PeriodicCallback(every_secs=2, callback_fn=callback)
270 |
271 | for step in range(1, 10):
272 | mock_time.return_value = float(step)
273 | hook(step, remainder=step % 5)
274 | # Note: time will be initialized at 1 so hook runs at steps 4 & 7.
275 | expected_calls = [
276 | mock.call(remainder=4, step=4, t=4.0),
277 | mock.call(remainder=2, step=7, t=7.0),
278 | ]
279 | self.assertListEqual(expected_calls, callback.call_args_list)
280 |
281 | def test_on_steps(self):
282 | callback = mock.Mock()
283 | hook = periodic_actions.PeriodicCallback(on_steps=[8], callback_fn=callback)
284 |
285 | for step in range(1, 10):
286 | hook(step, remainder=step % 3)
287 |
288 | callback.assert_called_once_with(remainder=2, step=8, t=mock.ANY)
289 |
290 | def test_async_execution(self):
291 | out = []
292 |
293 | def cb(step, t):
294 | del t
295 | out.append(step)
296 |
297 | hook = periodic_actions.PeriodicCallback(
298 | every_steps=1, callback_fn=cb, execute_async=True
299 | )
300 | hook(0)
301 | hook(1)
302 | hook(2)
303 | hook(3)
304 | # Block till all the hooks have finished.
305 | hook.get_last_callback_result().result()
306 | # Check order of execution is preserved.
307 | self.assertListEqual(out, [0, 1, 2, 3])
308 |
309 | def test_error_async_is_forwarded(self):
310 |
311 | def cb(step, t):
312 | del step
313 | del t
314 | raise Exception
315 |
316 | hook = periodic_actions.PeriodicCallback(
317 | every_steps=1, callback_fn=cb, execute_async=True
318 | )
319 |
320 | hook(0)
321 |
322 | with self.assertRaises(Exception):
323 | hook(1)
324 |
325 | def test_function_without_step_and_time(self):
326 |
327 | # This must be used with pass_step_and_time=False.
328 | def cb():
329 | return 5
330 |
331 | hook = periodic_actions.PeriodicCallback(
332 | every_steps=1, callback_fn=cb, pass_step_and_time=False
333 | )
334 | hook(0)
335 | hook(1)
336 | self.assertEqual(hook.get_last_callback_result(), 5)
337 |
338 |
339 | if __name__ == "__main__":
340 | absltest.main()
341 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "[]"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright [yyyy] [name of copyright owner]
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
203 |
--------------------------------------------------------------------------------
/clu/parameter_overview.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The CLU Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Helper function for creating and logging JAX variable overviews."""
16 |
17 | from collections.abc import Callable, Mapping, Sequence
18 | import dataclasses
19 | from typing import Any
20 |
21 | from absl import logging
22 |
23 | import flax
24 | import jax
25 | import jax.numpy as jnp
26 | import numpy as np
27 |
28 | _ParamsContainer = dict[str, np.ndarray] | Mapping[str, Mapping[str, Any]]
29 |
30 |
31 | @dataclasses.dataclass
32 | class _ParamRow:
33 | name: str
34 | shape: tuple[int, ...]
35 | dtype: str
36 | size: int
37 |
38 |
39 | @dataclasses.dataclass
40 | class _ParamRowWithSharding(_ParamRow):
41 | sharding: tuple[int | None, ...] | str
42 |
43 |
44 | @dataclasses.dataclass
45 | class _ParamRowWithStats(_ParamRow):
46 | mean: float
47 | std: float
48 |
49 |
50 | @dataclasses.dataclass
51 | class _ParamRowWithStatsAndSharding(_ParamRowWithStats):
52 | sharding: tuple[int | None, ...] | str
53 |
54 |
55 | @jax.jit
56 | def _mean_std_jit(x):
57 | return jax.tree_util.tree_map(jnp.mean, x), jax.tree_util.tree_map(jnp.std, x)
58 |
59 |
60 | def _mean_std(x):
61 | mean = jax.tree_util.tree_map(lambda x: x.mean(), x)
62 | std = jax.tree_util.tree_map(lambda x: x.std(), x)
63 | return mean, std
64 |
65 |
66 | def flatten_dict(
67 | input_dict: dict[str, Any], *, prefix: str = "", delimiter: str = "/"
68 | ) -> dict[str, Any]:
69 | """Flattens the keys of a nested dictionary."""
70 | output_dict = {}
71 | for key, value in input_dict.items():
72 | nested_key = f"{prefix}{delimiter}{key}" if prefix else key
73 | if isinstance(value, (dict, flax.core.FrozenDict)):
74 | output_dict.update(
75 | flatten_dict(value, prefix=nested_key, delimiter=delimiter)
76 | )
77 | else:
78 | output_dict[nested_key] = value
79 | return output_dict
80 |
81 |
82 | def _count_parameters(params: _ParamsContainer) -> int:
83 | """Returns the count of variables for the module or parameter dictionary."""
84 | params = flatten_dict(params)
85 | return sum(np.prod(v.shape) for v in params.values() if v is not None)
86 |
87 |
88 | def _parameters_size(params: _ParamsContainer) -> int:
89 | """Returns total size (bytes) for the module or parameter dictionary."""
90 | params = flatten_dict(params)
91 | return sum(
92 | np.prod(v.shape) * v.dtype.itemsize
93 | for v in params.values()
94 | if v is not None
95 | )
96 |
97 |
98 | def count_parameters(params: _ParamsContainer) -> int:
99 | """Returns the count of variables for the module or parameter dictionary."""
100 |
101 | return _count_parameters(params)
102 |
103 |
104 | def _make_row(name, value) -> _ParamRow:
105 | if value is None:
106 | return _ParamRow(
107 | name=name,
108 | shape=(),
109 | dtype="",
110 | size=0,
111 | )
112 | return _ParamRow(
113 | name=name,
114 | shape=value.shape,
115 | dtype=str(value.dtype),
116 | size=int(np.prod(value.shape)),
117 | )
118 |
119 |
120 | def _make_row_with_sharding(name, value) -> _ParamRowWithSharding:
121 | row = _make_row(name, value)
122 | if hasattr(value, "sharding"):
123 | if hasattr(value.sharding, "spec"):
124 | sharding = tuple(value.sharding.spec)
125 | else:
126 | sharding = str(value.sharding)
127 | else:
128 | sharding = ()
129 | return _ParamRowWithSharding(**dataclasses.asdict(row), sharding=sharding)
130 |
131 |
132 | def _make_row_with_stats(name, value, mean, std) -> _ParamRowWithStats:
133 | row = _make_row(name, value)
134 | mean = mean or 0.0
135 | std = std or 0.0
136 | return _ParamRowWithStats(
137 | **dataclasses.asdict(row),
138 | mean=float(jax.device_get(mean)),
139 | std=float(jax.device_get(std)),
140 | )
141 |
142 |
143 | def _make_row_with_stats_and_sharding(
144 | name, value, mean, std
145 | ) -> _ParamRowWithStatsAndSharding:
146 | row = _make_row_with_sharding(name, value)
147 | return _ParamRowWithStatsAndSharding(
148 | **dataclasses.asdict(row),
149 | mean=float(jax.device_get(mean)),
150 | std=float(jax.device_get(std)),
151 | )
152 |
153 |
154 | def _get_parameter_rows(
155 | params: _ParamsContainer,
156 | *,
157 | include_stats: bool | str = False,
158 | ) -> list[_ParamRow]:
159 | """Returns information about parameters as a list of dictionaries.
160 |
161 | Args:
162 | params: Dictionary with parameters as NumPy arrays. The dictionary can be
163 | nested. Alternatively a `tf.Module` can be provided, in which case the
164 | `trainable_variables` of the module will be used.
165 | include_stats: If True, add columns with mean and std for each variable. If
166 | the string "sharding", add column a column with the sharding of the
167 | variable. If the string "global", params are sharded global arrays and
168 | this function assumes it is called on every host, i.e. can use
169 | collectives. The sharding of the variables is also added as a column.
170 |
171 | Returns:
172 | A list of `ParamRow`, or `ParamRowWithStats`, depending on the passed value
173 | of `include_stats`.
174 | """
175 | if not isinstance(params, (dict, flax.core.FrozenDict)):
176 | raise ValueError(
177 | f"Expected `params` to be a dictionary but got {type(params)}"
178 | )
179 |
180 | params = flatten_dict(params)
181 | if params:
182 | names, values = map(list, tuple(zip(*sorted(params.items()))))
183 | else:
184 | names, values = [], []
185 |
186 | match include_stats:
187 | case False:
188 | return jax.tree_util.tree_map(_make_row, names, values)
189 |
190 | case True:
191 | mean_and_std = _mean_std(values)
192 | return jax.tree_util.tree_map(
193 | _make_row_with_stats, names, values, *mean_and_std
194 | )
195 |
196 | case "global":
197 | mean_and_std = _mean_std_jit(values)
198 | return jax.tree_util.tree_map(
199 | _make_row_with_stats_and_sharding, names, values, *mean_and_std
200 | )
201 |
202 | case "sharding":
203 | return jax.tree_util.tree_map(_make_row_with_sharding, names, values)
204 |
205 | case _:
206 | raise ValueError(f"Unknown `include_stats`: {include_stats}")
207 |
208 |
209 | def _default_table_value_formatter(value):
210 | """Formats ints with "," between thousands and floats to 3 digits."""
211 | if isinstance(value, bool):
212 | return str(value)
213 | elif isinstance(value, int):
214 | return "{:,}".format(value)
215 | elif isinstance(value, float):
216 | return "{:.3}".format(value)
217 | else:
218 | return str(value)
219 |
220 |
221 | def make_table(
222 | rows: list[Any],
223 | *,
224 | column_names: Sequence[str] | None = None,
225 | value_formatter: Callable[[Any], str] = _default_table_value_formatter,
226 | max_lines: int | None = None,
227 | ) -> str:
228 | """Renders a list of rows to a table.
229 |
230 | Args:
231 | rows: List of dataclass instances of a single type (e.g. `ParamRow`).
232 | column_names: List of columns that that should be included in the output. If
233 | not provided, then the columns are taken from keys of the first row.
234 | value_formatter: Callable used to format cell values.
235 | max_lines: Don't render a table longer than this.
236 |
237 | Returns:
238 | A string representation of the table in the form:
239 |
240 | +---------+---------+
241 | | Col1 | Col2 |
242 | +---------+---------+
243 | | value11 | value12 |
244 | | value21 | value22 |
245 | +---------+---------+
246 | """
247 |
248 | if any(not dataclasses.is_dataclass(row) for row in rows):
249 | raise ValueError("Expected `rows` to be list of dataclasses")
250 | if len(set(map(type, rows))) > 1:
251 | raise ValueError("Expected elements of `rows` be of same type.")
252 |
253 | class Column:
254 |
255 | def __init__(self, name, values):
256 | self.name = name.capitalize()
257 | self.values = values
258 | self.width = max(len(v) for v in values + [name])
259 |
260 | if column_names is None:
261 | if not rows:
262 | return "(empty table)"
263 | column_names = [field.name for field in dataclasses.fields(rows[0])]
264 |
265 | columns = [
266 | Column(name, [value_formatter(getattr(row, name)) for row in rows])
267 | for name in column_names
268 | ]
269 |
270 | var_line_format = "|" + "".join(f" {{: <{c.width}s}} |" for c in columns)
271 | sep_line_format = var_line_format.replace(" ", "-").replace("|", "+")
272 | header = var_line_format.replace(">", "<").format(*[c.name for c in columns])
273 | separator = sep_line_format.format(*["" for c in columns])
274 |
275 | lines = [separator, header, separator]
276 | for i in range(len(rows)):
277 | if max_lines and len(lines) >= max_lines - 3:
278 | lines.append("[...]")
279 | break
280 | lines.append(var_line_format.format(*[c.values[i] for c in columns]))
281 | lines.append(separator)
282 |
283 | return "\n".join(lines)
284 |
285 |
286 | def _get_parameter_overview(
287 | params: _ParamsContainer,
288 | *,
289 | include_stats: bool | str = True,
290 | max_lines: int | None = None,
291 | ) -> str:
292 | """See get_parameter_overview()."""
293 | if include_stats is True and isinstance(params, (dict, flax.core.FrozenDict)): # pylint: disable=g-bool-id-comparison
294 | params = jax.device_get(params) # A no-op if already numpy array.
295 | rows = _get_parameter_rows(params, include_stats=include_stats)
296 | RowType = { # pylint: disable=invalid-name
297 | False: _ParamRow,
298 | True: _ParamRowWithStats,
299 | "global": _ParamRowWithStatsAndSharding,
300 | "sharding": _ParamRowWithSharding,
301 | }[include_stats]
302 | # Pass in `column_names` to enable rendering empty tables.
303 | column_names = [field.name for field in dataclasses.fields(RowType)]
304 | table = make_table(rows, max_lines=max_lines, column_names=column_names)
305 | total_weights = _count_parameters(params)
306 | total_size = _parameters_size(params)
307 | return table + f"\nTotal: {total_weights:,} -- {total_size:,} bytes"
308 |
309 |
310 | def get_parameter_overview(
311 | params: _ParamsContainer,
312 | *,
313 | include_stats: bool | str = True,
314 | max_lines: int | None = None,
315 | ) -> str:
316 | """Returns a string with variables names, their shapes, count.
317 |
318 | Args:
319 | params: Dictionary with parameters as NumPy arrays. The dictionary can be
320 | nested.
321 | include_stats: If True, add columns with mean and std for each variable. If
322 | the string "sharding", add column a column with the sharding of the
323 | variable. If the string "global", params are sharded global arrays and
324 | this function assumes it is called on every host, i.e. can use
325 | collectives. The sharding of the variables is also added as a column.
326 | max_lines: If not `None`, the maximum number of variables to include.
327 |
328 | Returns:
329 | A string with a table like in the example.
330 |
331 | +----------------+---------------+------------+
332 | | Name | Shape | Size |
333 | +----------------+---------------+------------+
334 | | FC_1/weights:0 | (63612, 1024) | 65,138,688 |
335 | | FC_1/biases:0 | (1024,) | 1,024 |
336 | | FC_2/weights:0 | (1024, 32) | 32,768 |
337 | | FC_2/biases:0 | (32,) | 32 |
338 | +----------------+---------------+------------+
339 | Total: 65,172,512
340 | """
341 |
342 | return _get_parameter_overview(
343 | params, include_stats=include_stats, max_lines=max_lines
344 | )
345 |
346 |
347 | def _log_parameter_overview(
348 | params: _ParamsContainer,
349 | *,
350 | include_stats: bool | str = True,
351 | max_lines: int | None = None,
352 | msg: str | None = None,
353 | jax_logging_process: int | None = None,
354 | ):
355 | """See log_parameter_overview()."""
356 |
357 | table = _get_parameter_overview(
358 | params, include_stats=include_stats, max_lines=max_lines
359 | )
360 | if jax_logging_process is None or jax_logging_process == jax.process_index():
361 | lines = [msg] if msg else []
362 | lines += table.split("\n")
363 | # The table can be too large to fit into one log entry.
364 | for i in range(0, len(lines), 80):
365 | logging.info("\n%s", "\n".join(lines[i : i + 80]))
366 |
367 |
368 | def log_parameter_overview(
369 | params: _ParamsContainer,
370 | *,
371 | include_stats: bool | str = True,
372 | max_lines: int | None = None,
373 | msg: str | None = None,
374 | jax_logging_process: int | None = None,
375 | ):
376 | """Writes a table with variables name and shapes to INFO log.
377 |
378 | See get_parameter_overview for details.
379 |
380 | Args:
381 | params: Dictionary with parameters as NumPy arrays. The dictionary can be
382 | nested.
383 | include_stats: If True, add columns with mean and std for each variable. If
384 | the string "global", params are sharded global arrays and this function
385 | assumes it is called on every host, i.e. can use collectives.
386 | max_lines: If not `None`, the maximum number of variables to include.
387 | msg: Message to be logged before the overview.
388 | jax_logging_process: Which JAX process ID should do the logging. None = all.
389 | Use this to avoid logspam when include_stats="global".
390 | """
391 |
392 | _log_parameter_overview(
393 | params,
394 | include_stats=include_stats,
395 | max_lines=max_lines,
396 | msg=msg,
397 | jax_logging_process=jax_logging_process,
398 | )
399 |
--------------------------------------------------------------------------------
/clu/deterministic_data_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The CLU Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Unit tests for the deterministic_data module."""
16 | import dataclasses
17 | import itertools
18 | import math
19 |
20 | from typing import Dict
21 | from unittest import mock
22 |
23 | from absl.testing import parameterized
24 | from clu import deterministic_data
25 | import jax
26 | from packaging import version
27 | import tensorflow as tf
28 | import tensorflow_datasets as tfds
29 |
30 | _use_split_info = version.parse("4.4.0") < version.parse(
31 | tfds.version.__version__)
32 |
33 |
34 | @dataclasses.dataclass
35 | class MyDatasetBuilder:
36 |
37 | name2len: Dict[str, int] # Number of examples per split.
38 |
39 | def as_dataset(self, split: tfds.core.ReadInstruction, shuffle_files: bool,
40 | read_config: tfds.ReadConfig, decoders) -> tf.data.Dataset:
41 | del shuffle_files, read_config, decoders
42 | if _use_split_info:
43 | split_infos = {
44 | k: tfds.core.SplitInfo(name=k, shard_lengths=[v], num_bytes=0)
45 | for k, v in self.name2len.items()
46 | }
47 | instructions = split.to_absolute(split_infos)
48 | else:
49 | instructions = split.to_absolute(self.name2len)
50 | assert len(instructions) == 1
51 | from_ = instructions[0].from_ or 0
52 | to = instructions[0].to or self.name2len[instructions[0].splitname]
53 | return tf.data.Dataset.range(from_, to).map(lambda i: {"index": i})
54 |
55 |
56 | @dataclasses.dataclass
57 | class FakeDatasetInfo:
58 | train_size: int = 9
59 | test_size: int = 8
60 |
61 | @property
62 | def splits(self):
63 | return {
64 | "train": tfds.core.SplitInfo("train", [self.train_size], 0),
65 | "test": tfds.core.SplitInfo("test", [self.test_size], 0)
66 | }
67 |
68 |
69 | class DeterministicDataTest(tf.test.TestCase, parameterized.TestCase):
70 | """Tests for deterministic_data module."""
71 |
72 | @parameterized.parameters(
73 | (9, 0, 1, True, "test[0:9]"),
74 | (9, 0, 2, True, "test[0:4]"),
75 | (9, 1, 2, True, "test[4:8]"), # Last example gets dropped.
76 | (9, 0, 3, True, "test[0:3]"),
77 | (9, 1, 3, True, "test[3:6]"),
78 | (9, 2, 3, True, "test[6:9]"),
79 | (9, 0, 1, False, "test[0:9]"),
80 | (9, 0, 2, False, "test[0:5]"), # First host gets an extra example.
81 | (9, 1, 2, False, "test[5:9]"),
82 | (8, 0, 3, False, "test[0:3]"), # First 2 hosts get 1 example each.
83 | (8, 1, 3, False, "test[3:6]"),
84 | (8, 2, 3, False, "test[6:8]"),
85 | )
86 | def test_get_read_instruction_for_host_deprecated(self, num_examples: int,
87 | host_id: int,
88 | host_count: int,
89 | drop_remainder: bool,
90 | expected_spec: str):
91 | expected = tfds.core.ReadInstruction.from_spec(expected_spec)
92 | actual = deterministic_data.get_read_instruction_for_host(
93 | "test",
94 | num_examples,
95 | host_id=host_id,
96 | host_count=host_count,
97 | drop_remainder=drop_remainder)
98 | if _use_split_info:
99 | split_infos = {
100 | "test": tfds.core.SplitInfo(
101 | name="test",
102 | shard_lengths=[9],
103 | num_bytes=0,
104 | )}
105 | else:
106 | split_infos = {"test": 9}
107 | self.assertEqual(
108 | expected.to_absolute(split_infos), actual.to_absolute(split_infos))
109 |
110 | @parameterized.parameters(
111 | # host_id, host_count, drop_remainder, spec, exected_spec_for_host
112 | # train split has 9 examples.
113 | (0, 1, True, "train", "train[0:9]"),
114 | (0, 2, True, "train", "train[0:4]"),
115 | (1, 2, True, "train", "train[4:8]"), # Last example gets dropped.
116 | (0, 3, True, "train", "train[0:3]"),
117 | (1, 3, True, "train", "train[3:6]"),
118 | (2, 3, True, "train", "train[6:9]"),
119 | (0, 1, False, "train", "train[0:9]"),
120 | (0, 2, False, "train", "train[0:5]"), # First host gets an extra example.
121 | (1, 2, False, "train", "train[5:9]"),
122 | # test split has 8 examples.
123 | (0, 3, False, "test", "test[0:3]"), # First 2 hosts get 1 example each.
124 | (1, 3, False, "test", "test[3:6]"),
125 | (2, 3, False, "test", "test[6:8]"),
126 | # Subsplits.
127 | (0, 2, True, "train[:50%]", "train[0:2]"),
128 | (1, 2, True, "train[:50%]", "train[2:4]"),
129 | (0, 2, True, "train[3:7]", "train[3:5]"),
130 | (1, 2, True, "train[3:7]", "train[5:7]"),
131 | (0, 2, True, "train[3:8]", "train[3:5]"), # Last example gets dropped.
132 | (1, 2, True, "train[3:8]", "train[5:7]"),
133 | # 2 splits.
134 | (0, 2, True, "train[3:7]+test", "train[3:5]+test[0:4]"),
135 | (1, 2, True, "train[3:7]+test", "train[5:7]+test[4:8]"),
136 | # First host gets an extra example.
137 | (0, 2, False, "train[3:8]+test[:5]", "train[3:6]+test[0:3]"),
138 | (1, 2, False, "train[3:8]+test[:5]", "train[6:8]+test[3:5]"),
139 | )
140 | def test_get_read_instruction_for_host(self, host_id: int, host_count: int,
141 | drop_remainder: bool, spec: str,
142 | expected_spec_for_host: str):
143 |
144 | actual_spec_for_host = deterministic_data.get_read_instruction_for_host(
145 | spec,
146 | dataset_info=FakeDatasetInfo(),
147 | host_id=host_id,
148 | host_count=host_count,
149 | drop_remainder=drop_remainder)
150 | expected_spec_for_host = tfds.core.ReadInstruction.from_spec(
151 | expected_spec_for_host)
152 | self.assertEqual(str(actual_spec_for_host), str(expected_spec_for_host))
153 |
154 | @parameterized.parameters(
155 | # host_id, host_count, balance_remainder, spec, exected_spec_for_host
156 | # test split has 10 examples.
157 | (0, 1, True, "test", "test[0:10]"),
158 | (0, 1, False, "test", "test[0:10]"),
159 | (0, 4, True, "test", "test[0:3]"),
160 | (1, 4, True, "test", "test[3:6]"),
161 | (2, 4, True, "test", "test[6:8]"),
162 | (3, 4, True, "test", "test[8:10]"),
163 | (0, 4, False, "test", "test[0:4]"),
164 | (1, 4, False, "test", "test[4:6]"),
165 | (2, 4, False, "test", "test[6:8]"),
166 | (3, 4, False, "test", "test[8:10]"),
167 | )
168 | def test_get_read_instruction_balance_remainder(self, host_id: int,
169 | host_count: int,
170 | balance_remainder: bool,
171 | spec: str,
172 | expected_spec_for_host: str):
173 | actual_spec_for_host = deterministic_data.get_read_instruction_for_host(
174 | spec,
175 | dataset_info=FakeDatasetInfo(test_size=10),
176 | host_id=host_id,
177 | host_count=host_count,
178 | remainder_options=deterministic_data.RemainderOptions
179 | .BALANCE_ON_PROCESSES if balance_remainder else
180 | deterministic_data.RemainderOptions.ON_FIRST_PROCESS)
181 | expected_spec_for_host = tfds.core.ReadInstruction.from_spec(
182 | expected_spec_for_host)
183 | self.assertEqual(str(actual_spec_for_host), str(expected_spec_for_host))
184 |
185 | @parameterized.parameters(
186 | (0, 0), # No hosts.
187 | (1, 1), # Only one host (host_id is zero-based.
188 | (-1, 1), # Negative host_id.
189 | (5, 2), # host_id bigger than number of hosts.
190 | )
191 | def test_get_read_instruction_for_host_fails(self, host_id: int,
192 | host_count: int):
193 | with self.assertRaises(ValueError):
194 | deterministic_data.get_read_instruction_for_host(
195 | "test", 11, host_id=host_id, host_count=host_count)
196 |
197 | def test_preprocess_with_per_example_rng(self):
198 |
199 | def preprocess_fn(features):
200 | features["b"] = tf.random.stateless_uniform([], features["rng"])
201 | return features
202 |
203 | rng = jax.random.PRNGKey(42)
204 | ds_in = tf.data.Dataset.from_tensor_slices({"a": [37.2, 31.2, 39.0]})
205 | ds_out = deterministic_data._preprocess_with_per_example_rng(
206 | ds_in, preprocess_fn, rng=rng)
207 | self.assertAllClose([
208 | {
209 | "a": 37.2,
210 | "b": 0.79542184
211 | },
212 | {
213 | "a": 31.2,
214 | "b": 0.45482683
215 | },
216 | {
217 | "a": 39.0,
218 | "b": 0.85335636
219 | },
220 | ], list(ds_out))
221 |
222 | @parameterized.parameters(*itertools.product([2, "auto"], [True, False]))
223 | def test_create_dataset_padding(self, pad_up_to_batches, cardinality):
224 | dataset_builder = mock.Mock()
225 | dataset = tf.data.Dataset.from_tensor_slices(
226 | dict(x=tf.ones((12, 10)), y=tf.ones(12)))
227 | dataset_builder.as_dataset.return_value = dataset
228 | batch_dims = (2, 5)
229 | ds = deterministic_data.create_dataset(
230 | dataset_builder,
231 | split="(ignored)",
232 | batch_dims=batch_dims,
233 | num_epochs=1,
234 | shuffle=False,
235 | pad_up_to_batches=pad_up_to_batches,
236 | cardinality=12 if cardinality else None,
237 | )
238 | ds_iter = iter(ds)
239 | self.assertAllClose(
240 | dict(
241 | x=tf.ones((2, 5, 10)),
242 | y=tf.ones((2, 5)),
243 | mask=tf.ones((2, 5), bool),
244 | ), next(ds_iter))
245 | self.assertAllClose(
246 | dict(
247 | x=tf.reshape(
248 | tf.concat([tf.ones(
249 | (2, 10)), tf.zeros((8, 10))], axis=0), (2, 5, 10)),
250 | y=tf.reshape(tf.concat([tf.ones(2), tf.zeros(8)], axis=0), (2, 5)),
251 | mask=tf.reshape(
252 | tf.concat(
253 | [tf.ones(2, bool), tf.zeros(8, bool)], axis=0), (2, 5)),
254 | ), next(ds_iter))
255 | with self.assertRaises(StopIteration):
256 | next(ds_iter)
257 |
258 | def test_create_dataset_padding_raises_error_cardinality(self):
259 | dataset_builder = mock.Mock()
260 | dataset = tf.data.Dataset.from_tensor_slices(
261 | dict(x=tf.ones((12, 10)), y=tf.ones(12)))
262 | dataset = dataset.filter(lambda x: True)
263 | dataset_builder.as_dataset.return_value = dataset
264 | batch_dims = (2, 5)
265 | with self.assertRaisesRegex(
266 | ValueError,
267 | r"^Cannot determine dataset cardinality."):
268 | deterministic_data.create_dataset(
269 | dataset_builder,
270 | split="(ignored)",
271 | batch_dims=batch_dims,
272 | num_epochs=1,
273 | shuffle=False,
274 | pad_up_to_batches=2,
275 | cardinality=None,
276 | )
277 |
278 | def test_pad_dataset(self):
279 | dataset = tf.data.Dataset.from_tensor_slices(
280 | dict(x=tf.ones((12, 10)), y=tf.ones(12)))
281 | padded_dataset = deterministic_data.pad_dataset(
282 | dataset, batch_dims=[20], pad_up_to_batches=2, cardinality=12)
283 | self.assertAllClose(
284 | dict(
285 | x=tf.concat([tf.ones(
286 | (12, 10)), tf.zeros((8, 10))], axis=0),
287 | y=tf.concat([tf.ones(12), tf.zeros(8)], axis=0),
288 | mask=tf.concat(
289 | [tf.ones(12, bool), tf.zeros(8, bool)], axis=0)),
290 | next(iter(padded_dataset.batch(20))))
291 |
292 | def test_pad_nested_dataset(self):
293 | dataset = tf.data.Dataset.from_tensor_slices(
294 | {"x": {"z": (tf.ones((12, 10)), tf.ones(12))},
295 | "y": tf.ones((12, 4))})
296 |
297 | def expected(*dims):
298 | return tf.concat([tf.ones((12,) + dims), tf.zeros((8,) + dims)], axis=0)
299 |
300 | padded_dataset = deterministic_data.pad_dataset(
301 | dataset, batch_dims=[20], pad_up_to_batches=2, cardinality=12)
302 | self.assertAllClose(
303 | {"x": {"z": (expected(10), expected())},
304 | "y": expected(4),
305 | "mask": tf.concat([tf.ones(12, bool), tf.zeros(8, bool)], axis=0)},
306 | next(iter(padded_dataset.batch(20))))
307 |
308 | @parameterized.parameters(*itertools.product(range(20), range(1, 4)))
309 | def test_same_cardinality_on_all_hosts(self, num_examples: int,
310 | host_count: int):
311 | builder = MyDatasetBuilder({"train": num_examples})
312 | cardinalities = []
313 | for host_id in range(host_count):
314 | split = deterministic_data.get_read_instruction_for_host(
315 | split="train",
316 | num_examples=num_examples,
317 | host_id=host_id,
318 | host_count=host_count,
319 | drop_remainder=True)
320 | ds = deterministic_data.create_dataset(
321 | builder, split=split, batch_dims=[2], shuffle=False, num_epochs=1)
322 | cardinalities.append(ds.cardinality().numpy().item())
323 | self.assertLen(set(cardinalities), 1)
324 |
325 | @parameterized.parameters(*itertools.product(range(20), range(1, 4)))
326 | def test_same_cardinality_on_all_hosts_with_pad(self, num_examples: int,
327 | host_count: int):
328 | builder = MyDatasetBuilder({"train": num_examples})
329 | # All hosts should have the same number of batches.
330 | batch_size = 2
331 | pad_up_to_batches = int(math.ceil(num_examples / (batch_size * host_count)))
332 | assert pad_up_to_batches * batch_size * host_count >= num_examples
333 | cardinalities = []
334 | for host_id in range(host_count):
335 | split = deterministic_data.get_read_instruction_for_host(
336 | split="train",
337 | num_examples=num_examples,
338 | host_id=host_id,
339 | host_count=host_count,
340 | drop_remainder=False)
341 | ds = deterministic_data.create_dataset(
342 | builder,
343 | split=split,
344 | batch_dims=[batch_size],
345 | shuffle=False,
346 | num_epochs=1,
347 | pad_up_to_batches=pad_up_to_batches)
348 | cardinalities.append(ds.cardinality().numpy().item())
349 | self.assertLen(set(cardinalities), 1)
350 |
351 |
352 | if __name__ == "__main__":
353 | tf.test.main()
354 |
--------------------------------------------------------------------------------
/clu/preprocess_spec.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The CLU Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Library for parsing a preprocessing spec.
16 |
17 | A preprocessing spec is a list of preprocessing ops separated by '|' that can be
18 | applied sequentially as a preprocessing function. The preprocessing ops are
19 | provided as input and must implement the PreprocessOp protocol. While not
20 | strictly required we also recommend annotating preprocess ops as dataclasses.
21 |
22 | By convention the preprocessing function operates on dictionaries of features.
23 | Each op can change the dictionary by modifying, adding or removing dictionary
24 | entries. Dictionary entries should be tensors, keys should be strings.
25 | (For common data types we recommend using the feature keys used in TFDS.)
26 |
27 | Example spec: 'fn1|fn2(3)|fn3(keyword=5)'
28 | This will construct the following preprocessing function:
29 | def preprocess_fn(features: Dict[str, tf.Tensor]) -> Dict[str, tf.Tensor]:
30 | features = fn1(features)
31 | features = fn2(features, 3)
32 | features = fn3(features, keyword=5)
33 | return features
34 |
35 | See preprocess_spec_test.py for some simple examples.
36 | """
37 |
38 | import abc
39 | import ast
40 | import dataclasses
41 | import inspect
42 | import re
43 | import sys
44 | from typing import Dict, List, Sequence, Tuple, Type, TypeVar, Union
45 |
46 | from absl import logging
47 | from flax import traverse_util
48 | import jax.numpy as jnp
49 | import tensorflow as tf
50 | import typing_extensions
51 | from typing_extensions import Protocol
52 |
53 |
54 | # Feature dictionary. Arbitrary nested dictionary with string keys and
55 | # tf.Tensor as leaves.
56 | Tensor = Union[tf.Tensor, tf.RaggedTensor, tf.SparseTensor]
57 | # TFDS allows for nested `Features` ...
58 | Features = Dict[str, Union[Tensor, "Features"]]
59 | # ... but it's usually a better idea NOT to nest them. Also better for PyType.
60 | FlatFeatures = Dict[str, Tensor]
61 | D = TypeVar("D", FlatFeatures, tf.data.Dataset)
62 |
63 | # Feature name for the random seed for tf.random.stateless_* ops. By
64 | # convention ops should split of their random seed and keep the SEED_KEY
65 | # feature:
66 | # ```
67 | # features[SEEQ_KEY], seed = tf.unstack(
68 | # tf.random.experimental.stateless_split(features[SEED_KEY]))
69 | # ````
70 | SEED_KEY = "_seed"
71 |
72 | # Regex that finds upper case characters.
73 | _CAMEL_CASE_RGX = re.compile(r"(? Features:
91 | """Applies the preprocessing op to the features."""
92 |
93 |
94 | # Deprecated. Please use `grain.tensorflow.MapTransform`.
95 | class MapTransform(abc.ABC):
96 | """Base class for transformations of single elements.
97 |
98 | This class implements the PreprocessOp interface and also:
99 | - Limits the features to a flat dictionary (instead of an arbitrary nested
100 | dictionary).
101 | - Provides a convenient implementation of `__call__` that can automatically
102 | apply the single transformation to a single example (`FlatFeatures`) or a
103 | `tf.data.Dataset`. The latter is convenient for SeqIO users migrating to
104 | preprocess_spec.py. For multiple transformations we still recommend users
105 | to use the `PreprocessFn` class.
106 | - Enforces subclasses to be a dataclasses.
107 | """
108 |
109 | def __new__(cls, *args, **kwargs):
110 | del args, kwargs
111 | # Check that our subclass instance is a dataclass. We cannot do this with
112 | # `__init_subclass__`` because the dataclasses.dataclass decorator wraps
113 | # the intermediate class which is a subclass of MapTransform but not a
114 | # dataclass.
115 | if not dataclasses.is_dataclass(cls):
116 | raise ValueError(
117 | f"Class {cls} is not a dataclass. We strongly recommend annotating "
118 | "transformations with `@dataclasses.dataclass(frozen=True)`.")
119 | return super().__new__(cls)
120 |
121 | def __call__(self, features: D) -> D:
122 | """Applies the transformation to the features or the dataset."""
123 | logging.warning("clu.preprocess_spec.MapTransform is deprecated. Please "
124 | "switch to grain.tensorflow.MapTransform.")
125 | if isinstance(features, tf.data.Dataset):
126 | return features.map(self._transform, num_parallel_calls=tf.data.AUTOTUNE)
127 | return self._transform(features)
128 |
129 | @abc.abstractmethod
130 | def _transform(self, features: FlatFeatures) -> FlatFeatures:
131 | """Transforms the features."""
132 |
133 |
134 | # Deprecated. Please use `grain.tensorflow.RandomMapTransform`.
135 | class RandomMapTransform(MapTransform, abc.ABC):
136 | """Base class for random transformations of single elements.
137 |
138 | We require all random transformations to use stateless random operations (e.g.
139 | `tf.random.stateless_uniform()`) and respect the provided random seed. The
140 | user can expect the random seed to be unique for the element.
141 |
142 | If multiple random seeds are required the user can split the seed into N
143 | new seeds:
144 | ```
145 | seeds = tf.unstack(tf.random.experimental.stateless_split(seed, N))
146 | ```
147 | """
148 |
149 | def __call__(self, features: D) -> D:
150 | logging.warning("clu.preprocess_spec.RandomMapTransform is deprecated. "
151 | "Please switch to grain.tensorflow.RandomMapTransform.")
152 | if isinstance(features, tf.data.Dataset):
153 | return features.map(self, num_parallel_calls=tf.data.AUTOTUNE)
154 |
155 | next_seed, seed = tf.unstack(
156 | tf.random.experimental.stateless_split(features.pop(SEED_KEY)))
157 | features = self._transform(features, seed)
158 | features[SEED_KEY] = next_seed
159 | return features
160 |
161 | @abc.abstractmethod
162 | def _transform(self, features: FlatFeatures, seed: tf.Tensor) -> FlatFeatures: # pytype: disable=signature-mismatch # overriding-parameter-count-checks
163 | """Transforms the features only using stateless random ops."""
164 |
165 |
166 | # Deprecated. Please use `grain.tensorflow.FilterMapTransform`.
167 | class FilterTransform(abc.ABC):
168 |
169 | def __call__(self, dataset: tf.data.Dataset) -> tf.data.Dataset:
170 | logging.warning("clu.preprocess_spec.FilterTransform is deprecated. Please "
171 | "switch to grain.tensorflow.FilterTransform.")
172 | return dataset.filter(self._predicate)
173 |
174 | @abc.abstractmethod
175 | def _predicate(self, features: FlatFeatures) -> tf.Tensor:
176 | """Returns a True if the element should be kept."""
177 |
178 |
179 | def get_all_ops(module_name: str) -> List[Tuple[str, Type[PreprocessOp]]]:
180 | """Helper to return all preprocess ops in a module.
181 |
182 | Modules that define processing ops can simply define:
183 | all_ops = lambda: process_spec.get_all_ops(__name__)
184 | all_ops() will then return a list with all dataclasses implementing the
185 | PreprocessOp protocol.
186 |
187 | Args:
188 | module_name: Name of the module. The module must already be imported.
189 |
190 | Returns:
191 | List of tuples of process ops. The first tuple element is the class name
192 | converted to snake case (MyAwesomeTransform => my_awesome_transform) and
193 | the second element is the class.
194 | """
195 | def is_op(x):
196 | return (inspect.isclass(x) and dataclasses.is_dataclass(x) and
197 | issubclass(x, PreprocessOp))
198 |
199 | op_name = lambda n: _CAMEL_CASE_RGX.sub("_", n).lower()
200 | members = inspect.getmembers(sys.modules[module_name])
201 | return [(op_name(name), op) for name, op in members if is_op(op)]
202 |
203 |
204 | def _jax_supported_tf_types():
205 | types = [
206 | x for _, x in inspect.getmembers(tf.dtypes)
207 | if isinstance(x, tf.dtypes.DType) and hasattr(jnp, x.name)
208 | ]
209 | # bool is called bool_ in jax and won't be found by the expression above.
210 | return types + [tf.bool]
211 |
212 |
213 | @dataclasses.dataclass
214 | class OnlyJaxTypes:
215 | """Removes all features which types are not supported by JAX.
216 |
217 | This filters dense tensors by dtype and removes sparse and ragged tensors.
218 | The latter don't have an equivalent in JAX.
219 |
220 | Attr:
221 | types: List of allowed types. Defaults to all TF types that can be have an
222 | equivalant in jax.numpy.
223 | """
224 |
225 | types: List[tf.dtypes.DType] = dataclasses.field(
226 | default_factory=_jax_supported_tf_types)
227 |
228 | def __call__(self, features: Features) -> Features:
229 | features = traverse_util.flatten_dict(features)
230 | for name in list(features):
231 | dtype = features[name].dtype
232 | if dtype not in self.types:
233 | del features[name]
234 | logging.warning(
235 | "Removing feature %r because dtype %s is not supported in JAX.",
236 | name, dtype)
237 | elif isinstance(features[name], tf.SparseTensor):
238 | del features[name]
239 | logging.warning(
240 | "Removing feature %r because sparse tensors are not "
241 | "supported in JAX.", name)
242 | elif isinstance(features[name], tf.RaggedTensor):
243 | del features[name]
244 | logging.warning(
245 | "Removing feature %r because ragged tensors are not support in "
246 | "JAX.", name)
247 | features = traverse_util.unflatten_dict(features)
248 | return features # pytype: disable=bad-return-type
249 |
250 |
251 | @dataclasses.dataclass
252 | class PreprocessFn:
253 | """Chain of preprocessing ops combined to a single preprocessing function.
254 |
255 | Attributes:
256 | ops: List of feature transformations. Transformations will be applied in the
257 | given order.
258 | only_jax_types: If True will add the `OnlyJaxTypes` transformation at the
259 | end.
260 | """
261 |
262 | ops: Sequence[PreprocessOp]
263 | only_jax_types: bool
264 |
265 | def __call__(self, features: Features) -> Features:
266 | """Sequentially applies all `self.ops` and returns the result."""
267 | logging.info("Features before preprocessing: %s",
268 | _describe_features(features))
269 | features = features.copy()
270 | for op in self.ops:
271 | features = op(features)
272 | logging.info("Features after op %s:\n%s", op,
273 | _describe_features(features))
274 | logging.info("Features after preprocessing: %s",
275 | _describe_features(features))
276 | if self.only_jax_types:
277 | features = OnlyJaxTypes()(features)
278 | return features
279 |
280 | def __add__(self, other: "PreprocessFn") -> "PreprocessFn":
281 | """Concatenates two `PreprocessingFn`."""
282 | if not isinstance(other, PreprocessFn):
283 | raise ValueError("Can only add other instances of `PreprocessFn`.")
284 | return PreprocessFn(
285 | ops=tuple(self.ops) + tuple(other.ops),
286 | only_jax_types=self.only_jax_types or other.only_jax_types,
287 | )
288 |
289 | def __getitem__(self, op_index: Union[int, slice]) -> "PreprocessFn":
290 | """Returns a `PreprocessFn` of the sliced ops."""
291 | return PreprocessFn(
292 | ops=self.ops[op_index]
293 | if isinstance(op_index, slice) else [self.ops[op_index]],
294 | only_jax_types=self.only_jax_types,
295 | )
296 |
297 |
298 | def _get_op_class(
299 | expr: List[ast.stmt],
300 | available_ops: Dict[str, Type[PreprocessOp]]) -> Type[PreprocessOp]:
301 | """Gets the process op fn from the given expression."""
302 | if isinstance(expr, ast.Call):
303 | fn_name = expr.func.id
304 | elif isinstance(expr, ast.Name):
305 | fn_name = expr.id
306 | else:
307 | raise ValueError(
308 | f"Could not parse function name from expression: {expr!r}.")
309 | if fn_name in available_ops:
310 | return available_ops[fn_name]
311 | raise ValueError(
312 | f"'{fn_name}' is not available (available ops: {list(available_ops)}).")
313 |
314 |
315 | def _parse_single_preprocess_op(
316 | spec: str, available_ops: Dict[str, Type[PreprocessOp]]) -> PreprocessOp:
317 | """Parsing the spec for a single preprocess op.
318 |
319 | The op can just be the method name or the method name followed by any
320 | arguments (both positional and keyword) to the method.
321 | See the test cases for some valid examples.
322 |
323 | Args:
324 | spec: String specifying a single processing operations.
325 | available_ops: Available preprocessing ops.
326 |
327 | Returns:
328 | The ProcessOp corresponding to the spec.
329 | """
330 | try:
331 | expr = ast.parse(spec, mode="eval").body # pytype: disable=attribute-error
332 | except SyntaxError as e:
333 | raise ValueError(f"{spec!r} is not a valid preprocess op spec.") from e
334 | op_class = _get_op_class(expr, available_ops) # pytype: disable=wrong-arg-types
335 |
336 | # Simple case without arguments.
337 | if isinstance(expr, ast.Name):
338 | return op_class()
339 |
340 | assert isinstance(expr, ast.Call)
341 | args = [ast.literal_eval(arg) for arg in expr.args]
342 | kwargs = {kv.arg: ast.literal_eval(kv.value) for kv in expr.keywords}
343 | if not args:
344 | return op_class(**kwargs)
345 |
346 | # Translate positional arguments into keyword arguments.
347 | available_arg_names = [f.name for f in dataclasses.fields(op_class)]
348 | for i, arg in enumerate(args):
349 | name = available_arg_names[i]
350 | if name in kwargs:
351 | raise ValueError(
352 | f"Argument {name} to {op_class} given both as positional argument "
353 | f"(value: {arg}) and keyword argument (value: {kwargs[name]}).")
354 | kwargs[name] = arg
355 |
356 | return op_class(**kwargs)
357 |
358 |
359 | def parse(spec: str,
360 | available_ops: List[Tuple[str, Type[PreprocessOp]]],
361 | *,
362 | only_jax_types: bool = True) -> PreprocessFn:
363 | """Parses a preprocess spec; a '|' separated list of preprocess ops."""
364 | available_ops = dict(available_ops)
365 | if not spec.strip():
366 | ops = []
367 | else:
368 | ops = [
369 | _parse_single_preprocess_op(s, available_ops) for s in spec.split("|")
370 | ]
371 | return PreprocessFn(ops, only_jax_types=only_jax_types)
372 |
373 |
374 | def _describe_features(features: Features) -> str:
375 | description = {}
376 | for k, v in features.items():
377 | if isinstance(v, (tf.Tensor, tf.RaggedTensor, tf.SparseTensor)):
378 | description[k] = f"{v.dtype.name}{v.shape}"
379 | elif isinstance(v, dict):
380 | description[k] = _describe_features(v)
381 | else:
382 | description[k] = f"Unsupported type {type(v)} at feature '{k}'."
383 | return str(description)
384 |
--------------------------------------------------------------------------------