├── AUTHORS ├── clu ├── run_pytest.google.sh ├── __init__.py ├── internal │ ├── __init__.py │ ├── utils.py │ └── utils_test.py ├── metric_writers │ ├── tf │ │ ├── __init__.py │ │ ├── summary_writer.py │ │ └── summary_writer_test.py │ ├── summary_writer.py │ ├── multi_writer_test.py │ ├── __init__.py │ ├── multi_writer.py │ ├── torch_tensorboard_writer.py │ ├── torch_tensorboard_writer_test.py │ ├── logging_writer_test.py │ ├── async_writer_test.py │ ├── utils.py │ ├── logging_writer.py │ ├── async_writer.py │ ├── utils_test.py │ └── interface.py ├── data │ ├── __init__.py │ ├── dataset_iterator_test.py │ └── dataset_iterator.py ├── platform │ ├── __init__.py │ ├── local.py │ └── interface.py ├── profiler.py ├── values.py ├── asynclib_test.py ├── asynclib.py ├── parameter_overview_test.py ├── preprocess_spec_test.py ├── periodic_actions_test.py ├── parameter_overview.py ├── deterministic_data_test.py └── preprocess_spec.py ├── README.md ├── .github └── workflows │ ├── python-publish.yml │ └── build.yml ├── CONTRIBUTING.md ├── setup.py ├── CHANGELOG.md └── LICENSE /AUTHORS: -------------------------------------------------------------------------------- 1 | # This is the list of Common Loop Utils significant contributors. 2 | # 3 | # This does not necessarily list everyone who has contributed code, 4 | # especially since many employees of one corporation may be contributing. 5 | # To see the full list of contributors, see the revision history in 6 | # source control. 7 | Google LLC 8 | -------------------------------------------------------------------------------- /clu/run_pytest.google.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e -x 4 | 5 | CLU_DST="${CLU_DST:-/tmp/clu}" 6 | CLU_ENV="${CLU_ENV:-/tmp/clu_env}" 7 | 8 | copybara third_party/py/clu/copy.bara.sky local .. \ 9 | --folder-dir="${CLU_DST}" --ignore-noop 10 | 11 | # Note: we're reusing the environment if it already exists. 12 | mkdir -p "${CLU_ENV}" 13 | cd "${CLU_ENV}" 14 | python3 -m virtualenv . 15 | . bin/activate 16 | 17 | cd "${CLU_DST}" 18 | pip install . .[test] 19 | 20 | pytest 21 | -------------------------------------------------------------------------------- /clu/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The CLU Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /clu/internal/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The CLU Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /clu/metric_writers/tf/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The CLU Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Package __init__ file.""" 16 | -------------------------------------------------------------------------------- /clu/metric_writers/summary_writer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The CLU Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """MetricWriter for writing to TF summary files.""" 16 | # pylint: disable=unused-import 17 | 18 | from .tf.summary_writer import SummaryWriter 19 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CLU - Common Loop Utils 2 | 3 | This repository contains common functionality for writing ML training loops. The 4 | goal is to make trainings loops short and readable (but moving common tasks to 5 | small libraries) without removing the flexibility required for research. 6 | 7 | To get started, check out this Colab: 8 | 9 | https://colab.research.google.com/github/google/CommonLoopUtils/blob/main/clu_synopsis.ipynb 10 | 11 | If you're looking for usage examples, see: 12 | 13 | https://github.com/google/flax/tree/main/examples 14 | 15 | You can also find answers to common questions about CLU on Flax Github 16 | discussions page: 17 | 18 | https://github.com/google/flax/discussions 19 | 20 | Note: As this point we are not accepting contributions. Please fork the 21 | repository if you want to extend the libraries for your use case. 22 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflows will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | name: Upload Python Package 5 | 6 | on: 7 | release: 8 | types: [created] 9 | 10 | jobs: 11 | deploy: 12 | 13 | runs-on: ubuntu-latest 14 | 15 | steps: 16 | - uses: actions/checkout@v3 17 | - name: Set up Python 18 | uses: actions/setup-python@v4 19 | with: 20 | python-version: '3.x' 21 | - name: Install dependencies 22 | run: | 23 | python -m pip install --upgrade pip 24 | pip install setuptools wheel twine 25 | - name: Build and publish 26 | env: 27 | TWINE_USERNAME: __token__ 28 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 29 | run: | 30 | python setup.py sdist bdist_wheel 31 | twine upload dist/* 32 | -------------------------------------------------------------------------------- /clu/data/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The CLU Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """DatasetIterator is an interface for input pipelines.""" 16 | # pylint: disable=g-multiple-import 17 | # pylint: disable=unused-import 18 | 19 | from clu.data.dataset_iterator import ( 20 | Array, 21 | ArraySpec, 22 | DatasetIterator, 23 | Element, 24 | ElementSpec, 25 | TfDatasetIterator, 26 | PeekableDatasetIterator, 27 | PyTree, 28 | ) 29 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | At this time we are focused on supporting research done by Google Research and 4 | are not accepting patches. 5 | 6 | You are however free to start of fork of the project for your purposes as 7 | permitted by the license. 8 | 9 | ## Contributor License Agreement 10 | 11 | Contributions to this project must be accompanied by a Contributor License 12 | Agreement (CLA). You (or your employer) retain the copyright to your 13 | contribution; this simply gives us permission to use and redistribute your 14 | contributions as part of the project. Head over to 15 | to see your current agreements on file or 16 | to sign a new one. 17 | 18 | You generally only need to submit a CLA once, so if you've already submitted one 19 | (even if it was for a different project), you probably don't need to do it 20 | again. 21 | 22 | ## Community Guidelines 23 | 24 | This project follows 25 | [Google's Open Source Community Guidelines](https://opensource.google/conduct/). 26 | -------------------------------------------------------------------------------- /.github/workflows/build.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint. 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: Build 5 | 6 | on: 7 | push: 8 | branches: 9 | - main 10 | - 'test_*' 11 | pull_request: 12 | branches: 13 | - main 14 | 15 | jobs: 16 | build: 17 | runs-on: ubuntu-latest 18 | strategy: 19 | matrix: 20 | python-version: ['3.10', '3.11'] 21 | steps: 22 | - name: Cancel previous 23 | uses: styfle/cancel-workflow-action@0.8.0 24 | with: 25 | access_token: ${{ github.token }} 26 | - uses: actions/checkout@v4 27 | - name: Set up Python ${{ matrix.python-version }} 28 | uses: actions/setup-python@v5 29 | with: 30 | python-version: ${{ matrix.python-version }} 31 | - name: Install dependencies 32 | run: | 33 | pip install . 34 | pip install .[test] 35 | - name: Test with pytest and generate coverage report 36 | run: | 37 | pytest . 38 | - name: Upload coverage to Codecov 39 | uses: codecov/codecov-action@v1 40 | with: 41 | file: ./coverage.xml 42 | -------------------------------------------------------------------------------- /clu/platform/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The CLU Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Methods for interacting with the experiment platform. 16 | 17 | Use cases include informing the platform of the experiment status and providing 18 | a platform independent interface for interactions. 19 | """ 20 | 21 | import threading 22 | 23 | 24 | from clu.platform.interface import ArtifactType 25 | from clu.platform.interface import WorkUnit 26 | from clu.platform.local import LocalWorkUnit 27 | 28 | # TODO(b/200953513): Migrate away from logging imports (on module level) 29 | # to logging the actual usage. See b/200953513. 30 | 31 | _work_unit = None 32 | _work_unit_lock = threading.Lock() 33 | 34 | 35 | def work_unit() -> WorkUnit: 36 | """Gets the global work unit for this experiment trial.""" 37 | global _work_unit 38 | if _work_unit is None: 39 | with _work_unit_lock: 40 | _work_unit = LocalWorkUnit() 41 | return _work_unit 42 | -------------------------------------------------------------------------------- /clu/platform/local.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The CLU Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Implementation for platform functionality when running locally.""" 16 | 17 | from typing import Any 18 | 19 | from absl import logging 20 | from clu.platform import interface 21 | 22 | WorkUnit = interface.WorkUnit 23 | ArtifactType = interface.ArtifactType 24 | 25 | 26 | class LocalWorkUnit(WorkUnit): 27 | """Dummy work unit for running locally.""" 28 | 29 | @property 30 | def experiment_id(self): 31 | """ID of the experiment of the work unit.""" 32 | return -1 33 | 34 | @property 35 | def id(self): 36 | """Unique identifier for the work unit.""" 37 | return -1 38 | 39 | def set_notes(self, msg: str): 40 | """Sets the notes for this work unit.""" 41 | logging.info("Setting work unit notes: %s", msg) 42 | 43 | def set_task_status(self, msg: str): 44 | """Sets the status string for this task.""" 45 | logging.info("Setting task status: %s", msg) 46 | 47 | def create_artifact(self, artifact_type: ArtifactType, artifact: Any, 48 | description: str): 49 | """Creates an artifact entry for the work unit.""" 50 | logging.info("Created artifact %s of type %s and value %s.", description, 51 | artifact_type, artifact) 52 | -------------------------------------------------------------------------------- /clu/profiler.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The CLU Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Methods for running triggering a profiler for accelerators. 16 | 17 | Where results are stored depends on the platform (e.g. TensorBoard). 18 | """ 19 | from collections.abc import Callable, Sequence 20 | import threading 21 | from typing import Optional, Protocol 22 | 23 | from absl import logging 24 | 25 | import jax 26 | 27 | 28 | 29 | def start(logdir: str, options=None): 30 | """Starts profiling.""" 31 | if options is not None: 32 | raise NotImplementedError( 33 | "'options' not supported by clu.profiler.start(). Please file an issue " 34 | "at https://github.com/google/jax/issues requesting profiler option " 35 | "support if you need this feature.") 36 | if logdir is None: 37 | raise ValueError("Must specify logdir where profile should be written!") 38 | jax.profiler.start_trace(logdir) 39 | 40 | 41 | def stop() -> Optional[str]: 42 | """Stops profiling.""" 43 | jax.profiler.stop_trace() 44 | 45 | 46 | CollectCallback = Callable[[Optional[str]], None] 47 | 48 | 49 | def collect(logdir: str, 50 | callback: CollectCallback, 51 | hosts: Optional[Sequence[str]] = None, 52 | duration_ms: int = 3_000): 53 | """Calls start() followed by stop() after specified duration.""" 54 | del hosts # not used. 55 | start(logdir) 56 | 57 | def timer_cb(): 58 | stop() 59 | callback(None) 60 | 61 | threading.Timer(duration_ms / 1e3, timer_cb).start() 62 | 63 | 64 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The CLU Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """setup.py for Common Loop Utils. 16 | 17 | Install for development: 18 | 19 | pip intall -e . .[tests] 20 | """ 21 | 22 | import sys 23 | 24 | from setuptools import find_packages 25 | from setuptools import setup 26 | 27 | 28 | if sys.version_info < (3, 10): 29 | sys.exit("Python < 3.10 not supported anymore!") 30 | 31 | tests_require = [ 32 | "pytest", 33 | "tensorflow", 34 | "tensorflow_datasets", 35 | "torch>=2.0.0", 36 | ] 37 | 38 | setup( 39 | name="clu", 40 | version="0.0.12", 41 | description="Set of libraries for ML training loops in JAX.", 42 | author="Common Loop Utils Authors", 43 | author_email="no-reply@google.com", 44 | long_description=open("README.md").read(), 45 | long_description_content_type="text/markdown", 46 | url="http://github.com/google/CommonLoopUtils", 47 | license="Apache 2.0", 48 | packages=find_packages(), 49 | include_package_data=True, 50 | install_requires=[ 51 | "absl-py", 52 | "etils[epath,epy]", 53 | "flax", 54 | "jax", 55 | "jaxlib", 56 | "ml_collections", 57 | "numpy", 58 | "packaging", 59 | "typing_extensions", 60 | "wrapt", 61 | ], 62 | tests_require=tests_require, 63 | extras_require=dict(test=tests_require), 64 | classifiers=[ 65 | "Development Status :: 4 - Beta", 66 | "Intended Audience :: Developers", 67 | "Intended Audience :: Science/Research", 68 | "License :: OSI Approved :: Apache Software License", 69 | "Programming Language :: Python", 70 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 71 | ], 72 | keywords="JAX machine learning", 73 | ) 74 | -------------------------------------------------------------------------------- /clu/platform/interface.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The CLU Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Interface work units.""" 16 | 17 | import abc 18 | import enum 19 | from typing import Any 20 | 21 | 22 | class ArtifactType(enum.Enum): 23 | # A URL for dashboards, etc. 24 | URL = 1 25 | # File path. 26 | FILE = 2 27 | # Directory path. 28 | DIRECTORY = 3 29 | 30 | 31 | class WorkUnit(abc.ABC): 32 | """A work unit represents a single trial in an experiment. 33 | 34 | Experiments will usually have multiple work units with different 35 | hyperparameters. Each work unit can have multiple jobs (training, 36 | evaluation, etc.). And jobs can have multiple tasks when the training 37 | is distributed across multiple machines. 38 | """ 39 | 40 | @property 41 | @abc.abstractmethod 42 | def experiment_id(self): 43 | """ID of the experiment of the work unit.""" 44 | 45 | @property 46 | @abc.abstractmethod 47 | def id(self): 48 | """Unique identifier for the work unit.""" 49 | 50 | @property 51 | def name(self): 52 | """Returns the name of the work unit as /. 53 | 54 | XID is a ID of the experiment and WID is the number of the work unit 55 | within the experiment. 56 | 57 | Returns: 58 | The work unit name. e.g. 12345/1. 59 | """ 60 | return f"{self.experiment_id}/{self.id}" 61 | 62 | @abc.abstractmethod 63 | def set_notes(self, msg: str): 64 | """Sets the notes for this work unit. These are displayed in the UI.""" 65 | 66 | @abc.abstractmethod 67 | def set_task_status(self, msg: str): 68 | """Sets the status string for this task.""" 69 | 70 | @abc.abstractmethod 71 | def create_artifact(self, artifact_type: ArtifactType, artifact: Any, 72 | description: str): 73 | """Creates an artifact entry for the work unit.""" 74 | -------------------------------------------------------------------------------- /clu/metric_writers/multi_writer_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The CLU Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for MultiWriter.""" 16 | 17 | from unittest import mock 18 | 19 | from clu.metric_writers import interface 20 | from clu.metric_writers import multi_writer 21 | import numpy as np 22 | import tensorflow as tf 23 | 24 | 25 | class MultiWriterTest(tf.test.TestCase): 26 | 27 | def setUp(self): 28 | super().setUp() 29 | self.writers = [ 30 | mock.create_autospec(interface.MetricWriter), 31 | mock.create_autospec(interface.MetricWriter) 32 | ] 33 | self.writer = multi_writer.MultiWriter(self.writers) 34 | 35 | def test_write_scalars(self): 36 | self.writer.write_scalars(0, {"a": 3, "b": 0.15}) 37 | self.writer.write_scalars(2, {"a": 5, "b": 0.007}) 38 | self.writer.flush() 39 | for w in self.writers: 40 | w.write_scalars.assert_has_calls([ 41 | mock.call(step=0, scalars={ 42 | "a": 3, 43 | "b": 0.15 44 | }), 45 | mock.call(step=2, scalars={ 46 | "a": 5, 47 | "b": 0.007 48 | }) 49 | ]) 50 | w.flush.assert_called() 51 | 52 | def test_write_pointcloud(self): 53 | point_clouds = np.random.normal(0, 1, (1, 1024, 3)).astype(np.float32) 54 | point_colors = np.random.uniform(0, 1, (1, 1024, 3)).astype(np.float32) 55 | config = { 56 | "material": "PointCloudMaterial", 57 | "size": 0.09, 58 | } 59 | self.writer.write_pointcloud( 60 | step=0, 61 | point_clouds={"pcd": point_clouds}, 62 | point_colors={"pcd": point_colors}, 63 | configs={"config": config}, 64 | ) 65 | self.writer.flush() 66 | for w in self.writers: 67 | w.write_pointcloud.assert_called_with( 68 | step=0, 69 | point_clouds={"pcd": point_clouds}, 70 | point_colors={"pcd": point_colors}, 71 | configs={"config": config}, 72 | ) 73 | w.flush.assert_called() 74 | 75 | 76 | if __name__ == "__main__": 77 | tf.test.main() 78 | -------------------------------------------------------------------------------- /clu/metric_writers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The CLU Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Metric writers write ML model outputs during model training and evaluation. 16 | 17 | This module introduces the MetricWriter interface. MetricWriters allow users 18 | to write out metrics about ML models during training and evaluation (e.g. loss, 19 | accuracy). 20 | There is a MetricWriter implementation for each back end (e.g. TensorFlow 21 | summaries) and classes that work on top other MetricWriter to 22 | write to multiple writes at once or write asynchronously. 23 | 24 | Note: The current interface might not contain write() methods for all possible 25 | data types. We are open for extending the interface to other data types 26 | (e.g. audio). 27 | 28 | Usage: 29 | writer = MyMetricWriterImplementation() 30 | # Before training. 31 | writer.write_hparams({"learning_rate": 0.001, "batch_size": 64}) 32 | # Start training loop. 33 | for step in range(num_train_steps): 34 | loss = train_step() 35 | if step % 50 == 0: 36 | writer.write_scalars(step, {"loss": loss}) 37 | accuracy = evaluate() 38 | writer.write_scalars(step, {"accuracy": accuracy}) 39 | # Make sure all values were written. 40 | writer.flush() # or use metric_writers.ensure_flushes() context. 41 | """ 42 | 43 | # pylint: disable=unused-import 44 | # pylint: disable=g-importing-member 45 | 46 | 47 | from clu.metric_writers.async_writer import AsyncMultiWriter 48 | from clu.metric_writers.async_writer import AsyncWriter 49 | from clu.metric_writers.async_writer import ensure_flushes 50 | from clu.metric_writers.interface import MetricWriter 51 | from clu.metric_writers.logging_writer import LoggingWriter 52 | from clu.metric_writers.multi_writer import MultiWriter 53 | from clu.metric_writers.summary_writer import SummaryWriter 54 | from clu.metric_writers.utils import create_default_writer 55 | from clu.metric_writers.utils import write_values 56 | 57 | # TODO(b/200953513): Migrate away from logging imports (on module level) 58 | # to logging the actual usage. See b/200953513. 59 | 60 | -------------------------------------------------------------------------------- /clu/metric_writers/multi_writer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The CLU Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """MetricWriter that writes to multiple MetricWriters.""" 16 | 17 | from collections.abc import Mapping, Sequence 18 | from typing import Any, Optional 19 | 20 | from clu.metric_writers import interface 21 | 22 | Array = interface.Array 23 | Scalar = interface.Scalar 24 | 25 | 26 | class MultiWriter(interface.MetricWriter): 27 | """MetricWriter that writes to multiple writers at once.""" 28 | 29 | def __init__(self, writers: Sequence[interface.MetricWriter]): 30 | self._writers = tuple(writers) 31 | 32 | def write_summaries( 33 | self, step: int, 34 | values: Mapping[str, Array], 35 | metadata: Optional[Mapping[str, Any]] = None): 36 | for w in self._writers: 37 | w.write_summaries(step, values, metadata) 38 | 39 | def write_scalars(self, step: int, scalars: Mapping[str, Scalar]): 40 | for w in self._writers: 41 | w.write_scalars(step, scalars) 42 | 43 | def write_images(self, step: int, images: Mapping[str, Array]): 44 | for w in self._writers: 45 | w.write_images(step, images) 46 | 47 | def write_videos(self, step: int, videos: Mapping[str, Array]): 48 | for w in self._writers: 49 | w.write_videos(step, videos) 50 | 51 | def write_audios( 52 | self, step: int, audios: Mapping[str, Array], *, sample_rate: int): 53 | for w in self._writers: 54 | w.write_audios(step, audios, sample_rate=sample_rate) 55 | 56 | def write_texts(self, step: int, texts: Mapping[str, str]): 57 | for w in self._writers: 58 | w.write_texts(step, texts) 59 | 60 | def write_histograms(self, 61 | step: int, 62 | arrays: Mapping[str, Array], 63 | num_buckets: Optional[Mapping[str, int]] = None): 64 | for w in self._writers: 65 | w.write_histograms(step, arrays, num_buckets) 66 | 67 | def write_pointcloud( 68 | self, 69 | step: int, 70 | point_clouds: Mapping[str, Array], 71 | *, 72 | point_colors: Mapping[str, Array] | None = None, 73 | configs: Mapping[str, str | float | bool | None] | None = None, 74 | ): 75 | for w in self._writers: 76 | w.write_pointcloud( 77 | step, point_clouds, point_colors=point_colors, configs=configs 78 | ) 79 | 80 | def write_hparams(self, hparams: Mapping[str, Any]): 81 | for w in self._writers: 82 | w.write_hparams(hparams) 83 | 84 | def flush(self): 85 | for w in self._writers: 86 | w.flush() 87 | 88 | def close(self): 89 | for w in self._writers: 90 | w.close() 91 | -------------------------------------------------------------------------------- /clu/values.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The CLU Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Defines available types for use by Metrics when written. 16 | 17 | A Metric should return one of the following types when compute() is called. 18 | """ 19 | 20 | import dataclasses 21 | from typing import Any, Union, Protocol, runtime_checkable 22 | 23 | import jax.numpy as jnp 24 | import numpy as np 25 | 26 | ArrayType = Union[np.ndarray, jnp.ndarray] 27 | ScalarType = Union[int, float, np.number, np.ndarray, jnp.ndarray] 28 | 29 | 30 | @runtime_checkable 31 | class Value(Protocol): 32 | """Class defining available metric computation return values. 33 | 34 | Types mirror those available in MetricWriter. See 35 | clu/metric_writers/interface.py 36 | """ 37 | value: Any 38 | 39 | 40 | @dataclasses.dataclass 41 | class Summary(Value): 42 | value: ArrayType 43 | metadata: Any 44 | 45 | 46 | @dataclasses.dataclass 47 | class Scalar(Value): 48 | value: ScalarType 49 | 50 | 51 | @dataclasses.dataclass 52 | class Image(Value): 53 | """Image type. 54 | 55 | Mapping from image key to images. Images should have the shape [N, H, W, C] or 56 | [H, W, C], where H is the height, W is the width and C the 57 | number of channels (1 or 3). N is the number of images that will be 58 | written. Image dimensions can differ between different image keys but 59 | not between different steps for the same image key. 60 | """ 61 | value: ArrayType 62 | 63 | 64 | @dataclasses.dataclass 65 | class Audio(Value): 66 | """Audio type. 67 | 68 | Mapping from audio key to audios. Audios should have the shape [N, T, C], 69 | where T is the time length and C the number of channels (1 - 70 | mono, 2 - stereo, >= 3 - surround; not all writers support any number of 71 | channels). N is the number of audios that will be written. Audio 72 | dimensions can differ between different audio keys but not between 73 | different steps for the same audio key. Values should be floating-point 74 | values in [-1, +1]. 75 | """ 76 | value: ArrayType 77 | sample_rate: int 78 | 79 | 80 | @dataclasses.dataclass 81 | class Text(Value): 82 | value: str 83 | 84 | 85 | @dataclasses.dataclass 86 | class Histogram(Value): 87 | # value must be an array of counts (integers) 88 | value: ArrayType 89 | num_buckets: int 90 | 91 | 92 | @dataclasses.dataclass 93 | class HyperParam(Value): 94 | """The name of the hyperparameter should be handled outside this class. 95 | 96 | Value should correspond to a single hyperparameter, while a Mapping[str, 97 | HyperParam] (name to HyperParam) is maintained independently. 98 | """ 99 | value: Any 100 | -------------------------------------------------------------------------------- /clu/internal/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The CLU Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Small utilities by CLU libraries.""" 16 | 17 | import contextlib 18 | import sys 19 | import time 20 | from typing import Any, List, Mapping, Tuple, Union 21 | 22 | from absl import logging 23 | 24 | import jax.numpy as jnp 25 | import numpy as np 26 | import wrapt 27 | 28 | 29 | @contextlib.contextmanager 30 | def log_activity(activity_name: str): 31 | """Logs `activity_name` and timing information (or exception).""" 32 | t0 = time.time() 33 | logging.info("%s ...", activity_name) 34 | try: 35 | yield 36 | finally: 37 | dt = time.time() - t0 38 | exc, *_ = sys.exc_info() 39 | if exc is not None: 40 | logging.exception("%s FAILED after %.2fs with %s.", activity_name, dt, 41 | exc.__name__) 42 | else: 43 | logging.info("%s finished after %.2fs.", activity_name, dt) 44 | 45 | 46 | 47 | def logged_with(activity_name: str): 48 | """Returns a decorator wrapping a function with `log_activity()`.""" 49 | @wrapt.decorator 50 | def decorator(wrapped, instance, args, kwargs): 51 | del instance # not used 52 | with log_activity(activity_name): 53 | return wrapped(*args, **kwargs) 54 | return decorator 55 | 56 | 57 | def check_param(value, *, ndim=None, dtype=jnp.float32): 58 | """Raises a `ValueError` if `value` does not match ndim/dtype. 59 | 60 | Args: 61 | value: Value to be tested. 62 | ndim: Expected dimensions. 63 | dtype: Expected dtype. 64 | 65 | Raises: 66 | A `ValueError` if `value` does not match `ndim` or `dtype`, or if `value` 67 | is not an instance of `jnp.ndarray`. 68 | """ 69 | if not isinstance(value, (np.ndarray, jnp.ndarray)): 70 | raise ValueError(f"Expected np.array or jnp.array, got type={type(value)}") 71 | if ndim is not None and value.ndim != ndim: 72 | raise ValueError(f"Expected ndim={ndim}, got ndim={value.ndim}") 73 | if dtype is not None and value.dtype != dtype: 74 | raise ValueError(f"Expected dtype={dtype}, got dtype={value.dtype}") 75 | 76 | 77 | def flatten_dict( 78 | d: Mapping[str, Any], prefix: Tuple[str, ...] = () 79 | ) -> List[Tuple[str, Union[int, float, str]]]: 80 | """Returns a sequence of flattened (k, v) pairs for tfsummary.hparams(). 81 | 82 | Args: 83 | d: A dict-like object that has an `.item()` method. 84 | prefix: Prefix to add to keys in `d`. 85 | 86 | Returns: 87 | Sequence of (k, v) pairs where k is the flattened key with individual 88 | subkeys separated by dots. `None` values are replaced by the empty string. 89 | """ 90 | ret = [] 91 | for k, v in d.items(): 92 | # Note `ml_collections.ConfigDict` is not (yet) a `Mapping`. 93 | if isinstance(v, Mapping) or hasattr(v, "items"): 94 | ret += flatten_dict(v, prefix + (k,)) 95 | elif isinstance(v, (list, tuple)): 96 | ret += flatten_dict({str(idx): value for idx, value in enumerate(v)}, 97 | prefix + (k,)) 98 | else: 99 | ret.append((".".join(prefix + (k,)), v if v is not None else "")) 100 | return ret 101 | -------------------------------------------------------------------------------- /clu/metric_writers/torch_tensorboard_writer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The CLU Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """MetricWriter for Pytorch summary files. 16 | 17 | Use this writer for the Pytorch-based code. 18 | 19 | """ 20 | 21 | from collections.abc import Mapping 22 | from typing import Any, Optional 23 | from absl import logging 24 | 25 | from clu.metric_writers import interface 26 | from torch.utils import tensorboard 27 | 28 | Array = interface.Array 29 | Scalar = interface.Scalar 30 | 31 | 32 | class TorchTensorboardWriter(interface.MetricWriter): 33 | """MetricWriter that writes Pytorch summary files.""" 34 | 35 | def __init__(self, logdir: str): 36 | super().__init__() 37 | self._writer = tensorboard.SummaryWriter(log_dir=logdir) 38 | 39 | 40 | def write_summaries( 41 | self, step: int, 42 | values: Mapping[str, Array], 43 | metadata: Optional[Mapping[str, Any]] = None): 44 | logging.log_first_n( 45 | logging.WARNING, 46 | "TorchTensorboardWriter does not support writing raw summaries.", 1) 47 | 48 | def write_scalars(self, step: int, scalars: Mapping[str, Scalar]): 49 | for key, value in scalars.items(): 50 | self._writer.add_scalar(key, value, global_step=step) 51 | 52 | def write_images(self, step: int, images: Mapping[str, Array]): 53 | for key, value in images.items(): 54 | self._writer.add_image(key, value, global_step=step, dataformats="HWC") 55 | 56 | def write_videos(self, step: int, videos: Mapping[str, Array]): 57 | logging.log_first_n( 58 | logging.WARNING, 59 | "TorchTensorBoardWriter does not support writing videos.", 1) 60 | 61 | def write_audios( 62 | self, step: int, audios: Mapping[str, Array], *, sample_rate: int): 63 | for key, value in audios.items(): 64 | self._writer.add_audio( 65 | key, value, global_step=step, sample_rate=sample_rate) 66 | 67 | def write_texts(self, step: int, texts: Mapping[str, str]): 68 | raise NotImplementedError( 69 | "TorchTensorBoardWriter does not support writing texts." 70 | ) 71 | 72 | def write_histograms(self, 73 | step: int, 74 | arrays: Mapping[str, Array], 75 | num_buckets: Optional[Mapping[str, int]] = None): 76 | for tag, values in arrays.items(): 77 | bins = None if num_buckets is None else num_buckets.get(tag) 78 | self._writer.add_histogram( 79 | tag, values, global_step=step, bins="auto", max_bins=bins) 80 | 81 | def write_pointcloud( 82 | self, 83 | step: int, 84 | point_clouds: Mapping[str, Array], 85 | *, 86 | point_colors: Mapping[str, Array] | None = None, 87 | configs: Mapping[str, str | float | bool | None] | None = None, 88 | ): 89 | logging.log_first_n( 90 | logging.WARNING, 91 | "TorchTensorBoardWriter does not support writing point clouds.", 92 | 1, 93 | ) 94 | 95 | def write_hparams(self, hparams: Mapping[str, Any]): 96 | self._writer.add_hparams(hparams, {}) 97 | 98 | def flush(self): 99 | self._writer.flush() 100 | 101 | def close(self): 102 | self._writer.close() 103 | -------------------------------------------------------------------------------- /clu/metric_writers/torch_tensorboard_writer_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The CLU Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for TorchTensorboardWriter.""" 16 | 17 | import collections 18 | import os 19 | from typing import Any, Dict 20 | 21 | from clu.metric_writers import torch_tensorboard_writer 22 | import numpy as np 23 | import tensorflow as tf 24 | 25 | 26 | def _load_scalars_data(logdir: str): 27 | """Loads scalar summaries from events in a logdir.""" 28 | paths = tf.io.gfile.glob(os.path.join(logdir, "events.out.tfevents.*")) 29 | data = collections.defaultdict(dict) 30 | for path in paths: 31 | for event in tf.compat.v1.train.summary_iterator(path): 32 | for value in event.summary.value: 33 | data[event.step][value.tag] = value.simple_value 34 | 35 | return data 36 | 37 | 38 | def _load_histograms_data(logdir: str) -> Dict[int, Dict[str, Any]]: 39 | """Loads histograms summaries from events in a logdir. 40 | 41 | Args: 42 | logdir: a directory to find logs 43 | 44 | Returns: 45 | A generated histograms in a shape step -> tag -> histo. 46 | """ 47 | paths = tf.io.gfile.glob(os.path.join(logdir, "events.out.tfevents.*")) 48 | data = {} 49 | for path in paths: 50 | for event in tf.compat.v1.train.summary_iterator(path): 51 | if event.step not in data: 52 | data[event.step] = {} 53 | step_data = {} 54 | for value in event.summary.value: 55 | print(" value:", value) 56 | step_data[value.tag] = value.histo 57 | data[event.step].update(step_data) 58 | 59 | return data 60 | 61 | 62 | class TorchTensorboardWriterTest(tf.test.TestCase): 63 | 64 | def setUp(self): 65 | super().setUp() 66 | self.logdir = self.get_temp_dir() 67 | self.writer = torch_tensorboard_writer.TorchTensorboardWriter(self.logdir) 68 | 69 | def test_write_scalar(self): 70 | self.writer.write_scalars(11, {"a": 0.6, "b": 15}) 71 | self.writer.write_scalars(20, {"a": 0.8, "b": 12}) 72 | self.writer.flush() 73 | data = _load_scalars_data(self.logdir) 74 | self.assertAllClose(data[11], {"a": 0.6, "b": 15}) 75 | self.assertAllClose(data[20], {"a": 0.8, "b": 12}) 76 | 77 | def test_write_histograms(self): 78 | self.writer.write_histograms( 79 | 0, { 80 | "a": np.asarray([0.3, 0.1, 0.5, 0.7, 0.1]), 81 | "b": np.asarray([-0.1, 0.3, 0.2, 0.4, 0.4]), 82 | }, num_buckets={"a": 2, "b": 2}) 83 | self.writer.write_histograms( 84 | 2, { 85 | "a": np.asarray([0.2, 0.4, 0.5, 0.1, -0.1]), 86 | "b": np.asarray([0.7, 0.3, 0.2, 0.1, 0.0]), 87 | }, num_buckets={"a": 2, "b": 2}) 88 | self.writer.flush() 89 | data = _load_histograms_data(self.logdir) 90 | self.assertNear(data[0]["a"].min, 0.1, 0.001) 91 | self.assertNear(data[0]["a"].max, 0.7, 0.001) 92 | self.assertNear(data[0]["b"].min, -0.1, 0.001) 93 | self.assertNear(data[0]["b"].max, 0.4, 0.001) 94 | self.assertNear(data[2]["a"].min, -0.1, 0.001) 95 | self.assertNear(data[2]["a"].max, 0.5, 0.001) 96 | self.assertNear(data[2]["b"].min, 0.0, 0.001) 97 | self.assertNear(data[2]["b"].max, 0.7, 0.001) 98 | 99 | 100 | if __name__ == "__main__": 101 | tf.test.main() 102 | -------------------------------------------------------------------------------- /clu/internal/utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The CLU Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from unittest import mock 16 | 17 | from absl.testing import absltest 18 | from clu.internal import utils 19 | import jax.numpy as jnp 20 | import ml_collections 21 | 22 | 23 | class TestError(BaseException): 24 | __test__ = False 25 | pass 26 | 27 | 28 | class HelpersTest(absltest.TestCase): 29 | 30 | def test_log_activity( 31 | self, 32 | ): 33 | with self.assertLogs() as logs: 34 | with utils.log_activity("test_activity"): 35 | pass 36 | self.assertLen(logs.output, 2) 37 | self.assertEqual(logs.output[0], "INFO:absl:test_activity ...") 38 | self.assertRegex(logs.output[1], 39 | r"^INFO:absl:test_activity finished after \d+.\d\ds.$") 40 | 41 | def test_log_activity_fails( 42 | self, 43 | ): 44 | with self.assertRaises(TestError): # pylint: disable=g-error-prone-assert-raises, line-too-long 45 | with self.assertLogs() as logs: 46 | with utils.log_activity("test_activity"): 47 | raise TestError() 48 | self.assertLen(logs.output, 2) 49 | self.assertEqual(logs.output[0], "INFO:absl:test_activity ...") 50 | self.assertRegex(logs.output[1], 51 | r"^ERROR:absl:test_activity FAILED after \d+.\d\ds") 52 | 53 | def test_logged_with(self): 54 | 55 | @utils.logged_with("test_activity") 56 | def test(): 57 | pass 58 | 59 | with self.assertLogs() as logs: 60 | test() 61 | self.assertLen(logs.output, 2) 62 | self.assertEqual(logs.output[0], "INFO:absl:test_activity ...") 63 | self.assertRegex(logs.output[1], 64 | r"^INFO:absl:test_activity finished after \d+.\d\ds.$") 65 | 66 | def test_logged_with_fails(self): 67 | 68 | @utils.logged_with("test_activity") 69 | def test(): 70 | raise TestError() 71 | 72 | with self.assertRaises(TestError): # pylint: disable=g-error-prone-assert-raises, line-too-long 73 | with self.assertLogs() as logs: 74 | test() 75 | self.assertLen(logs.output, 2) 76 | self.assertEqual(logs.output[0], "INFO:absl:test_activity ...") 77 | self.assertRegex(logs.output[1], 78 | r"^ERROR:absl:test_activity FAILED after \d+.\d\ds") 79 | 80 | def test_check_param(self): 81 | a = jnp.array(0.) 82 | with self.assertRaisesRegex(ValueError, r"^Expected np.array or jnp.array"): 83 | utils.check_param(None, ndim=1) 84 | with self.assertRaisesRegex(ValueError, r"^Expected ndim"): 85 | utils.check_param(a, ndim=1) 86 | with self.assertRaisesRegex(ValueError, r"^Expected dtype"): 87 | utils.check_param(a, ndim=0, dtype=jnp.int32) 88 | utils.check_param(a, ndim=0) # should work 89 | utils.check_param(a, ndim=0, dtype=jnp.float32) # should also work 90 | 91 | def test_flatten_dict(self): 92 | self.assertEqual( 93 | utils.flatten_dict( 94 | ml_collections.ConfigDict({ 95 | "x": 1, 96 | "y": None, 97 | "z": ml_collections.ConfigDict({ 98 | "a": "bc", 99 | }) 100 | })), [("x", 1), ("y", ""), ("z.a", "bc")]) 101 | 102 | 103 | if __name__ == "__main__": 104 | absltest.main() 105 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | ## v0.0.1-alpha.1 4 | 5 | Initial PyPi Release 6 | 7 | Current list of modules: 8 | 9 | - `clu.checkpoint` 10 | - `clu.deterministic_training` 11 | - `clu.metric_writers` 12 | - `clu.periodic_actions` 13 | - `clu.platform` 14 | - `clu.profiler` 15 | 16 | ## v0.0.1-alpha.2 17 | 18 | - Adds `metrics` module and some minor changes. 19 | 20 | ## v0.0.1a3 21 | 22 | - Added `metric_writers.TorchTensorboardWriter` 23 | 24 | ## v0.0.2 25 | 26 | - Added preprocess_spec. 27 | - Improvements to periodic_actions. 28 | 29 | ## v0.0.3 30 | 31 | - `metric_writers`: Lets `SummaryWriter` write nested dictionaries. 32 | - `internal`: Adds `async.Pool`. 33 | - `preprocess_spec`: Support nested dictionaries. 34 | - `profile`: Use JAX profiler APIs instead of TF profiler APIs. 35 | 36 | ## v0.0.4 37 | 38 | `deterministic_data` 39 | 40 | - Support non-positive input value for pad_up_to_batches. 41 | - Support padding dataset when data dimension is unknown. 42 | - Support TFDS specs in get_read_instruction_for_host. 43 | - Allow changing drop_remainder for batching. 44 | - Add RemainderOptions in deterministic_data. 45 | 46 | `metric_writers` 47 | 48 | - Support multiple writers in metric_writers.ensure_flushes. 49 | 50 | `metrics` 51 | 52 | - Makes internal.flatten_dict() work with ConfigDicts. 53 | - Forwards mask model output to metrics created via `Metric.from_output()`. 54 | - Forwards mask model output to metrics created via `Metric.from_fun()`. 55 | - Added `Collections.unreplicate()`, `Collections.create()`. 56 | 57 | `periodic_actions` 58 | 59 | - Formats long time strings in '{days}d{hours}h{mins}m' format. 60 | 61 | `preprocess_spec` 62 | 63 | - Make feature description of features in PreprocessFn more compact. 64 | - Better type check in `preprocess_spec.get_all_ops()`. 65 | 66 | Documentation: 67 | 68 | - Added `clu_synopsis.ipynb` Colab 69 | 70 | ## v0.0.5 71 | 72 | - Log error instead of failing when `profiler.start()` raises an exception. 73 | - Makes `periodic_actions.ProgressUpdate` show total number of steps. 74 | - Makes `AsyncWriter` non-blocking wrt JAX async computations. 75 | - Adds `clu_synopsis.ipynb` Colab as initial documentation. 76 | - Restore Checkpoint without providing the state 77 | - Makes `PreprocessFn` addable. 78 | - Allow n-dimensional arrays (and masks) to be passed to Metrics.Average(). 79 | - Support slicing `PreprocessFn`. 80 | 81 | ## v0.0.6 82 | 83 | - Makes `deterministic_data` work with `tfds>4.4.0` and `tfds<=4.4.0`. 84 | 85 | This will be the last release supporting Python 3.6. 86 | 87 | ## v0.0.7 88 | 89 | - Moves `clu.internal.asynclib` to `clu.asynclib`. 90 | - Adds methods for writing raw tensors and audio to `MetricWriter`. 91 | - Adds `clu.values` to annotate arrays with a modality. 92 | - Adds `clu.data.DatasetIterator` - a generic interface between input 93 | pipelines and training loops. 94 | - Fixes various issues with `clu.metrics`. 95 | 96 | This will be the last release supporting Python 3.7. 97 | 98 | ## v0.0.9 99 | 100 | - Fix pytype failures related to teaching pytype about NumPy scalar types. 101 | - Fix a couple of docstring typos. 102 | - Updates README and clu_synposis.ipynb 103 | 104 | Last release before dropping support for Python 3.8 and 3.9 105 | 106 | ## v0.0.10 107 | 108 | - `clu.parameter_overview` now supports JAX global arrays. 109 | - Various small fixes in `clu.metrics` module. 110 | - Removed some tensorflow dependencies. 111 | 112 | ## v0.0.11 113 | 114 | - Removes numpy version pin 115 | - Adds sharding annotations, dtype, total bytes to `parameter_overview` 116 | - Makes `clu.metrics.Std` support same shapes as `clu.metrics.Average` 117 | 118 | ## v0.0.12 119 | 120 | - Switch from `jax.tree_map` (deprecated since JAX 0.4.26) to 121 | `jax.tree_util.tree_map`. 122 | - Improvements to parameter overview. 123 | -------------------------------------------------------------------------------- /clu/asynclib_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The CLU Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for clu.asynclib.""" 16 | 17 | from unittest import mock 18 | 19 | from absl.testing import absltest 20 | from clu import asynclib 21 | 22 | 23 | class AsyncWriterTest(absltest.TestCase): 24 | 25 | def test_async_execution(self): 26 | pool = asynclib.Pool() 27 | counter = 0 28 | 29 | @pool 30 | def fn(counter_increment, return_value): 31 | nonlocal counter 32 | counter += counter_increment 33 | return return_value 34 | 35 | future = fn(1, return_value=2) 36 | self.assertEqual(counter, 1) 37 | self.assertEqual(future.result(), 2) 38 | 39 | def test_reraise(self): 40 | pool = asynclib.Pool() 41 | 42 | @pool 43 | def error(): 44 | raise ValueError("test") 45 | 46 | error() 47 | self.assertTrue(pool.has_errors) 48 | with self.assertRaisesRegex(asynclib.AsyncError, "test"): 49 | pool.join() 50 | self.assertFalse(pool.has_errors) 51 | 52 | @pool 53 | def noop(): 54 | ... 55 | 56 | error() 57 | self.assertTrue(pool.has_errors) 58 | with self.assertRaisesRegex(asynclib.AsyncError, "test"): 59 | noop() 60 | self.assertFalse(pool.has_errors) 61 | 62 | pool.join() 63 | 64 | @mock.patch("concurrent.futures.ThreadPoolExecutor") 65 | def test_queue_length(self, executor_mock): 66 | pool_mock = mock.Mock() 67 | in_flight = [] 68 | 69 | def execute_one(): 70 | in_flight.pop(0)() 71 | 72 | def submit(fn, *args, **kwargs): 73 | in_flight.append(lambda: fn(*args, **kwargs)) 74 | 75 | pool_mock.submit = submit 76 | executor_mock.return_value = pool_mock 77 | 78 | pool = asynclib.Pool() 79 | 80 | @pool 81 | def noop(): 82 | ... 83 | 84 | self.assertEqual(pool.queue_length, 0) 85 | noop() 86 | self.assertEqual(pool.queue_length, 1) 87 | noop() 88 | self.assertEqual(pool.queue_length, 2) 89 | execute_one() 90 | self.assertEqual(pool.queue_length, 1) 91 | execute_one() 92 | self.assertEqual(pool.queue_length, 0) 93 | 94 | @mock.patch("concurrent.futures.ThreadPoolExecutor") 95 | def test_flush(self, executor_mock): 96 | pool_mock = mock.Mock() 97 | pool_mock._in_flight = None 98 | 99 | def execute_one(): 100 | pool_mock._in_flight.pop(0)() 101 | 102 | def submit(fn, *args, **kwargs): 103 | pool_mock._in_flight.append(lambda: fn(*args, **kwargs)) 104 | 105 | def create_pool(max_workers, thread_name_prefix): 106 | del max_workers 107 | del thread_name_prefix 108 | pool_mock._in_flight = [] 109 | return pool_mock 110 | 111 | def shutdown(wait=False): 112 | if wait: 113 | while pool_mock._in_flight: 114 | execute_one() 115 | pool_mock._in_flight = None 116 | 117 | pool_mock.submit = submit 118 | executor_mock.side_effect = create_pool 119 | pool_mock.shutdown.side_effect = shutdown 120 | 121 | pool = asynclib.Pool() 122 | 123 | @pool 124 | def noop(): 125 | ... 126 | 127 | self.assertEqual(pool.queue_length, 0) 128 | noop() 129 | self.assertEqual(pool.queue_length, 1) 130 | noop() 131 | pool.join() 132 | self.assertEqual(pool.queue_length, 0) 133 | noop() 134 | self.assertEqual(pool.queue_length, 1) 135 | 136 | 137 | if __name__ == "__main__": 138 | absltest.main() 139 | -------------------------------------------------------------------------------- /clu/data/dataset_iterator_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The CLU Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for dataset_iterator.""" 16 | import itertools 17 | import pathlib 18 | import tempfile 19 | 20 | from absl.testing import parameterized 21 | from clu.data import dataset_iterator 22 | import numpy as np 23 | import tensorflow as tf 24 | 25 | INDEX = "_index" 26 | 27 | 28 | class DatasetIteratorTest(parameterized.TestCase, tf.test.TestCase): 29 | 30 | def _create_iterator(self, start_index: int, checkpoint: bool = True): 31 | """Create an iterator over some prime numbers with index.""" 32 | primes = tf.constant([2, 3, 5, 7, 11, 13, 17, 19, 23, 29]) 33 | ds = tf.data.Dataset.range(start_index, 10) 34 | ds = ds.map(lambda i: {INDEX: i, "prime": primes[i]}) 35 | # Remove index 1 and 3. 36 | ds = ds.filter(lambda x: tf.logical_and(x["prime"] != 3, x["prime"] != 7)) 37 | ds = ds.batch(2, drop_remainder=True) 38 | return dataset_iterator.TfDatasetIterator(ds, checkpoint=checkpoint) 39 | 40 | def test_tf_iterator(self): 41 | it = self._create_iterator(0) 42 | self.assertEqual( 43 | it.element_spec, { 44 | INDEX: dataset_iterator.ArraySpec(np.int64, (2,)), 45 | "prime": dataset_iterator.ArraySpec(np.int32, (2,)) 46 | }) 47 | self.assertEqual(next(it), {INDEX: [0, 2], "prime": [2, 5]}) 48 | self.assertEqual(next(it), {INDEX: [4, 5], "prime": [11, 13]}) 49 | it.reset() 50 | # Iterator starts from the beginning. 51 | self.assertEqual(next(it), {INDEX: [0, 2], "prime": [2, 5]}) 52 | 53 | def test_tf_iterator_save_and_load(self): 54 | it = self._create_iterator(0) 55 | next(it) 56 | next(it) 57 | next(it) 58 | work_dir = pathlib.Path(tempfile.mkdtemp()) 59 | filename = work_dir / "ckpt" 60 | it.save(filename) 61 | self.assertTrue((work_dir / "ckpt.index").exists()) 62 | 63 | it = self._create_iterator(0) 64 | # Iterator is at the beginning (batch 1). 65 | self.assertEqual(next(it), {INDEX: [0, 2], "prime": [2, 5]}) 66 | it.load(filename) 67 | # Iterator is at the end (batch 4). 68 | self.assertEqual(next(it), {INDEX: [8, 9], "prime": [23, 29]}) 69 | 70 | def test_tf_iterator_save_and_load_no_checkpoint(self): 71 | it = self._create_iterator(0, checkpoint=False) 72 | self.assertEqual(next(it), {INDEX: [0, 2], "prime": [2, 5]}) 73 | self.assertEqual(next(it), {INDEX: [4, 5], "prime": [11, 13]}) 74 | work_dir = pathlib.Path(tempfile.mkdtemp()) 75 | filename = work_dir / "ckpt" 76 | it.save(filename) # Should be a no-op and not create a checkpoint. 77 | self.assertFalse((work_dir / "ckpt.index").exists()) 78 | 79 | it = self._create_iterator(0, checkpoint=False) 80 | self.assertEqual(next(it), {INDEX: [0, 2], "prime": [2, 5]}) 81 | it.restore(filename) # Should be a no-op, iterator just continues. 82 | self.assertEqual(next(it), {INDEX: [4, 5], "prime": [11, 13]}) 83 | 84 | def test_peekable_dataset_iterator(self): 85 | it = self._create_iterator(0) 86 | it = dataset_iterator.PeekableDatasetIterator(it) 87 | self.assertEqual(it.peek(), {INDEX: [0, 2], "prime": [2, 5]}) 88 | self.assertEqual(next(it), {INDEX: [0, 2], "prime": [2, 5]}) 89 | self.assertEqual(next(it), {INDEX: [4, 5], "prime": [11, 13]}) 90 | 91 | @parameterized.parameters(itertools.product([True, False], [True, False])) 92 | def test_peekable_dataset_iterator_async(self, wait: bool, peek_first: bool): 93 | it = self._create_iterator(0) 94 | it = dataset_iterator.PeekableDatasetIterator(it) 95 | future = it.peek_async() 96 | self.assertIsNone(it._peek) 97 | if wait: 98 | future.result() 99 | self.assertIsNotNone(it._peek) 100 | if peek_first: 101 | self.assertEqual(it.peek(), {INDEX: [0, 2], "prime": [2, 5]}) 102 | self.assertEqual(next(it), {INDEX: [0, 2], "prime": [2, 5]}) 103 | self.assertEqual(next(it), {INDEX: [4, 5], "prime": [11, 13]}) 104 | 105 | 106 | if __name__ == "__main__": 107 | tf.test.main() 108 | -------------------------------------------------------------------------------- /clu/metric_writers/tf/summary_writer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The CLU Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """MetricWriter for writing to TF summary files. 16 | 17 | Only works in eager mode. Does not work for Pytorch code, please use 18 | TorchTensorboardWriter instead. 19 | """ 20 | 21 | from collections.abc import Mapping 22 | from typing import Any, Optional 23 | 24 | from absl import logging 25 | 26 | from clu.internal import utils 27 | from clu.metric_writers import interface 28 | from etils import epy 29 | import tensorflow as tf 30 | 31 | with epy.lazy_imports(): 32 | # pylint: disable=g-import-not-at-top 33 | from tensorboard.plugins.hparams import api as hparams_api 34 | from tensorboard.plugins.mesh import summary as mesh_summary # pylint: disable=line-too-long 35 | # pylint: enable=g-import-not-at-top 36 | 37 | 38 | Array = interface.Array 39 | Scalar = interface.Scalar 40 | 41 | 42 | class SummaryWriter(interface.MetricWriter): 43 | """MetricWriter that writes TF summary files.""" 44 | 45 | def __init__(self, logdir: str): 46 | super().__init__() 47 | self._summary_writer = tf.summary.create_file_writer(logdir) 48 | 49 | 50 | def write_summaries( 51 | self, 52 | step: int, 53 | values: Mapping[str, Array], 54 | metadata: Optional[Mapping[str, Any]] = None, 55 | ): 56 | with self._summary_writer.as_default(): 57 | for key, value in values.items(): 58 | md = metadata.get(key) if metadata is not None else None 59 | tf.summary.write(key, value, step=step, metadata=md) 60 | 61 | def write_scalars(self, step: int, scalars: Mapping[str, Scalar]): 62 | with self._summary_writer.as_default(): 63 | for key, value in scalars.items(): 64 | tf.summary.scalar(key, value, step=step) 65 | 66 | def write_images(self, step: int, images: Mapping[str, Array]): 67 | with self._summary_writer.as_default(): 68 | for key, value in images.items(): 69 | if len(value.shape) == 3: 70 | value = value[None] 71 | tf.summary.image(key, value, step=step, max_outputs=value.shape[0]) 72 | 73 | def write_videos(self, step: int, videos: Mapping[str, Array]): 74 | logging.log_first_n( 75 | logging.WARNING, 76 | "SummaryWriter does not support writing videos.", 1) 77 | 78 | def write_audios( 79 | self, step: int, audios: Mapping[str, Array], *, sample_rate: int): 80 | with self._summary_writer.as_default(): 81 | for key, value in audios.items(): 82 | tf.summary.audio(key, value, sample_rate=sample_rate, step=step, 83 | max_outputs=value.shape[0]) 84 | 85 | def write_texts(self, step: int, texts: Mapping[str, str]): 86 | with self._summary_writer.as_default(): 87 | for key, value in texts.items(): 88 | tf.summary.text(key, value, step=step) 89 | 90 | def write_histograms( 91 | self, 92 | step: int, 93 | arrays: Mapping[str, Array], 94 | num_buckets: Optional[Mapping[str, int]] = None, 95 | ): 96 | with self._summary_writer.as_default(): 97 | for key, value in arrays.items(): 98 | buckets = None if num_buckets is None else num_buckets.get(key) 99 | tf.summary.histogram(key, value, step=step, buckets=buckets) 100 | 101 | def write_pointcloud( 102 | self, 103 | step: int, 104 | point_clouds: Mapping[str, Array], 105 | *, 106 | point_colors: Mapping[str, Array] | None = None, 107 | configs: Mapping[str, str | float | bool | None] | None = None, 108 | ): 109 | with self._summary_writer.as_default(): 110 | for key, vertices in point_clouds.items(): 111 | colors = None if point_colors is None else point_colors.get(key) 112 | config = None if configs is None else configs.get(key) 113 | mesh_summary.mesh( 114 | key, 115 | vertices=vertices, 116 | colors=colors, 117 | step=step, 118 | config_dict=config, 119 | ) 120 | 121 | def write_hparams(self, hparams: Mapping[str, Any]): 122 | with self._summary_writer.as_default(): 123 | hparams_api.hparams(dict(utils.flatten_dict(hparams))) 124 | 125 | def flush(self): 126 | self._summary_writer.flush() 127 | 128 | def close(self): 129 | self._summary_writer.close() 130 | -------------------------------------------------------------------------------- /clu/metric_writers/logging_writer_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The CLU Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for the LoggingWriter.""" 16 | 17 | from clu.metric_writers import logging_writer 18 | import numpy as np 19 | import tensorflow as tf 20 | 21 | 22 | class LoggingWriterTest(tf.test.TestCase): 23 | 24 | def setUp(self): 25 | super().setUp() 26 | self.writer = logging_writer.LoggingWriter() 27 | 28 | def test_write_scalars(self): 29 | with self.assertLogs(level="INFO") as logs: 30 | self.writer.write_scalars(0, {"a": 3, "b": 0.15}) 31 | self.writer.write_scalars(2, {"a": 0.0000005, "b": 0.007}) 32 | self.assertEqual( 33 | logs.output, 34 | ["INFO:absl:[0] a=3, b=0.15", "INFO:absl:[2] a=5e-07, b=0.007"]) 35 | 36 | def test_write_images(self): 37 | images = np.zeros((2, 28, 28, 3)) 38 | with self.assertLogs(level="INFO") as logs: 39 | self.writer.write_images(4, {"input_images": images}) 40 | self.assertEqual( 41 | logs.output, 42 | ["INFO:absl:[4] Got images: {'input_images': (2, 28, 28, 3)}."]) 43 | 44 | def test_write_videos(self): 45 | videos = np.zeros((2, 4, 28, 28, 3)) 46 | with self.assertLogs(level="INFO") as logs: 47 | self.writer.write_videos(4, {"input_videos": videos}) 48 | self.assertEqual( 49 | logs.output, 50 | ["INFO:absl:[4] Got videos: {'input_videos': (2, 4, 28, 28, 3)}."]) 51 | 52 | def test_write_texts(self): 53 | with self.assertLogs(level="INFO") as logs: 54 | self.writer.write_texts(4, {"samples": "bla"}) 55 | self.assertEqual( 56 | logs.output, 57 | ["INFO:absl:[4] Got texts: {'samples': 'bla'}."]) 58 | 59 | def test_write_histogram(self): 60 | with self.assertLogs(level="INFO") as logs: 61 | self.writer.write_histograms( 62 | step=4, 63 | arrays={ 64 | "a": np.asarray([-0.1, 0.1, 0.3]), 65 | "b": np.arange(31), 66 | "c": np.asarray([0.1, 0.1, 0.1, 0.1, 0.1]), 67 | }, 68 | num_buckets={ 69 | "a": 2, 70 | "c": 1 71 | }) 72 | # Note: There are 31 distinct values [0, 1, ..., 30], and 30 buckets by 73 | # default. Last bucket gets 2 values. 74 | expected_histo_b = ", ".join([f"[{i}, {i + 1}): 1" for i in range(29)] + 75 | ["[29, 30]: 2"]) 76 | self.assertEqual(logs.output, [ 77 | "INFO:absl:[4] Histogram for 'a' = {[-0.1, 0.1): 1, [0.1, 0.3]: 2}", 78 | f"INFO:absl:[4] Histogram for 'b' = {{{expected_histo_b}}}", 79 | "WARNING:absl:num_buckets was automatically changed from 1 to 2", 80 | "INFO:absl:[4] Histogram for 'c' = {[-0.4, 0.6]: 5}", 81 | ]) 82 | 83 | def test_write_pointcloud(self): 84 | point_clouds = np.random.normal(0, 1, (1, 1024, 3)).astype(np.float32) 85 | point_colors = np.random.uniform(0, 1, (1, 1024, 3)).astype(np.float32) 86 | config = { 87 | "material": "PointCloudMaterial", 88 | "size": 0.09, 89 | } 90 | with self.assertLogs(level="INFO") as logs: 91 | self.writer.write_pointcloud( 92 | step=4, 93 | point_clouds={"pcd": point_clouds}, 94 | point_colors={"pcd": point_colors}, 95 | configs={"configs": config}, 96 | ) 97 | self.assertEqual( 98 | logs.output, 99 | [ 100 | "INFO:absl:[4] Got point clouds: {'pcd': (1, 1024, 3)}," 101 | " point_colors: {'pcd': (1, 1024, 3)}, configs: {'configs':" 102 | " {'material': 'PointCloudMaterial', 'size': 0.09}}." 103 | ], 104 | ) 105 | 106 | def test_write_hparams(self): 107 | with self.assertLogs(level="INFO") as logs: 108 | self.writer.write_hparams({"learning_rate": 0.1, "batch_size": 128}) 109 | self.assertEqual(logs.output, [ 110 | "INFO:absl:[Hyperparameters] {'learning_rate': 0.1, 'batch_size': 128}" 111 | ]) 112 | 113 | def test_collection(self): 114 | writer = logging_writer.LoggingWriter(collection="train") 115 | with self.assertLogs(level="INFO") as logs: 116 | writer.write_scalars(0, {"a": 3, "b": 0.15}) 117 | writer.write_images(4, {"input_images": np.zeros((2, 28, 28, 3))}) 118 | writer.write_texts(4, {"samples": "bla"}) 119 | writer.write_histograms( 120 | step=4, 121 | arrays={ 122 | "a": np.asarray([-0.1, 0.1, 0.3]), 123 | }, 124 | num_buckets={ 125 | "a": 2, 126 | }) 127 | writer.write_hparams({"learning_rate": 0.1}) 128 | 129 | self.assertEqual(logs.output, [ 130 | "INFO:absl:[0] collection=train a=3, b=0.15", 131 | "INFO:absl:[4] collection=train Got images: {'input_images': (2, 28, 28, 3)}.", 132 | "INFO:absl:[4] collection=train Got texts: {'samples': 'bla'}.", 133 | "INFO:absl:[4] collection=train Histogram for 'a' = {[-0.1, 0.1): 1, [0.1, 0.3]: 2}", 134 | "INFO:absl:[Hyperparameters] collection=train {'learning_rate': 0.1}", 135 | ]) 136 | 137 | 138 | if __name__ == "__main__": 139 | tf.test.main() 140 | -------------------------------------------------------------------------------- /clu/metric_writers/async_writer_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The CLU Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for AsyncWriter.""" 16 | 17 | import time 18 | from unittest import mock 19 | 20 | from clu import asynclib 21 | from clu.metric_writers import async_writer 22 | from clu.metric_writers import interface 23 | import numpy as np 24 | import tensorflow as tf 25 | 26 | 27 | class AsyncWriterTest(tf.test.TestCase): 28 | 29 | def setUp(self): 30 | super().setUp() 31 | self.sync_writer = mock.create_autospec(interface.MetricWriter) 32 | self.writer = async_writer.AsyncWriter(self.sync_writer) 33 | 34 | def test_write_summaries_async(self): 35 | self.writer.write_summaries( 36 | 11, 37 | {"a": np.eye(3, dtype=np.uint8), 38 | "b": np.eye(2, dtype=np.float32)}, 39 | {"a": np.ones((2, 3)).tobytes()}) 40 | self.writer.flush() 41 | self.sync_writer.write_summaries.assert_called_with( 42 | step=11, 43 | values={"a": mock.ANY, "b": mock.ANY}, 44 | metadata={"a": mock.ANY}) 45 | 46 | def test_write_scalars_async(self): 47 | self.writer.write_scalars(0, {"a": 3, "b": 0.15}) 48 | self.writer.write_scalars(2, {"a": 5, "b": 0.007}) 49 | self.writer.flush() 50 | self.sync_writer.write_scalars.assert_has_calls([ 51 | mock.call(step=0, scalars={ 52 | "a": 3, 53 | "b": 0.15 54 | }), 55 | mock.call(step=2, scalars={ 56 | "a": 5, 57 | "b": 0.007 58 | }) 59 | ]) 60 | 61 | def test_write_images(self): 62 | images = np.zeros((2, 28, 28, 3)) 63 | self.writer.write_images(4, {"input_images": images}) 64 | self.writer.flush() 65 | self.sync_writer.write_images.assert_called_with(4, 66 | {"input_images": mock.ANY}) 67 | 68 | def test_write_videos(self): 69 | videos = np.zeros((2, 4, 28, 28, 3)) 70 | self.writer.write_videos(4, {"input_videos": videos}) 71 | self.writer.flush() 72 | self.sync_writer.write_videos.assert_called_with(4, 73 | {"input_videos": mock.ANY}) 74 | 75 | def test_write_pointcloud(self): 76 | point_clouds = np.random.normal(0, 1, (1, 1024, 3)).astype(np.float32) 77 | point_colors = np.random.uniform(0, 1, (1, 1024, 3)).astype(np.float32) 78 | config = { 79 | "material": "PointCloudMaterial", 80 | "size": 0.09, 81 | } 82 | self.writer.write_pointcloud( 83 | step=0, 84 | point_clouds={"pcd": point_clouds}, 85 | point_colors={"pcd": point_colors}, 86 | configs={"config": config}, 87 | ) 88 | self.writer.flush() 89 | self.sync_writer.write_pointcloud.assert_called_with( 90 | step=0, 91 | point_clouds={"pcd": mock.ANY}, 92 | point_colors={"pcd": mock.ANY}, 93 | configs={"config": mock.ANY}, 94 | ) 95 | 96 | def test_write_texts(self): 97 | self.writer.write_texts(4, {"samples": "bla"}) 98 | self.writer.flush() 99 | self.sync_writer.write_texts.assert_called_with(4, {"samples": "bla"}) 100 | 101 | def test_ensure_flushes(self): 102 | with async_writer.ensure_flushes(self.writer) as writer: 103 | writer.write_scalars(0, {"a": 3, "b": 0.15}) 104 | writer.write_scalars(2, {"a": 5, "b": 0.007}) 105 | self.sync_writer.write_scalars.assert_has_calls([ 106 | mock.call(step=0, scalars={ 107 | "a": 3, 108 | "b": 0.15 109 | }), 110 | mock.call(step=2, scalars={ 111 | "a": 5, 112 | "b": 0.007 113 | }) 114 | ]) 115 | self.sync_writer.flush.assert_called_once() 116 | 117 | def test_ensure_flushes_with_multiple_writers(self): 118 | sync_writer1 = mock.create_autospec(interface.MetricWriter) 119 | writer1 = async_writer.AsyncWriter(sync_writer1) 120 | sync_writer2 = mock.create_autospec(interface.MetricWriter) 121 | writer2 = async_writer.AsyncWriter(sync_writer2) 122 | 123 | with async_writer.ensure_flushes(writer1, writer2): 124 | writer1.write_scalars(0, {"a": 3, "b": 0.15}) 125 | writer2.write_scalars(2, {"a": 5, "b": 0.007}) 126 | 127 | sync_writer1.write_scalars.assert_has_calls( 128 | [mock.call(step=0, scalars={ 129 | "a": 3, 130 | "b": 0.15 131 | })]) 132 | 133 | sync_writer2.write_scalars.assert_has_calls( 134 | [mock.call(step=2, scalars={ 135 | "a": 5, 136 | "b": 0.007 137 | })]) 138 | 139 | sync_writer1.flush.assert_called_once() 140 | sync_writer2.flush.assert_called_once() 141 | 142 | def test_flush_before_close(self): 143 | self.writer.close() 144 | self.sync_writer.flush.assert_called() 145 | self.sync_writer.close.assert_called() 146 | 147 | def test_reraises_exception(self): 148 | self.sync_writer.write_scalars.side_effect = ValueError("foo") 149 | self.writer.write_scalars(0, {"a": 3, "b": 0.15}) 150 | time.sleep(0.1) 151 | with self.assertRaisesRegex(asynclib.AsyncError, "Consider re-running"): 152 | self.writer.write_scalars(2, {"a": 5, "b": 0.007}) 153 | 154 | 155 | if __name__ == "__main__": 156 | tf.test.main() 157 | -------------------------------------------------------------------------------- /clu/asynclib.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The CLU Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Utilities for async function calls.""" 16 | 17 | import collections 18 | import concurrent.futures 19 | import functools 20 | import sys 21 | import threading 22 | from typing import Callable, List, Optional 23 | 24 | from absl import logging 25 | 26 | 27 | class AsyncError(Exception): 28 | """An exception that wraps another exception that ocurred asynchronously.""" 29 | 30 | 31 | class Pool: 32 | """Pool for wrapping functions to be executed asynchronously. 33 | 34 | Synopsis: 35 | 36 | from clu.internal import asynclib 37 | 38 | pool = asynclib.Pool() 39 | @pool 40 | def fn(): 41 | time.sleep(1) 42 | 43 | future = fn() 44 | print(future.result()) 45 | fn() # This could re-raise an exception from the first execution. 46 | print(len(pool)) # Would print "1" because there is one function in flight. 47 | pool.flush() # This could re-raise an exception from the second execution. 48 | """ 49 | 50 | def __init__(self, thread_name_prefix: str = "", 51 | max_workers: Optional[int] = None): 52 | """Creates a new pool that decorates functions for async execution. 53 | 54 | Args: 55 | thread_name_prefix: See documentation of `ThreadPoolExecutor`. 56 | max_workers: See documentation of `ThreadPoolExecutor`. The default `None` 57 | optimizes for parallelizability using the number of CPU cores. If you 58 | specify `max_workers=1` you the async calls are executed in the same 59 | order they have been scheduled. 60 | """ 61 | self._pool = concurrent.futures.ThreadPoolExecutor( 62 | max_workers=max_workers, thread_name_prefix=thread_name_prefix) 63 | self._max_workers = max_workers 64 | self._thread_name_prefix = thread_name_prefix 65 | self._errors = collections.deque() 66 | self._errors_mutex = threading.Lock() 67 | self._queue_length = 0 68 | 69 | def _reraise(self) -> None: 70 | if self._errors: 71 | with self._errors_mutex: 72 | exc_info = self._errors.popleft() 73 | exc = exc_info[1].with_traceback(exc_info[2]) 74 | raise AsyncError(f"Error '{exc}' occurred ASYNCHRONOUSLY.") from exc 75 | 76 | def close(self) -> None: 77 | """Closes this pool & raise a pending exception (if needed).""" 78 | self._pool.shutdown(wait=True) 79 | self._reraise() 80 | 81 | def join(self) -> None: 82 | """Blocks until all functions are processed. 83 | 84 | The pool can be used to schedule more functions after calling this function, 85 | but there might be more exceptions 86 | 87 | Side-effect: 88 | If any of the functions raised an exception, then the first of these 89 | exceptions is reraised. 90 | """ 91 | self._pool.shutdown(wait=True) 92 | self._pool = concurrent.futures.ThreadPoolExecutor( 93 | max_workers=self._max_workers, 94 | thread_name_prefix=self._thread_name_prefix) 95 | self._reraise() 96 | 97 | @property 98 | def queue_length(self) -> int: 99 | """Returns the number of functions that have not returned yet.""" 100 | return self._queue_length 101 | 102 | @property 103 | def has_errors(self) -> bool: 104 | """Returns True if there are any pending errors.""" 105 | return bool(self._errors) 106 | 107 | def clear_errors(self) -> List[Exception]: 108 | """Clears all pending errors and returns them as a (possibly empty) list.""" 109 | with self._errors_mutex: 110 | errors, self._errors = self._errors, collections.deque() 111 | return list(errors) 112 | 113 | def __call__(self, fn: Callable): # pylint: disable=g-bare-generic 114 | """Returns an async version of fn. 115 | 116 | The function will be executed by this class's ThreadPoolExecutor. Any errors 117 | will be stored and re-raised next time any function is called that is 118 | executed through this pool. 119 | 120 | Note that even if there was a previous error, the function is still 121 | scheduled upon re-execution of the wrapper returned by this function. 122 | 123 | Args: 124 | fn: Function to be wrapped. 125 | 126 | Returns: 127 | An async version of `fn`. The return value of that async version will be 128 | a future (unless an exception was re-raised). 129 | """ 130 | 131 | def inner(*args, **kwargs): 132 | 133 | def trap_errors(*args, **kwargs): 134 | try: 135 | return fn(*args, **kwargs) 136 | except Exception as e: 137 | with self._errors_mutex: 138 | self._errors.append(sys.exc_info()) 139 | logging.exception("Error in producer thread for %s", 140 | self._thread_name_prefix) 141 | raise e 142 | finally: 143 | self._queue_length -= 1 144 | 145 | self._queue_length += 1 146 | if not self.has_errors: 147 | return self._pool.submit(trap_errors, *args, **kwargs) 148 | self._pool.submit(trap_errors, *args, **kwargs) 149 | self._reraise() 150 | 151 | if isinstance(fn.__name__, str): 152 | # Regular function. 153 | return functools.wraps(fn)(inner) 154 | # Mock or another weird function that fails with functools.wraps(). 155 | return inner 156 | -------------------------------------------------------------------------------- /clu/metric_writers/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The CLU Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Defines a generic write interface. 16 | 17 | The write helper accepts a MetricWriter object and a Mapping[str, 18 | clu.metrics.Metric], and automatically writes to the appropriate typed write 19 | method of the writer depending on the type of the metric. 20 | """ 21 | 22 | # pylint: disable=g-importing-member 23 | 24 | import collections 25 | import getpass 26 | import os 27 | import re 28 | from typing import Any, List, Mapping, Optional, Tuple, Union 29 | 30 | from absl import flags 31 | from absl import logging 32 | from clu import values 33 | from clu.metric_writers.async_writer import AsyncMultiWriter 34 | from clu.metric_writers.interface import MetricWriter 35 | from clu.metric_writers.logging_writer import LoggingWriter 36 | from clu.metric_writers.multi_writer import MultiWriter 37 | from clu.metric_writers.summary_writer import SummaryWriter 38 | from etils import epath 39 | import jax.numpy as jnp 40 | import numpy as np 41 | 42 | 43 | FLAGS = flags.FLAGS 44 | 45 | 46 | def _is_scalar(value: Any) -> bool: 47 | if isinstance(value, values.Scalar) or isinstance( 48 | value, (int, float, np.number) 49 | ): 50 | return True 51 | if isinstance(value, (np.ndarray, jnp.ndarray)): 52 | return value.ndim == 0 or value.size <= 1 53 | return False 54 | 55 | 56 | def write_values( 57 | writer: MetricWriter, 58 | step: int, 59 | metrics: Mapping[ 60 | str, Union[values.Value, values.ArrayType, values.ScalarType] 61 | ], 62 | ): 63 | """Writes all provided metrics. 64 | 65 | Allows providing a mapping of name to Value object, where each Value 66 | specifies a type. The appropriate write method can then be called depending 67 | on the type. 68 | 69 | Args: 70 | writer: MetricWriter object 71 | step: Step at which the arrays were generated. 72 | metrics: Mapping from name to clu.values.Value object. 73 | """ 74 | writes = collections.defaultdict(dict) 75 | histogram_num_buckets = collections.defaultdict(int) 76 | for k, v in metrics.items(): 77 | if isinstance(v, values.Summary): 78 | writes[ 79 | (writer.write_summaries, frozenset({"metadata": v.metadata}.items())) 80 | ][k] = v.value 81 | elif _is_scalar(v): 82 | if isinstance(v, values.Scalar): 83 | writes[(writer.write_scalars, frozenset())][k] = v.value 84 | else: 85 | writes[(writer.write_scalars, frozenset())][k] = v 86 | elif isinstance(v, values.Image): 87 | writes[(writer.write_images, frozenset())][k] = v.value 88 | elif isinstance(v, values.Text): 89 | writes[(writer.write_texts, frozenset())][k] = v.value 90 | elif isinstance(v, values.HyperParam): 91 | writes[(writer.write_hparams, frozenset())][k] = v.value 92 | elif isinstance(v, values.Histogram): 93 | writes[(writer.write_histograms, frozenset())][k] = v.value 94 | histogram_num_buckets[k] = v.num_buckets 95 | elif isinstance(v, values.Audio): 96 | writes[( 97 | writer.write_audios, 98 | frozenset({"sample_rate": v.sample_rate}.items()), 99 | )][k] = v.value 100 | else: 101 | raise ValueError("Metric: ", k, " has unsupported value: ", v) 102 | 103 | for (fn, extra_args), vals in writes.items(): 104 | if fn == writer.write_histograms: 105 | # for write_histograms, the num_buckets arg is a Dict indexed by name 106 | writer.write_histograms(step, vals, num_buckets=histogram_num_buckets) 107 | else: 108 | fn(step, vals, **dict(extra_args)) 109 | 110 | 111 | 112 | 113 | def create_default_writer( 114 | logdir: Optional[epath.PathLike] = None, 115 | *, 116 | just_logging: bool = False, 117 | asynchronous: bool = True, 118 | collection: Optional[str] = None, 119 | ) -> MultiWriter: 120 | """Create the default writer for the platform. 121 | 122 | On most platforms this will create a MultiWriter that writes to multiple back 123 | ends (logging, TF summaries etc.). 124 | 125 | Args: 126 | logdir: Logging dir to use for TF summary files. If empty/None will the 127 | returned writer will not write TF summary files. 128 | just_logging: If True only use a LoggingWriter. This is useful in multi-host 129 | setups when only the first host should write metrics and all other hosts 130 | should only write to their own logs. 131 | default (None) will automatically determine if you # GOOGLE-INTERNAL have 132 | asynchronous: If True return an AsyncMultiWriter to not block when writing 133 | metrics. 134 | collection: A string which, if provided, provides an indication that the 135 | provided metrics should all be written to the same collection, or 136 | grouping. 137 | 138 | Returns: 139 | A `MetricWriter` according to the platform and arguments. 140 | """ 141 | if just_logging: 142 | if asynchronous: 143 | return AsyncMultiWriter([LoggingWriter(collection=collection)]) 144 | else: 145 | return MultiWriter([LoggingWriter(collection=collection)]) 146 | writers = [LoggingWriter(collection=collection)] 147 | if logdir is not None: 148 | logdir = epath.Path(logdir) 149 | if collection is not None: 150 | logdir /= collection 151 | writers.append(SummaryWriter(os.fspath(logdir))) 152 | if asynchronous: 153 | return AsyncMultiWriter(writers) 154 | return MultiWriter(writers) 155 | -------------------------------------------------------------------------------- /clu/metric_writers/logging_writer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The CLU Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """MetricWriter that writes all values to INFO log.""" 16 | 17 | from collections.abc import Mapping 18 | from typing import Any, Optional 19 | 20 | from absl import logging 21 | from clu.metric_writers import interface 22 | import numpy as np 23 | 24 | Array = interface.Array 25 | Scalar = interface.Scalar 26 | 27 | 28 | class LoggingWriter(interface.MetricWriter): 29 | """MetricWriter that writes all values to INFO log.""" 30 | 31 | def __init__(self, collection: Optional[str] = None): 32 | if collection: 33 | self._collection_str = f" collection={collection}" 34 | else: 35 | self._collection_str = "" 36 | 37 | def write_summaries( 38 | self, step: int, 39 | values: Mapping[str, Array], 40 | metadata: Optional[Mapping[str, Any]] = None): 41 | logging.info("[%d]%s Got raw tensors: %s.", step, self._collection_str, 42 | {k: v.shape for k, v in values.items()}) 43 | 44 | def write_scalars(self, step: int, scalars: Mapping[str, Scalar]): 45 | values = [ 46 | f"{k}={v:.6g}" if isinstance(v, float) else f"{k}={v}" 47 | for k, v in sorted(scalars.items()) 48 | ] 49 | logging.info("[%d]%s %s", step, self._collection_str, ", ".join(values)) 50 | 51 | def write_images(self, step: int, images: Mapping[str, Array]): 52 | logging.info("[%d]%s Got images: %s.", step, self._collection_str, 53 | {k: v.shape for k, v in images.items()}) 54 | 55 | def write_videos(self, step: int, videos: Mapping[str, Array]): 56 | logging.info("[%d]%s Got videos: %s.", step, self._collection_str, 57 | {k: v.shape for k, v in videos.items()}) 58 | 59 | def write_audios( 60 | self, step: int, audios: Mapping[str, Array], *, sample_rate: int): 61 | logging.info("[%d]%s Got audios: %s.", step, self._collection_str, 62 | {k: v.shape for k, v in audios.items()}) 63 | 64 | def write_texts(self, step: int, texts: Mapping[str, str]): 65 | logging.info("[%d]%s Got texts: %s.", step, self._collection_str, texts) 66 | 67 | def write_histograms(self, 68 | step: int, 69 | arrays: Mapping[str, Array], 70 | num_buckets: Optional[Mapping[str, int]] = None): 71 | num_buckets = num_buckets or {} 72 | for key, value in arrays.items(): 73 | histo, bins = _compute_histogram_as_tf( 74 | np.asarray(value), num_buckets=num_buckets.get(key)) 75 | if histo is not None: 76 | logging.info("[%d]%s Histogram for %r = {%s}", step, 77 | self._collection_str, key, 78 | _get_histogram_as_string(histo, bins)) 79 | 80 | def write_pointcloud( 81 | self, 82 | step: int, 83 | point_clouds: Mapping[str, Array], 84 | *, 85 | point_colors: Mapping[str, Any] | None = None, 86 | configs: Mapping[str, str | float | bool | None] | None = None, 87 | ): 88 | logging.info( 89 | "[%d]%s Got point clouds: %s, point_colors: %s, configs: %s.", 90 | step, 91 | self._collection_str, 92 | {k: v.shape for k, v in point_clouds.items()}, 93 | ( 94 | {k: v.shape for k, v in point_colors.items()} 95 | if point_colors is not None 96 | else None 97 | ), 98 | configs, 99 | ) 100 | 101 | def write_hparams(self, hparams: Mapping[str, Any]): 102 | logging.info("[Hyperparameters]%s %s", self._collection_str, hparams) 103 | 104 | def flush(self): 105 | logging.flush() 106 | 107 | def close(self): 108 | self.flush() 109 | 110 | 111 | def _compute_histogram_as_tf( 112 | array: np.ndarray, 113 | num_buckets: Optional[int] = None 114 | ) -> tuple[Optional[np.ndarray], Optional[np.ndarray]]: 115 | """Compute the histogram of the input array as TF would do. 116 | 117 | Args: 118 | array: Input data. The histogram is computed over the flattened array. 119 | num_buckets: The number of equal-width bins used to create the histogram. 120 | 121 | Returns: 122 | histo: A numpy array with the values of the histogram. 123 | bins: A numpy array with the bin edges (its length is length(histo)+1). 124 | 125 | If the histogram cannot be built because the array is empty, returns 126 | (None, None). 127 | """ 128 | # See DEFAULT_BUCKET_COUNT in tensorboard/plugins/histogram/summary_v2.py 129 | num_buckets = num_buckets or 30 130 | if num_buckets < 2: 131 | logging.log_first_n(logging.WARNING, 132 | "num_buckets was automatically changed from %d to 2", 1, 133 | num_buckets) 134 | num_buckets = 2 135 | 136 | if array.size == 0: 137 | return None, None 138 | 139 | range_max = np.max(array) 140 | range_min = np.min(array) 141 | if range_max == range_min: 142 | histo = np.asarray([array.size], dtype=np.int64) 143 | bins = np.asarray([range_max - 0.5, range_max + 0.5], dtype=np.float64) 144 | else: 145 | histo, bins = np.histogram( 146 | array, bins=num_buckets, range=(range_min, range_max)) 147 | bins = np.asarray(bins, dtype=np.float64) 148 | 149 | return histo, bins 150 | 151 | 152 | def _get_histogram_as_string(histo: np.ndarray, bins: np.ndarray): 153 | # First items are right-open (i.e. [a, b)). 154 | items = [ 155 | f"[{bins[i]:.3g}, {bins[i+1]:.3g}): {count}" 156 | for i, count in enumerate(histo[:-1]) 157 | ] 158 | # Last item is right-closed (i.e. [a, b]). 159 | items.append(f"[{bins[-2]:.3g}, {bins[-1]:.3g}]: {histo[-1]}") 160 | return ", ".join(items) 161 | -------------------------------------------------------------------------------- /clu/metric_writers/async_writer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The CLU Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """MetricWriter that writes metrics in a separate thread. 16 | 17 | - The order of the write calls is preserved. 18 | - Users need to all `flush()` or use the `ensure_flushes()` context to make sure 19 | that all metrics have been written. 20 | - Errors while writing in the background thread will be re-raised in the main 21 | thread on the next write_*() call. 22 | """ 23 | 24 | from collections.abc import Mapping, Sequence 25 | import contextlib 26 | from typing import Any, Optional 27 | 28 | from clu import asynclib 29 | 30 | from clu.metric_writers import interface 31 | from clu.metric_writers import multi_writer 32 | import wrapt 33 | 34 | Array = interface.Array 35 | Scalar = interface.Scalar 36 | 37 | 38 | @wrapt.decorator 39 | def _wrap_exceptions(wrapped, instance, args, kwargs): 40 | del instance 41 | try: 42 | return wrapped(*args, **kwargs) 43 | except asynclib.AsyncError as e: 44 | raise asynclib.AsyncError( 45 | "Consider re-running the code without AsyncWriter (e.g. creating a " 46 | "writer using " 47 | "`clu.metric_writers.create_default_writer(asynchronous=False)`)" 48 | ) from e 49 | 50 | 51 | class AsyncWriter(interface.MetricWriter): 52 | """MetricWriter that performs write operations in a separate thread. 53 | 54 | All write operations will be executed in a background thread. If an exceptions 55 | occurs in the background thread it will be raised on the main thread on the 56 | call of one of the write_* methods. 57 | 58 | Use num_workers > 1 at your own risk, if the underlying writer is not 59 | thread-safe or does not expect out-of-order events, this can cause problems. 60 | If num_workers is None then the ThreadPool will use `os.cpu_count()` 61 | processes. 62 | """ 63 | 64 | def __init__(self, 65 | writer: interface.MetricWriter, 66 | *, 67 | num_workers: Optional[int] = 1): 68 | super().__init__() 69 | self._writer = writer 70 | # By default, we have a thread pool with a single worker to ensure that 71 | # calls to the function are run in order (but in a background thread). 72 | self._num_workers = num_workers 73 | self._pool = asynclib.Pool( 74 | thread_name_prefix="AsyncWriter", max_workers=num_workers) 75 | 76 | 77 | @_wrap_exceptions 78 | def write_summaries( 79 | self, step: int, 80 | values: Mapping[str, Array], 81 | metadata: Optional[Mapping[str, Any]] = None): 82 | self._pool(self._writer.write_summaries)( 83 | step=step, values=values, metadata=metadata) 84 | 85 | @_wrap_exceptions 86 | def write_scalars(self, step: int, scalars: Mapping[str, Scalar]): 87 | self._pool(self._writer.write_scalars)(step=step, scalars=scalars) 88 | 89 | @_wrap_exceptions 90 | def write_images(self, step: int, images: Mapping[str, Array]): 91 | self._pool(self._writer.write_images)(step=step, images=images) 92 | 93 | @_wrap_exceptions 94 | def write_videos(self, step: int, videos: Mapping[str, Array]): 95 | self._pool(self._writer.write_videos)(step=step, videos=videos) 96 | 97 | @_wrap_exceptions 98 | def write_audios( 99 | self, step: int, audios: Mapping[str, Array], *, sample_rate: int): 100 | self._pool(self._writer.write_audios)( 101 | step=step, audios=audios, sample_rate=sample_rate) 102 | 103 | @_wrap_exceptions 104 | def write_texts(self, step: int, texts: Mapping[str, str]): 105 | self._pool(self._writer.write_texts)(step=step, texts=texts) 106 | 107 | @_wrap_exceptions 108 | def write_histograms(self, 109 | step: int, 110 | arrays: Mapping[str, Array], 111 | num_buckets: Optional[Mapping[str, int]] = None): 112 | self._pool(self._writer.write_histograms)( 113 | step=step, arrays=arrays, num_buckets=num_buckets) 114 | 115 | @_wrap_exceptions 116 | def write_pointcloud( 117 | self, 118 | step: int, 119 | point_clouds: Mapping[str, Array], 120 | *, 121 | point_colors: Mapping[str, Array] | None = None, 122 | configs: Mapping[str, str | float | bool | None] | None = None, 123 | ): 124 | self._pool(self._writer.write_pointcloud)( 125 | step=step, 126 | point_clouds=point_clouds, 127 | point_colors=point_colors, 128 | configs=configs, 129 | ) 130 | 131 | @_wrap_exceptions 132 | def write_hparams(self, hparams: Mapping[str, Any]): 133 | self._pool(self._writer.write_hparams)(hparams=hparams) 134 | 135 | def flush(self): 136 | try: 137 | self._pool.join() 138 | finally: 139 | self._writer.flush() 140 | 141 | def close(self): 142 | try: 143 | self.flush() 144 | finally: 145 | self._writer.close() 146 | 147 | 148 | class AsyncMultiWriter(multi_writer.MultiWriter): 149 | """AsyncMultiWriter writes to multiple writes in a separate thread.""" 150 | 151 | def __init__(self, 152 | writers: Sequence[interface.MetricWriter], 153 | *, 154 | num_workers: Optional[int] = 1): 155 | super().__init__([AsyncWriter(w, num_workers=num_workers) for w in writers]) 156 | 157 | 158 | @contextlib.contextmanager 159 | def ensure_flushes(*writers: interface.MetricWriter): 160 | """Context manager which ensures that one or more writers are flushed.""" 161 | try: 162 | # The caller should not need to use the yielded value, but we yield 163 | # the first writer to stay backwards compatible for a single writer. 164 | yield writers[0] 165 | finally: 166 | for writer in writers: 167 | writer.flush() 168 | -------------------------------------------------------------------------------- /clu/parameter_overview_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The CLU Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for parameter overviews.""" 16 | 17 | from absl.testing import absltest 18 | from clu import parameter_overview 19 | from flax import linen as nn 20 | import jax 21 | import jax.numpy as jnp 22 | import numpy as np 23 | 24 | 25 | EMPTY_PARAMETER_OVERVIEW = """+------+-------+-------+------+------+-----+ 26 | | Name | Shape | Dtype | Size | Mean | Std | 27 | +------+-------+-------+------+------+-----+ 28 | +------+-------+-------+------+------+-----+ 29 | Total: 0 -- 0 bytes""" 30 | 31 | FLAX_CONV2D_PARAMETER_OVERVIEW = """+-------------+--------------+---------+------+ 32 | | Name | Shape | Dtype | Size | 33 | +-------------+--------------+---------+------+ 34 | | conv/bias | (2,) | float32 | 2 | 35 | | conv/kernel | (3, 3, 3, 2) | float32 | 54 | 36 | +-------------+--------------+---------+------+ 37 | Total: 56 -- 224 bytes""" 38 | 39 | FLAX_CONV2D_PARAMETER_OVERVIEW_WITH_SHARDING = """+-------------+--------------+---------+------+----------+ 40 | | Name | Shape | Dtype | Size | Sharding | 41 | +-------------+--------------+---------+------+----------+ 42 | | conv/bias | (2,) | float32 | 2 | () | 43 | | conv/kernel | (3, 3, 3, 2) | float32 | 54 | () | 44 | +-------------+--------------+---------+------+----------+ 45 | Total: 56 -- 224 bytes""" 46 | 47 | FLAX_CONV2D_PARAMETER_OVERVIEW_WITH_STATS = """+-------------+--------------+---------+------+------+-----+ 48 | | Name | Shape | Dtype | Size | Mean | Std | 49 | +-------------+--------------+---------+------+------+-----+ 50 | | conv/bias | (2,) | float32 | 2 | 1.0 | 0.0 | 51 | | conv/kernel | (3, 3, 3, 2) | float32 | 54 | 1.0 | 0.0 | 52 | +-------------+--------------+---------+------+------+-----+ 53 | Total: 56 -- 224 bytes""" 54 | 55 | FLAX_CONV2D_PARAMETER_OVERVIEW_WITH_STATS_AND_SHARDING = """+-------------+--------------+---------+------+------+-----+----------+ 56 | | Name | Shape | Dtype | Size | Mean | Std | Sharding | 57 | +-------------+--------------+---------+------+------+-----+----------+ 58 | | conv/bias | (2,) | float32 | 2 | 1.0 | 0.0 | () | 59 | | conv/kernel | (3, 3, 3, 2) | float32 | 54 | 1.0 | 0.0 | () | 60 | +-------------+--------------+---------+------+------+-----+----------+ 61 | Total: 56 -- 224 bytes""" 62 | 63 | FLAX_CONV2D_MAPPING_PARAMETER_OVERVIEW_WITH_STATS = """+--------------------+--------------+---------+------+------+-----+ 64 | | Name | Shape | Dtype | Size | Mean | Std | 65 | +--------------------+--------------+---------+------+------+-----+ 66 | | params/conv/bias | (2,) | float32 | 2 | 1.0 | 0.0 | 67 | | params/conv/kernel | (3, 3, 3, 2) | float32 | 54 | 1.0 | 0.0 | 68 | +--------------------+--------------+---------+------+------+-----+ 69 | Total: 56 -- 224 bytes""" 70 | 71 | 72 | class CNN(nn.Module): 73 | 74 | @nn.compact 75 | def __call__(self, x): 76 | return nn.Conv(features=2, kernel_size=(3, 3), name="conv")(x) 77 | 78 | 79 | class JaxParameterOverviewTest(absltest.TestCase): 80 | 81 | def test_count_parameters_empty(self): 82 | self.assertEqual(0, parameter_overview.count_parameters({})) 83 | 84 | def test_count_parameters(self): 85 | rng = jax.random.PRNGKey(42) 86 | # Weights of a 2D convolution with 2 filters. 87 | variables = CNN().init(rng, jnp.zeros((2, 5, 5, 3))) 88 | # 3 * 3*3 * 2 + 2 (bias) = 56 parameters 89 | self.assertEqual(56, 90 | parameter_overview.count_parameters(variables["params"])) 91 | 92 | def test_get_parameter_overview_empty(self): 93 | self.assertEqual(EMPTY_PARAMETER_OVERVIEW, 94 | parameter_overview.get_parameter_overview({})) 95 | self.assertEqual(EMPTY_PARAMETER_OVERVIEW, 96 | parameter_overview.get_parameter_overview({"a": {}})) 97 | 98 | def test_get_parameter_overview(self): 99 | rng = jax.random.PRNGKey(42) 100 | # Weights of a 2D convolution with 2 filters. 101 | variables = CNN().init(rng, jnp.zeros((2, 5, 5, 3))) 102 | variables = jax.tree_util.tree_map(jnp.ones_like, variables) 103 | self.assertEqual( 104 | FLAX_CONV2D_PARAMETER_OVERVIEW, 105 | parameter_overview.get_parameter_overview( 106 | variables["params"], include_stats=False)) 107 | self.assertEqual( 108 | FLAX_CONV2D_PARAMETER_OVERVIEW_WITH_STATS, 109 | parameter_overview.get_parameter_overview(variables["params"])) 110 | self.assertEqual( 111 | FLAX_CONV2D_MAPPING_PARAMETER_OVERVIEW_WITH_STATS, 112 | parameter_overview.get_parameter_overview(variables)) 113 | # Add sharding with PartitionSpecs. 114 | mesh = jax.sharding.Mesh(np.asarray(jax.devices()), "d") 115 | sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec()) 116 | variables = jax.jit(lambda x: x, out_shardings=sharding)(variables) 117 | self.assertEqual( 118 | FLAX_CONV2D_PARAMETER_OVERVIEW_WITH_SHARDING, 119 | parameter_overview.get_parameter_overview( 120 | variables["params"], include_stats="sharding")) 121 | self.assertEqual( 122 | FLAX_CONV2D_PARAMETER_OVERVIEW_WITH_STATS_AND_SHARDING, 123 | parameter_overview.get_parameter_overview( 124 | variables["params"], include_stats="global")) 125 | 126 | def test_get_parameter_overview_shape_dtype_struct(self): 127 | variables_shape_dtype_struct = jax.eval_shape( 128 | lambda: CNN().init(jax.random.PRNGKey(42), jnp.zeros((2, 5, 5, 3)))) 129 | self.assertEqual( 130 | FLAX_CONV2D_PARAMETER_OVERVIEW, 131 | parameter_overview.get_parameter_overview( 132 | variables_shape_dtype_struct["params"], include_stats=False)) 133 | 134 | def test_printing_bool(self): 135 | self.assertEqual( 136 | parameter_overview._default_table_value_formatter(True), "True") 137 | self.assertEqual( 138 | parameter_overview._default_table_value_formatter(False), "False") 139 | 140 | 141 | if __name__ == "__main__": 142 | absltest.main() 143 | -------------------------------------------------------------------------------- /clu/metric_writers/utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The CLU Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for interface.""" 16 | # pylint: disable=g-importing-member 17 | 18 | import itertools 19 | from typing import Any 20 | from unittest import mock 21 | 22 | from absl.testing import absltest 23 | from absl.testing import parameterized 24 | from clu import values 25 | from clu.metric_writers import utils 26 | from clu.metric_writers.async_writer import AsyncMultiWriter 27 | from clu.metric_writers.async_writer import AsyncWriter 28 | from clu.metric_writers.interface import MetricWriter 29 | from clu.metric_writers.logging_writer import LoggingWriter 30 | from clu.metric_writers.multi_writer import MultiWriter 31 | from clu.metric_writers.summary_writer import SummaryWriter 32 | import clu.metrics 33 | import flax.struct 34 | import jax.numpy as jnp 35 | import tensorflow as tf 36 | 37 | 38 | @flax.struct.dataclass 39 | class HistogramMetric(clu.metrics.Metric): 40 | value: jnp.ndarray 41 | num_buckets: int 42 | 43 | def compute_value(self): 44 | return values.Histogram(self.value, self.num_buckets) 45 | 46 | 47 | @flax.struct.dataclass 48 | class ImageMetric(clu.metrics.Metric): 49 | value: jnp.ndarray 50 | 51 | def compute_value(self): 52 | return values.Image(self.value) 53 | 54 | 55 | @flax.struct.dataclass 56 | class AudioMetric(clu.metrics.Metric): 57 | value: jnp.ndarray 58 | sample_rate: int 59 | 60 | def compute_value(self): 61 | return values.Audio(self.value, self.sample_rate) 62 | 63 | 64 | @flax.struct.dataclass 65 | class TextMetric(clu.metrics.Metric): 66 | value: str 67 | 68 | def compute_value(self): 69 | return values.Text(self.value) 70 | 71 | 72 | @flax.struct.dataclass 73 | class HyperParamMetric(clu.metrics.Metric): 74 | value: float 75 | 76 | def compute_value(self): 77 | return values.HyperParam(self.value) 78 | 79 | 80 | @flax.struct.dataclass 81 | class SummaryMetric(clu.metrics.Metric): 82 | value: jnp.ndarray 83 | metadata: Any 84 | 85 | def compute_value(self): 86 | return values.Summary(self.value, self.metadata) 87 | 88 | 89 | def _to_summary(metrics): 90 | return {k: v.value for k, v in metrics.items()} 91 | 92 | 93 | def _to_list_of_dicts(d): 94 | return [{k: v} for k, v in d.items()] 95 | 96 | 97 | class ONEOF(object): 98 | """ONEOF(options_list) check value in options_list.""" 99 | 100 | def __init__(self, container): 101 | if not hasattr(container, "__contains__"): 102 | raise TypeError(f"{container!r} is not a container") 103 | if not container: 104 | raise ValueError(f"{container!r} is empty") 105 | self._c = container 106 | 107 | def __eq__(self, o): 108 | return o in self._c 109 | 110 | def __ne__(self, o): 111 | return o not in self._c 112 | 113 | def __repr__(self): 114 | return "".format(",".join(repr(i) for i in self._c)) 115 | 116 | 117 | class MetricWriterTest(tf.test.TestCase, parameterized.TestCase): 118 | 119 | def test_write(self): 120 | writer = mock.Mock(spec_set=MetricWriter) 121 | step = 3 122 | num_buckets = 4 123 | sample_rate = 10 124 | scalar_metrics = { 125 | "loss": clu.metrics.Average.from_model_output(jnp.asarray([1, 2, 3])), 126 | "accuracy": clu.metrics.LastValue.from_model_output(jnp.asarray([5])), 127 | } 128 | image_metrics = { 129 | "image": ImageMetric(jnp.asarray([[4, 5], [1, 2]])), 130 | } 131 | histogram_metrics = { 132 | "hist": HistogramMetric( 133 | value=jnp.asarray([7, 8]), num_buckets=num_buckets 134 | ), 135 | "hist2": HistogramMetric( 136 | value=jnp.asarray([9, 10]), num_buckets=num_buckets 137 | ), 138 | } 139 | audio_metrics = { 140 | "audio": AudioMetric( 141 | value=jnp.asarray([1, 5]), sample_rate=sample_rate 142 | ), 143 | "audio2": AudioMetric( 144 | value=jnp.asarray([1, 5]), sample_rate=sample_rate + 2 145 | ), 146 | } 147 | text_metrics = { 148 | "text": TextMetric(value="hello"), 149 | } 150 | hparam_metrics = { 151 | "lr": HyperParamMetric(value=0.01), 152 | } 153 | summary_metrics = { 154 | "summary": SummaryMetric( 155 | value=jnp.asarray([2, 3, 10]), metadata="some info" 156 | ), 157 | "summary2": SummaryMetric(value=jnp.asarray([2, 3, 10]), metadata=5), 158 | } 159 | metrics = { 160 | **scalar_metrics, 161 | **image_metrics, 162 | **histogram_metrics, 163 | **audio_metrics, 164 | **text_metrics, 165 | **hparam_metrics, 166 | **summary_metrics, 167 | } 168 | metrics = {k: m.compute_value() for k, m in metrics.items()} 169 | utils.write_values(writer, step, metrics) 170 | 171 | writer.write_scalars.assert_called_once_with( 172 | step, {k: m.compute() for k, m in scalar_metrics.items()} 173 | ) 174 | writer.write_images.assert_called_once_with( 175 | step, _to_summary(image_metrics) 176 | ) 177 | writer.write_histograms.assert_called_once_with( 178 | step, 179 | _to_summary(histogram_metrics), 180 | num_buckets={k: v.num_buckets for k, v in histogram_metrics.items()}, 181 | ) 182 | writer.write_audios.assert_called_with( 183 | step, 184 | ONEOF(_to_list_of_dicts(_to_summary(audio_metrics))), 185 | sample_rate=ONEOF([sample_rate, sample_rate + 2]), 186 | ) 187 | writer.write_texts.assert_called_once_with(step, _to_summary(text_metrics)) 188 | writer.write_hparams.assert_called_once_with( 189 | step, _to_summary(hparam_metrics) 190 | ) 191 | writer.write_summaries.assert_called_with( 192 | step, 193 | ONEOF(_to_list_of_dicts(_to_summary(summary_metrics))), 194 | metadata=ONEOF(["some info", 5]), 195 | ) 196 | 197 | 198 | def test_create_default_writer_summary_writer_is_added(self): 199 | writer = utils.create_default_writer( 200 | logdir=self.get_temp_dir(), asynchronous=False 201 | ) 202 | self.assertTrue(any(isinstance(w, SummaryWriter) for w in writer._writers)) 203 | writer = utils.create_default_writer(logdir=None, asynchronous=False) 204 | self.assertFalse(any(isinstance(w, SummaryWriter) for w in writer._writers)) 205 | 206 | 207 | if __name__ == "__main__": 208 | absltest.main() 209 | -------------------------------------------------------------------------------- /clu/preprocess_spec_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The CLU Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import dataclasses 16 | 17 | from absl import logging 18 | from absl.testing import parameterized 19 | from clu import preprocess_spec 20 | import tensorflow as tf 21 | 22 | Features = preprocess_spec.Features 23 | SEED_KEY = preprocess_spec.SEED_KEY 24 | 25 | 26 | @dataclasses.dataclass(frozen=True) 27 | class ToFloat: 28 | 29 | def __call__(self, features: Features) -> Features: 30 | return {k: tf.cast(v, tf.float32) / 255.0 for k, v in features.items()} 31 | 32 | 33 | @dataclasses.dataclass(frozen=True) 34 | class Rescale: 35 | 36 | scale: int = 1 37 | 38 | def __call__(self, features: Features) -> Features: 39 | features["image"] *= self.scale 40 | features["segmentation_mask"] *= self.scale 41 | return features 42 | 43 | 44 | @dataclasses.dataclass(frozen=True) 45 | class AddRandomInteger(preprocess_spec.RandomMapTransform): 46 | 47 | def _transform(self, features, seed): 48 | features["x"] = tf.random.stateless_uniform([], seed) 49 | return features 50 | 51 | 52 | all_ops = lambda: preprocess_spec.get_all_ops(__name__) 53 | 54 | 55 | class PreprocessSpecTest(parameterized.TestCase, tf.test.TestCase): 56 | """Tests for parsing preprocessing op spec.""" 57 | 58 | def test_no_arguments(self): 59 | op = preprocess_spec._parse_single_preprocess_op("rescale", dict(all_ops())) 60 | logging.info("op: %r", op) 61 | self.assertEqual(str(op), "Rescale(scale=1)") 62 | 63 | def test_positional_argument(self): 64 | op = preprocess_spec._parse_single_preprocess_op("rescale(2)", 65 | dict(all_ops())) 66 | logging.info("op: %r", op) 67 | self.assertEqual(str(op), "Rescale(scale=2)") 68 | 69 | def test_keyword_argument(self): 70 | op = preprocess_spec._parse_single_preprocess_op("rescale(scale=3)", 71 | dict(all_ops())) 72 | logging.info("op: %r", op) 73 | self.assertEqual(str(op), "Rescale(scale=3)") 74 | 75 | def test_invalid_op_name(self): 76 | with self.assertRaisesRegex( 77 | ValueError, r"'does_not_exist' is not available \(available ops: " 78 | r"\['add_random_integer', 'rescale', 'to_float'\]\)."): 79 | preprocess_spec._parse_single_preprocess_op("does_not_exist", 80 | dict(all_ops())) 81 | 82 | def test_invalid_spec(self): 83 | with self.assertRaisesRegex( 84 | ValueError, r"'rescale\)' is not a valid preprocess op spec."): 85 | preprocess_spec._parse_single_preprocess_op("rescale)", dict(all_ops())) 86 | 87 | def test_pos_and_kw_arg(self): 88 | with self.assertRaisesRegex( 89 | ValueError, 90 | r"Rescale'> given both as positional argument \(value: 2\) and keyword " 91 | r"argument \(value: 3\)."): 92 | preprocess_spec._parse_single_preprocess_op("rescale(2, scale=3)", 93 | dict(all_ops())) 94 | 95 | def test_parsing_empty_string(self): 96 | preprocess_fn = preprocess_spec.parse("", all_ops()) 97 | self.assertEqual( 98 | str(preprocess_fn), "PreprocessFn(ops=[], only_jax_types=True)") 99 | 100 | def test_multi_op_spec(self): 101 | preprocess_fn = preprocess_spec.parse("to_float|rescale(3)", all_ops()) 102 | logging.info("preprocess_fn: %r", preprocess_fn) 103 | self.assertEqual(str(preprocess_fn.ops), "[ToFloat(), Rescale(scale=3)]") 104 | 105 | def test_two_tensors(self): 106 | preprocess_fn = preprocess_spec.parse("rescale(scale=7)", all_ops()) 107 | x = {"image": tf.constant(3), "segmentation_mask": tf.constant(2)} 108 | y = preprocess_fn(x) 109 | self.assertEqual(y, { 110 | "image": tf.constant(21), 111 | "segmentation_mask": tf.constant(14), 112 | }) 113 | 114 | def test_only_jax_types(self): 115 | preprocess_fn = preprocess_spec.parse("", all_ops()) 116 | x = { 117 | "image": tf.constant(2), 118 | # Strings are not supported. 119 | "label": tf.constant("bla"), 120 | # Sparse tensors are not supported. 121 | "foo": tf.sparse.eye(4), 122 | # Ragged tensors are not supported. 123 | "bar": tf.RaggedTensor.from_tensor([[1, 2, 3], [4, 5, 6]]), 124 | } 125 | y = preprocess_fn(x) 126 | self.assertEqual(y, {"image": tf.constant(2)}) 127 | 128 | def test_only_jax_types_nested_inputs(self): 129 | preprocess_fn = preprocess_spec.parse("", all_ops()) 130 | x = { 131 | "nested": { 132 | "not_allowed": tf.constant("bla"), 133 | "allowed": tf.constant(2), 134 | } 135 | } 136 | y = preprocess_fn(x) 137 | self.assertEqual(y, {"nested": {"allowed": tf.constant(2)}}) 138 | 139 | def test_not_only_jax_types(self): 140 | preprocess_fn = preprocess_spec.parse("", all_ops(), only_jax_types=False) 141 | x = {"image": tf.constant(2), "label": tf.constant("bla")} 142 | y = preprocess_fn(x) 143 | self.assertEqual(y, x) 144 | 145 | def test_add_preprocess_fn(self): 146 | op1 = ToFloat() 147 | op2 = ToFloat() 148 | op3 = ToFloat() 149 | fn1 = preprocess_spec.PreprocessFn(ops=(op1, op2), only_jax_types=False) 150 | fn2 = preprocess_spec.PreprocessFn(ops=(op3,), only_jax_types=True) 151 | fn12 = fn1 + fn2 152 | # Note: `+` is not supported on Sequence[PreprocessOp]; need to use `list`. 153 | self.assertSequenceEqual(fn12.ops, list(fn1.ops) + list(fn2.ops)) 154 | self.assertTrue(fn12.only_jax_types) 155 | 156 | def test_slice_preprocess_fn(self): 157 | op1 = ToFloat() 158 | op2 = Rescale() 159 | op3 = ToFloat() 160 | fn = preprocess_spec.PreprocessFn(ops=(op1, op2, op3), only_jax_types=True) 161 | self.assertEqual(fn[:-1].ops, (op1, op2)) 162 | self.assertTrue(fn[:-1].only_jax_types) 163 | self.assertEqual(fn[1].ops, [op2]) 164 | self.assertTrue(fn[1].only_jax_types) 165 | 166 | def test_random_map_transform(self): 167 | ds = tf.data.Dataset.from_tensor_slices( 168 | {SEED_KEY: [[1, 2], [3, 4], [1, 2]]}) 169 | ds = ds.map(AddRandomInteger()) 170 | actual = list(ds) 171 | print("actual:", actual) 172 | expect = [ 173 | # Random number was generated and random seed changed. 174 | { 175 | "x": 0.8838011, 176 | SEED_KEY: [1105988140, 1738052849] 177 | }, 178 | { 179 | "x": 0.33396423, 180 | SEED_KEY: [-1860230133, -671226999] 181 | }, 182 | # Same random seed as first element creates same outcome. 183 | { 184 | "x": 0.8838011, 185 | SEED_KEY: [1105988140, 1738052849] 186 | }, 187 | ] 188 | self.assertAllClose(actual, expect) 189 | 190 | 191 | if __name__ == "__main__": 192 | tf.test.main() 193 | -------------------------------------------------------------------------------- /clu/metric_writers/interface.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The CLU Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Library for unify reporting model metrics across various logging formats. 16 | 17 | This library provides a MetricWriter for each logging format (SummyWriter, 18 | LoggingWriter, etc.) and composing MetricWriter to add support for asynchronous 19 | logging or writing to multiple formats. 20 | """ 21 | 22 | import abc 23 | from collections.abc import Mapping 24 | from typing import Any, Optional, Union 25 | 26 | import jax.numpy as jnp 27 | import numpy as np 28 | 29 | Array = Union[np.ndarray, jnp.ndarray] 30 | Scalar = Union[int, float, np.number, np.ndarray, jnp.ndarray] 31 | 32 | 33 | class MetricWriter(abc.ABC): 34 | """MetricWriter inferface.""" 35 | 36 | @abc.abstractmethod 37 | def write_summaries( 38 | self, step: int, 39 | values: Mapping[str, Array], 40 | metadata: Optional[Mapping[str, Any]] = None): 41 | """Saves an arbitrary tensor summary. 42 | 43 | Useful when working with custom plugins or constructing a summary directly. 44 | 45 | Args: 46 | step: Step at which the scalar values occurred. 47 | values: Mapping from tensor keys to tensors. 48 | metadata: Optional SummaryMetadata, as a proto or serialized bytes. 49 | Note that markdown formatting is rendered by tensorboard. 50 | """ 51 | 52 | @abc.abstractmethod 53 | def write_scalars(self, step: int, scalars: Mapping[str, Scalar]): 54 | """Write scalar values for the step. 55 | 56 | Consecutive calls to this method can provide different sets of scalars. 57 | Repeated writes for the same metric at the same step are not allowed. 58 | 59 | Args: 60 | step: Step at which the scalar values occurred. 61 | scalars: Mapping from metric name to value. 62 | """ 63 | 64 | @abc.abstractmethod 65 | def write_images(self, step: int, images: Mapping[str, Array]): 66 | """Write images for the step. 67 | 68 | Consecutive calls to this method can provide different sets of images. 69 | Repeated writes for the same image key at the same step are not allowed. 70 | 71 | Warning: Not all MetricWriter implementation support writing images! 72 | 73 | Args: 74 | step: Step at which the images occurred. 75 | images: Mapping from image key to images. Images should have the shape [N, 76 | H, W, C] or [H, W, C], where H is the height, W is the width and C the 77 | number of channels (1 or 3). N is the number of images that will be 78 | written. Image dimensions can differ between different image keys but 79 | not between different steps for the same image key. 80 | """ 81 | 82 | @abc.abstractmethod 83 | def write_videos(self, step: int, videos: Mapping[str, Array]): 84 | """Write videos for the step. 85 | 86 | Warning: Logging only. 87 | Not all MetricWriter implementation support writing videos! 88 | 89 | Consecutive calls to this method can provide different sets of videos. 90 | Repeated writes for the same video key at the same step are not allowed. 91 | 92 | 93 | Args: 94 | step: Step at which the videos occurred. 95 | videos: Mapping from video key to videos. videos should have the shape 96 | [N, T, H, W, C] or [T, H, W, C], where T is time, H is the height, 97 | W is the width and C the number of channels (1 or 3). N is the number 98 | of videos that will be written. Video dimensions can differ between 99 | different video keys but not between different steps for the same 100 | video key. 101 | """ 102 | 103 | @abc.abstractmethod 104 | def write_audios( 105 | self, step: int, audios: Mapping[str, Array], *, sample_rate: int): 106 | """Write audios for the step. 107 | 108 | Consecutive calls to this method can provide different sets of audios. 109 | Repeated writes for the same audio key at the same step are not allowed. 110 | 111 | Warning: Not all MetricWriter implementation support writing audios! 112 | 113 | Args: 114 | step: Step at which the audios occurred. 115 | audios: Mapping from audio key to audios. Audios should have the shape 116 | [N, T, C], where T is the time length and C the number of channels 117 | (1 - mono, 2 - stereo, >= 3 - surround; not all writers support any 118 | number of channels). N is the number of audios that will be written. 119 | Audio dimensions can differ between different audio keys but not between 120 | different steps for the same audio key. Values should be floating-point 121 | values in [-1, +1]. 122 | sample_rate: Sample rate for the audios. 123 | """ 124 | 125 | @abc.abstractmethod 126 | def write_texts(self, step: int, texts: Mapping[str, str]): 127 | """Writes text snippets for the step. 128 | 129 | Warning: Not all MetricWriter implementation support writing text! 130 | 131 | Args: 132 | step: Step at which the text snippets occurred. 133 | texts: Mapping from name to text snippet. 134 | """ 135 | 136 | @abc.abstractmethod 137 | def write_histograms(self, 138 | step: int, 139 | arrays: Mapping[str, Array], 140 | num_buckets: Optional[Mapping[str, int]] = None): 141 | """Writes histograms for the step. 142 | 143 | Consecutive calls to this method can provide different sets of scalars. 144 | Repeated writes for the same metric at the same step are not allowed. 145 | 146 | Warning: Not all MetricWriter implementation support writing histograms! 147 | 148 | Args: 149 | step: Step at which the arrays were generated. 150 | arrays: Mapping from name to arrays to summarize. 151 | num_buckets: Number of buckets used to create the histogram of the arrays. 152 | The default number of buckets depends on the particular implementation 153 | of the MetricWriter. 154 | """ 155 | 156 | def write_pointcloud( 157 | self, 158 | step: int, 159 | point_clouds: Mapping[str, Array], 160 | *, 161 | point_colors: Mapping[str, Array] | None = None, 162 | configs: Mapping[str, str | float | bool | None] | None = None, 163 | ): 164 | """Writes point cloud summaries. 165 | 166 | Args: 167 | step: Step at which the point cloud was generated. 168 | point_clouds: Mapping from point clouds key to point cloud of shape [N, 3] 169 | array of point coordinates. 170 | point_colors: Mapping from point colors key to [N, 3] array of point 171 | colors. 172 | configs: A dictionary of configuration options for the point cloud. 173 | """ 174 | raise NotImplementedError() 175 | 176 | @abc.abstractmethod 177 | def write_hparams(self, hparams: Mapping[str, Any]): 178 | """Write hyper parameters. 179 | 180 | Do not call twice. 181 | 182 | Args: 183 | hparams: Flat mapping from hyper parameter name to value. 184 | """ 185 | 186 | @abc.abstractmethod 187 | def flush(self): 188 | """Tells the MetricWriter to write out any cached values.""" 189 | 190 | @abc.abstractmethod 191 | def close(self): 192 | """Flushes and closes the MetricWriter. 193 | 194 | Calling any method on MetricWriter after MetricWriter.close() 195 | is undefined behavior. 196 | """ 197 | -------------------------------------------------------------------------------- /clu/metric_writers/tf/summary_writer_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The CLU Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for SummaryWriter.""" 16 | 17 | import collections 18 | import os 19 | 20 | from clu.metric_writers.tf import summary_writer 21 | import numpy as np 22 | import tensorflow as tf 23 | 24 | from tensorboard.plugins.hparams import plugin_data_pb2 25 | 26 | 27 | def _load_summaries_data(logdir): 28 | """Loads raw summaries data from events in a logdir.""" 29 | paths = tf.io.gfile.glob(os.path.join(logdir, "events.out.tfevents.*")) 30 | data = collections.defaultdict(dict) 31 | metadata = collections.defaultdict(dict) 32 | for path in paths: 33 | for event in tf.compat.v1.train.summary_iterator(path): 34 | for value in event.summary.value: 35 | data[event.step][value.tag] = tf.make_ndarray(value.tensor) 36 | if value.HasField("metadata"): 37 | metadata[event.step][value.tag] = value.metadata.SerializeToString() 38 | return data, metadata 39 | 40 | 41 | def _load_histograms_data(logdir): 42 | """Loads tensor summaries from events in a logdir.""" 43 | # Note: new versions of histograms don't use the HistogramProto type, but 44 | # they are written as tensors representing the bounds and counts of buckets, 45 | # with plugin_name = "histogram". 46 | paths = tf.io.gfile.glob(os.path.join(logdir, "events.out.tfevents.*")) 47 | data = {} 48 | for path in paths: 49 | for event in tf.compat.v1.train.summary_iterator(path): 50 | for value in event.summary.value: 51 | current_steps, current_tensors = data.get(value.tag, ([], [])) 52 | data[value.tag] = (current_steps + [event.step], 53 | current_tensors + [tf.make_ndarray(value.tensor)]) 54 | return { 55 | tag: (np.stack(steps), np.stack(tensors)) 56 | for tag, (steps, tensors) in data.items() 57 | } 58 | 59 | 60 | def _load_scalars_data(logdir: str): 61 | """Loads scalar summaries from events in a logdir.""" 62 | paths = tf.io.gfile.glob(os.path.join(logdir, "events.out.tfevents.*")) 63 | data = collections.defaultdict(dict) 64 | for path in paths: 65 | for event in tf.compat.v1.train.summary_iterator(path): 66 | for value in event.summary.value: 67 | data[event.step][value.tag] = tf.make_ndarray(value.tensor).flat[0] 68 | 69 | return data 70 | 71 | 72 | def _load_pointcloud_data(logdir: str): 73 | """Loads pointcloud summaries from events in a logdir.""" 74 | paths = tf.io.gfile.glob(os.path.join(logdir, "events.out.tfevents.*")) 75 | data = collections.defaultdict(dict) 76 | for path in paths: 77 | for event in tf.compat.v1.train.summary_iterator(path): 78 | for value in event.summary.value: 79 | if value.metadata.plugin_data.plugin_name == "mesh": 80 | if "config" not in value.tag: 81 | data[event.step][value.tag] = tf.make_ndarray(value.tensor) 82 | else: 83 | data[event.step][value.tag] = value.metadata.plugin_data.content 84 | return data 85 | 86 | 87 | def _load_hparams(logdir: str): 88 | """Loads hparams summaries from events in a logdir.""" 89 | paths = tf.io.gfile.glob(os.path.join(logdir, "events.out.tfevents.*")) 90 | # data = collections.defaultdict(dict) 91 | hparams = [] 92 | for path in paths: 93 | for event in tf.compat.v1.train.summary_iterator(path): 94 | for value in event.summary.value: 95 | if value.metadata.plugin_data.plugin_name == "hparams": 96 | hparams.append(plugin_data_pb2.HParamsPluginData.FromString( 97 | value.metadata.plugin_data.content)) 98 | return hparams 99 | 100 | 101 | class SummaryWriterTest(tf.test.TestCase): 102 | 103 | def setUp(self): 104 | super().setUp() 105 | self.logdir = self.get_temp_dir() 106 | self.writer = summary_writer.SummaryWriter(self.logdir) 107 | 108 | def test_write_summaries(self): 109 | self.writer.write_summaries( 110 | 11, 111 | {"a": np.eye(3, dtype=np.uint8), 112 | "b": np.eye(2, dtype=np.float32)}, 113 | {"a": np.ones((2, 3)).tobytes()}) 114 | self.writer.flush() 115 | data, metadata = _load_summaries_data(self.logdir) 116 | self.assertAllClose( 117 | data[11], 118 | {"a": np.eye(3, dtype=np.uint8), "b": np.eye(2, dtype=np.float32)}) 119 | self.assertIn("a", metadata[11]) 120 | 121 | def test_write_scalar(self): 122 | self.writer.write_scalars(11, {"a": 0.6, "b": 15}) 123 | self.writer.write_scalars(20, {"a": 0.8, "b": 12}) 124 | self.writer.flush() 125 | data = _load_scalars_data(self.logdir) 126 | self.assertAllClose(data[11], {"a": 0.6, "b": 15}) 127 | self.assertAllClose(data[20], {"a": 0.8, "b": 12}) 128 | 129 | def test_write_histograms(self): 130 | self.writer.write_histograms( 131 | 0, { 132 | "a": np.asarray([0.3, 0.1, 0.5, 0.7, 0.1]), 133 | "b": np.asarray([-0.1, 0.3, 0.2, 0.4, 0.4]), 134 | }, num_buckets={"a": 2, "b": 2}) 135 | self.writer.write_histograms( 136 | 2, { 137 | "a": np.asarray([0.2, 0.4, 0.5, 0.1, -0.1]), 138 | "b": np.asarray([0.7, 0.3, 0.2, 0.1, 0.0]), 139 | }, num_buckets={"a": 2, "b": 2}) 140 | self.writer.flush() 141 | data = _load_histograms_data(self.logdir) 142 | # In the histograms, each tuple represents 143 | # (bucket_min, bucket_max, bucket_count), where bucket_min is inclusive and 144 | # bucket_max is exclusive (except the last bucket_max which is inclusive). 145 | expected_histograms_a = [ 146 | # Step 0. 147 | [(0.1, 0.4, 3), (0.4, 0.7, 2)], 148 | # Step 1. 149 | [(-0.1, 0.2, 2), (0.2, 0.5, 3)], 150 | ] 151 | self.assertAllClose(data["a"], ([0, 2], expected_histograms_a)) 152 | expected_histograms_b = [ 153 | # Step 0. 154 | [(-0.1, 0.15, 1), (0.15, 0.4, 4)], 155 | # Step 1. 156 | [(0.0, 0.35, 4), (0.35, 0.7, 1)], 157 | ] 158 | self.assertAllClose(data["b"], ([0, 2], expected_histograms_b)) 159 | 160 | def test_write_pointcloud(self): 161 | point_clouds = np.random.normal(0, 1, (1, 1024, 3)).astype(np.float32) 162 | point_colors = np.random.uniform(0, 1, (1, 1024, 3)).astype(np.float32) 163 | config = { 164 | "material": "PointCloudMaterial", 165 | "size": 0.09, 166 | } 167 | self.writer.write_pointcloud( 168 | step=0, 169 | point_clouds={"pcd": point_clouds}, 170 | point_colors={"pcd": point_colors}, 171 | configs={"config": config}, 172 | ) 173 | self.writer.flush() 174 | data = _load_pointcloud_data(self.logdir) 175 | self.assertAllClose(data[0]["pcd_VERTEX"], point_clouds) 176 | self.assertAllClose(data[0]["pcd_COLOR"], point_colors) 177 | 178 | def test_hparams(self): 179 | self.writer.write_hparams(dict(batch_size=512, num_epochs=90)) 180 | hparams = _load_hparams(self.logdir) 181 | self.assertLen(hparams, 1) 182 | hparams_dict = hparams[0].session_start_info.hparams 183 | self.assertLen(hparams_dict, 2) 184 | self.assertEqual(512, hparams_dict["batch_size"].number_value) 185 | self.assertEqual(90, hparams_dict["num_epochs"].number_value) 186 | 187 | def test_hparams_nested(self): 188 | config = { 189 | "list": [1, 2], 190 | "tuple": (3, 4), 191 | "subconfig": { 192 | "value": "a", 193 | "list": [10, 20], 194 | }, 195 | } 196 | self.writer.write_hparams(config) 197 | hparams = _load_hparams(self.logdir) 198 | self.assertLen(hparams, 1) 199 | hparams_dict = hparams[0].session_start_info.hparams 200 | self.assertLen(hparams_dict, 7) 201 | self.assertEqual(1, hparams_dict["list.0"].number_value) 202 | self.assertEqual(2, hparams_dict["list.1"].number_value) 203 | self.assertEqual(3, hparams_dict["tuple.0"].number_value) 204 | self.assertEqual(4, hparams_dict["tuple.1"].number_value) 205 | self.assertEqual("a", hparams_dict["subconfig.value"].string_value) 206 | self.assertEqual(10, hparams_dict["subconfig.list.0"].number_value) 207 | self.assertEqual(20, hparams_dict["subconfig.list.1"].number_value) 208 | 209 | if __name__ == "__main__": 210 | tf.test.main() 211 | -------------------------------------------------------------------------------- /clu/data/dataset_iterator.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The CLU Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Interface for dataset iterators. 16 | 17 | This module provides the DatasetIterator interface. This intention is that 18 | several frameworks providing datasets can implement this interface without 19 | knowing anything about the framework used for the model and the training loop. 20 | Likewise can training loops assume to get an DatasetIterator object and do not 21 | need to care about the specifics of the input pipelines. 22 | 23 | This modules does not depend on TensorFlow. The interface is generic and users 24 | don't have to use `tf.data` to construct a DatasetIterator. However, if they 25 | use `tf.data` they can simply wrap their `tf.data.Dataset` object with 26 | `TfDatasetIterator` to satisfy the interface. 27 | """ 28 | from __future__ import annotations 29 | 30 | import abc 31 | import collections.abc 32 | import concurrent.futures 33 | import dataclasses 34 | import os 35 | import threading 36 | import typing 37 | from typing import Any, Mapping, Optional, Sequence, Tuple, TypeVar, Union 38 | 39 | from absl import logging 40 | from clu import asynclib 41 | from etils import epath 42 | import jax.numpy as jnp # Just for type checking. 43 | import numpy as np 44 | import numpy.typing as npt 45 | 46 | Array = Union[np.ndarray, jnp.ndarray] 47 | # Sizes of dimensions, None means the dimension size is unknown. 48 | Shape = Tuple[Optional[int], ...] 49 | 50 | 51 | @dataclasses.dataclass(frozen=True) 52 | class ArraySpec: 53 | """Describes an array via it's dtype and shape.""" 54 | dtype: npt.DTypeLike 55 | shape: Shape 56 | 57 | def __repr__(self): 58 | return f"ArraySpec(dtype={np.dtype(self.dtype).name}, shape={self.shape})" 59 | 60 | def __str__(self): 61 | return f"{np.dtype(self.dtype).name}{list(self.shape)}" 62 | 63 | 64 | # Elements are PyTrees with NumPy/JAX arrays. 65 | 66 | # Anything can be a PyTree (it's either a container or leaf). We define 67 | # PyTree[T] as a PyTree where all leaves are of type T. 68 | # See https://jax.readthedocs.io/en/latest/pytrees.html. 69 | L = TypeVar("L") # pylint: disable=invalid-name 70 | 71 | PyTree = Union[L, Sequence["PyTree[L]"], Mapping[str, "PyTree[L]"]] 72 | 73 | Element = PyTree[Array] 74 | ElementSpec = PyTree[ArraySpec] 75 | 76 | 77 | class DatasetIterator(collections.abc.Iterator): # pytype: disable=ignored-abstractmethod 78 | """Generic interface for iterating over a dataset. 79 | 80 | This does not support __getitem__ since it cannot be implemented efficiently 81 | for many datasets. However datasets should allow starting the iterator from 82 | an arbitrary position. 83 | 84 | The element_spec property helps consumers to validate the input without 85 | reading data. This is similar to `tf.data.Dataset.element_spec`. 86 | 87 | Subclasses may decided to not read/write checkpoints if their state is 88 | sufficiently tracked externally (e.g. input pipelines that can be correctly 89 | restarted from the step number). 90 | """ 91 | 92 | def get_next(self) -> Element: 93 | """Returns the next element.""" 94 | logging.error( 95 | "DatasetIterator.get_next() is deprecated. Please use next().") 96 | # Subclasses should implement __next__() and remove calls to get_next(). 97 | return next(self) 98 | 99 | def reset(self): 100 | """Resets the iterator back to the beginning.""" 101 | raise NotImplementedError 102 | 103 | @property 104 | @abc.abstractmethod 105 | def element_spec(self) -> ElementSpec: 106 | """Returns the spec elements.""" 107 | raise NotImplementedError() 108 | 109 | def save(self, filename: epath.Path): 110 | """Saves the state of the iterator to a file. 111 | 112 | This should only handle this iterator - not iterators in other processes. 113 | 114 | Args: 115 | filename: Name of the checkpoint. 116 | """ 117 | raise NotImplementedError() 118 | 119 | def restore(self, filename: epath.Path): 120 | """Restores the iterator from a file (if available). 121 | 122 | This should only handle this iterator - not iterators in other processes. 123 | 124 | Args: 125 | filename: Name of the checkpoint. 126 | """ 127 | raise NotImplementedError() 128 | 129 | def load(self, filename: epath.Path): 130 | logging.error("DatasetIterator.load() is deprecated. Please use restore().") 131 | return self.restore(filename) 132 | 133 | 134 | class TfDatasetIterator(DatasetIterator): 135 | """DatasetIterator for wrapping a `tf.data.Dataset`.""" 136 | 137 | def __init__(self, dataset, *, checkpoint: bool): 138 | """Wraps `tf.data.Dataset` object into the `DatasetIterator` interface. 139 | 140 | Warning: Do not wrap this interator to do asynchronous prefetching if you 141 | use `checkpoint=True` (default). tf.data iterators must be saved() 142 | synchronously. 143 | 144 | Args: 145 | dataset: The dataset to wrap. Elements are converted to NumPy arrays but 146 | no additional prefetching is done. tf.data should automatically prefetch 147 | elements (to CPU memory). 148 | checkpoint: Whether to checkpoint the dataset iterator object. 149 | Checkpointing dataset iterators is required for handling job 150 | pre-emptions but depending on your input pipeline can result in very 151 | large checkpoints. If set to False save() and load() are no-ops. 152 | """ 153 | try: 154 | # Since this is the only class in this module using TF we only import 155 | # tensorflow if needed. 156 | if typing.TYPE_CHECKING: 157 | tf = Any 158 | else: 159 | import tensorflow as tf # pylint: disable=g-import-not-at-top 160 | except ImportError as e: 161 | raise RuntimeError("When using TfDatasetIterator your binary must " 162 | "depend on //third_party/py/tensorflow.") from e 163 | self._tf = tf 164 | 165 | if not isinstance(dataset, tf.data.Dataset): 166 | raise ValueError("`dataset` must be an instance of `tf.data.Dataset` " 167 | f"but got {type(dataset)}.") 168 | self._dataset = dataset 169 | self._checkpoint = checkpoint 170 | assert self.element_spec # Verify element spec. 171 | self.iterator = iter(dataset) 172 | self._ckpt = tf.train.Checkpoint(ds=self.iterator) 173 | 174 | def get_next(self) -> Element: 175 | return next(self) 176 | 177 | def __next__(self) -> Element: 178 | return {k: np.asarray(v) for k, v in next(self.iterator).items()} 179 | 180 | def reset(self): 181 | self.iterator = iter(self._dataset) 182 | self._ckpt = self._tf.train.Checkpoint(ds=self.iterator) 183 | 184 | @property 185 | def element_spec(self) -> ElementSpec: 186 | element_spec = self._dataset.element_spec 187 | if not isinstance(element_spec, dict): 188 | raise ValueError("Dataset elements must be flat dictionaries but got " 189 | f"{element_spec}.") 190 | invalid_features = [ 191 | k for k, v in element_spec.items() 192 | if not isinstance(v, self._tf.TensorSpec) 193 | ] 194 | if invalid_features: 195 | raise ValueError(f"Features {invalid_features} are not tensors. Dataset " 196 | "elements must be flat dictionaries of tensors.") 197 | return { 198 | k: ArraySpec(dtype=v.dtype.as_numpy_dtype, shape=tuple(v.shape)) 199 | for k, v in element_spec.items() 200 | } 201 | 202 | def save(self, filename: epath.Path): 203 | if self._checkpoint: 204 | self._ckpt.write(os.fspath(filename)) 205 | 206 | def restore(self, filename: epath.Path): 207 | if self._checkpoint: 208 | self._ckpt.read(os.fspath(filename)).assert_consumed() 209 | 210 | 211 | class PeekableDatasetIterator(DatasetIterator): 212 | """Wraps a DatasetIterator to provide a peek() method. 213 | 214 | This allows to look at the next element which can be useful in 2 scenarios: 215 | a) Get the structure of elements if the element_spec property is not 216 | supported. 217 | b) Request the next element without consuming it. This is especially handy to 218 | trigger reading of the first element while the model is being initialized. 219 | 220 | Example use case: 221 | >>> pool = clu.asynclib.Pool() 222 | >>> @pool 223 | >>> def warmup_input_pipeline(): 224 | >>> train_iter.peek() 225 | >>> first_batch_ready = warmup_input_pipeline() 226 | >>> # Do other stuff... 227 | >>> first_batch_ready.result() # wait for input pipeline to be ready. 228 | """ 229 | 230 | def __init__(self, it: DatasetIterator): 231 | self._it = it 232 | # Mutex for self._it. 233 | self._mutex = threading.Lock() 234 | self._peek: Optional[Element] = None 235 | self._pool = None 236 | self._peek_future = None 237 | 238 | def __next__(self) -> Element: 239 | with self._mutex: 240 | if self._peek is None: 241 | return next(self._it) 242 | peek = self._peek 243 | self._peek = None 244 | return peek 245 | 246 | def reset(self): 247 | with self._mutex: 248 | self._it.reset() 249 | self._peek = None 250 | self._pool = None 251 | self._peek_future = None 252 | 253 | @property 254 | def element_spec(self) -> ElementSpec: 255 | return self._it.element_spec 256 | 257 | def peek(self) -> Element: 258 | """Returns the next element without consuming it. 259 | 260 | This will get the next element from the underlying iterator. The element 261 | is stored and return on the next call of __next__(). 262 | 263 | Returns: 264 | The next element. 265 | """ 266 | if self._peek is None: 267 | self._peek = next(self) 268 | return self._peek 269 | 270 | def peek_async(self) -> concurrent.futures.Future[Element]: 271 | """Same as peek() but returns the Future of the element. 272 | 273 | Users can call this to warm up the iterator. 274 | 275 | Returns: 276 | Future with the next element. The element is also kept and returned on the 277 | next call of __next__(). 278 | """ 279 | with self._mutex: 280 | if self._peek_future is None: 281 | if self._pool is None: 282 | self._pool = asynclib.Pool(max_workers=1) 283 | self._peek_future = self._pool(self.peek)() 284 | return self._peek_future 285 | 286 | def save(self, filename: epath.Path): 287 | with self._mutex: 288 | self._it.save(filename) 289 | 290 | def restore(self, filename: epath.Path): 291 | with self._mutex: 292 | self._it.restore(filename) 293 | -------------------------------------------------------------------------------- /clu/periodic_actions_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The CLU Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for perodic actions.""" 16 | 17 | import tempfile 18 | import time 19 | from unittest import mock 20 | 21 | from absl.testing import absltest 22 | from absl.testing import parameterized 23 | from clu import periodic_actions 24 | 25 | 26 | class ReportProgressTest(parameterized.TestCase): 27 | 28 | def test_every_steps(self): 29 | hook = periodic_actions.ReportProgress( 30 | every_steps=4, every_secs=None, num_train_steps=10 31 | ) 32 | t = time.monotonic() 33 | with self.assertLogs(level="INFO") as logs: 34 | self.assertFalse(hook(1, t)) 35 | t += 0.11 36 | self.assertFalse(hook(2, t)) 37 | t += 0.13 38 | self.assertFalse(hook(3, t)) 39 | t += 0.12 40 | self.assertTrue(hook(4, t)) 41 | # We did 1 step every 0.12s => 8.333 steps/s. 42 | self.assertEqual( 43 | logs.output, 44 | [ 45 | "INFO:absl:Setting work unit notes: 8.3 steps/s, 40.0% (4/10)," 46 | " ETA: 0m" 47 | ], 48 | ) 49 | 50 | def test_every_secs(self): 51 | hook = periodic_actions.ReportProgress( 52 | every_steps=None, every_secs=0.3, num_train_steps=10 53 | ) 54 | t = time.monotonic() 55 | with self.assertLogs(level="INFO") as logs: 56 | self.assertFalse(hook(1, t)) 57 | t += 0.11 58 | self.assertFalse(hook(2, t)) 59 | t += 0.13 60 | self.assertFalse(hook(3, t)) 61 | t += 0.12 62 | self.assertTrue(hook(4, t)) 63 | # We did 1 step every 0.12s => 8.333 steps/s. 64 | self.assertEqual( 65 | logs.output, 66 | [ 67 | "INFO:absl:Setting work unit notes: 8.3 steps/s, 40.0% (4/10)," 68 | " ETA: 0m" 69 | ], 70 | ) 71 | 72 | def test_without_num_train_steps(self): 73 | report = periodic_actions.ReportProgress(every_steps=2) 74 | t = time.monotonic() 75 | with self.assertLogs(level="INFO") as logs: 76 | self.assertFalse(report(1, t)) 77 | self.assertTrue(report(2, t + 0.12)) 78 | # We did 1 step in 0.12s => 8.333 steps/s. 79 | self.assertEqual( 80 | logs.output, ["INFO:absl:Setting work unit notes: 8.3 steps/s"] 81 | ) 82 | 83 | def test_with_persistent_notes(self): 84 | report = periodic_actions.ReportProgress(every_steps=2) 85 | report.set_persistent_notes("Hello world") 86 | t = time.monotonic() 87 | with self.assertLogs(level="INFO") as logs: 88 | self.assertFalse(report(1, t)) 89 | self.assertTrue(report(2, t + 0.12)) 90 | # We did 1 step in 0.12s => 8.333 steps/s. 91 | self.assertEqual( 92 | logs.output, 93 | ["INFO:absl:Setting work unit notes: Hello world\n8.3 steps/s"], 94 | ) 95 | 96 | def test_unknown_cardinality(self): 97 | report = periodic_actions.ReportProgress(every_steps=2) 98 | t = time.monotonic() 99 | with self.assertLogs(level="INFO") as logs: 100 | self.assertFalse(report(1, t)) 101 | self.assertTrue(report(2, t + 0.12)) 102 | # We did 1 step in 0.12s => 8.333 steps/s. 103 | self.assertEqual( 104 | logs.output, ["INFO:absl:Setting work unit notes: 8.3 steps/s"] 105 | ) 106 | 107 | def test_called_every_step(self): 108 | hook = periodic_actions.ReportProgress(every_steps=3, num_train_steps=10) 109 | t = time.monotonic() 110 | with self.assertRaisesRegex( 111 | ValueError, "PeriodicAction must be called after every step" 112 | ): 113 | hook(1, t) 114 | hook(11, t) # Raises exception. 115 | 116 | @parameterized.named_parameters( 117 | ("_nowait", False), 118 | ("_wait", True), 119 | ) 120 | @mock.patch("time.monotonic") 121 | def test_named(self, wait_jax_async_dispatch, mock_time): 122 | mock_time.return_value = 0 123 | hook = periodic_actions.ReportProgress( 124 | every_steps=1, every_secs=None, num_train_steps=10 125 | ) 126 | 127 | def _wait(): 128 | # Here we depend on hook._executor=ThreadPoolExecutor(max_workers=1) 129 | hook._executor.submit(lambda: None).result() 130 | 131 | self.assertFalse(hook(1)) # Never triggers on first execution. 132 | with hook.timed("test1", wait_jax_async_dispatch): 133 | _wait() 134 | mock_time.return_value = 1 135 | _wait() 136 | with hook.timed("test2", wait_jax_async_dispatch): 137 | _wait() 138 | mock_time.return_value = 2 139 | _wait() 140 | with hook.timed("test1", wait_jax_async_dispatch): 141 | _wait() 142 | mock_time.return_value = 3 143 | _wait() 144 | mock_time.return_value = 4 145 | with self.assertLogs(level="INFO") as logs: 146 | self.assertTrue(hook(2)) 147 | self.assertEqual( 148 | logs.output, 149 | [ 150 | "INFO:absl:Setting work unit notes: 0.2 steps/s, 20.0% (2/10), ETA:" 151 | " 0m (0m : 50.0% test1, 25.0% test2)" 152 | ], 153 | ) 154 | 155 | @mock.patch("time.monotonic") 156 | def test_write_metrics(self, time_mock): 157 | time_mock.return_value = 0 158 | writer_mock = mock.Mock() 159 | hook = periodic_actions.ReportProgress( 160 | every_steps=2, every_secs=None, writer=writer_mock 161 | ) 162 | time_mock.return_value = 1 163 | hook(1) 164 | time_mock.return_value = 2 165 | hook(2) 166 | self.assertEqual( 167 | writer_mock.write_scalars.mock_calls, 168 | [ 169 | mock.call(2, {"steps_per_sec": 1}), 170 | mock.call(2, {"uptime": 2}), 171 | ], 172 | ) 173 | 174 | 175 | class DummyProfilerSession: 176 | """Dummy Profiler that records the steps at which sessions started/ended.""" 177 | 178 | def __init__(self): 179 | self.step = None 180 | self.start_session_call_steps = [] 181 | self.end_session_call_steps = [] 182 | 183 | def start_session(self): 184 | self.start_session_call_steps.append(self.step) 185 | 186 | def end_session_and_get_url(self, tag): 187 | del tag 188 | self.end_session_call_steps.append(self.step) 189 | 190 | 191 | class ProfileTest(absltest.TestCase): 192 | 193 | @mock.patch.object(periodic_actions, "profiler", autospec=True) 194 | @mock.patch("time.monotonic") 195 | def test_every_steps(self, mock_time, mock_profiler): 196 | start_steps = [] 197 | stop_steps = [] 198 | step = 0 199 | 200 | def add_start_step(logdir): 201 | del logdir # unused 202 | start_steps.append(step) 203 | 204 | def add_stop_step(): 205 | stop_steps.append(step) 206 | 207 | mock_profiler.start.side_effect = add_start_step 208 | mock_profiler.stop.side_effect = add_stop_step 209 | hook = periodic_actions.Profile( 210 | logdir=tempfile.mkdtemp(), 211 | num_profile_steps=2, 212 | profile_duration_ms=2_000, 213 | first_profile=3, 214 | every_steps=7, 215 | ) 216 | for step in range(1, 18): 217 | mock_time.return_value = step - 0.5 if step == 9 else step 218 | hook(step) 219 | self.assertEqual([3, 7, 14], start_steps) 220 | # Note: profiling 7..10 instead of 7..9 because 7..9 took only 1.5 seconds. 221 | self.assertEqual([5, 10, 16], stop_steps) 222 | 223 | 224 | class ProfileAllHostsTest(absltest.TestCase): 225 | 226 | @mock.patch.object(periodic_actions, "profiler", autospec=True) 227 | def test_every_steps(self, mock_profiler): 228 | start_steps = [] 229 | step = 0 230 | 231 | def profile_collect(logdir, callback, hosts, duration_ms): 232 | del logdir, callback, hosts, duration_ms # unused 233 | start_steps.append(step) 234 | 235 | mock_profiler.collect.side_effect = profile_collect 236 | hook = periodic_actions.ProfileAllHosts( 237 | logdir=tempfile.mkdtemp(), 238 | profile_duration_ms=2_000, 239 | first_profile=3, 240 | every_steps=7, 241 | ) 242 | for step in range(1, 18): 243 | hook(step) 244 | self.assertEqual([3, 7, 14], start_steps) 245 | 246 | 247 | class PeriodicCallbackTest(absltest.TestCase): 248 | 249 | def test_every_steps(self): 250 | callback = mock.Mock() 251 | hook = periodic_actions.PeriodicCallback( 252 | every_steps=2, callback_fn=callback 253 | ) 254 | 255 | for step in range(1, 10): 256 | hook(step, 3, remainder=step % 3) 257 | 258 | expected_calls = [ 259 | mock.call(remainder=2, step=2, t=3), 260 | mock.call(remainder=1, step=4, t=3), 261 | mock.call(remainder=0, step=6, t=3), 262 | mock.call(remainder=2, step=8, t=3), 263 | ] 264 | self.assertListEqual(expected_calls, callback.call_args_list) 265 | 266 | @mock.patch("time.monotonic") 267 | def test_every_secs(self, mock_time): 268 | callback = mock.Mock() 269 | hook = periodic_actions.PeriodicCallback(every_secs=2, callback_fn=callback) 270 | 271 | for step in range(1, 10): 272 | mock_time.return_value = float(step) 273 | hook(step, remainder=step % 5) 274 | # Note: time will be initialized at 1 so hook runs at steps 4 & 7. 275 | expected_calls = [ 276 | mock.call(remainder=4, step=4, t=4.0), 277 | mock.call(remainder=2, step=7, t=7.0), 278 | ] 279 | self.assertListEqual(expected_calls, callback.call_args_list) 280 | 281 | def test_on_steps(self): 282 | callback = mock.Mock() 283 | hook = periodic_actions.PeriodicCallback(on_steps=[8], callback_fn=callback) 284 | 285 | for step in range(1, 10): 286 | hook(step, remainder=step % 3) 287 | 288 | callback.assert_called_once_with(remainder=2, step=8, t=mock.ANY) 289 | 290 | def test_async_execution(self): 291 | out = [] 292 | 293 | def cb(step, t): 294 | del t 295 | out.append(step) 296 | 297 | hook = periodic_actions.PeriodicCallback( 298 | every_steps=1, callback_fn=cb, execute_async=True 299 | ) 300 | hook(0) 301 | hook(1) 302 | hook(2) 303 | hook(3) 304 | # Block till all the hooks have finished. 305 | hook.get_last_callback_result().result() 306 | # Check order of execution is preserved. 307 | self.assertListEqual(out, [0, 1, 2, 3]) 308 | 309 | def test_error_async_is_forwarded(self): 310 | 311 | def cb(step, t): 312 | del step 313 | del t 314 | raise Exception 315 | 316 | hook = periodic_actions.PeriodicCallback( 317 | every_steps=1, callback_fn=cb, execute_async=True 318 | ) 319 | 320 | hook(0) 321 | 322 | with self.assertRaises(Exception): 323 | hook(1) 324 | 325 | def test_function_without_step_and_time(self): 326 | 327 | # This must be used with pass_step_and_time=False. 328 | def cb(): 329 | return 5 330 | 331 | hook = periodic_actions.PeriodicCallback( 332 | every_steps=1, callback_fn=cb, pass_step_and_time=False 333 | ) 334 | hook(0) 335 | hook(1) 336 | self.assertEqual(hook.get_last_callback_result(), 5) 337 | 338 | 339 | if __name__ == "__main__": 340 | absltest.main() 341 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /clu/parameter_overview.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The CLU Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Helper function for creating and logging JAX variable overviews.""" 16 | 17 | from collections.abc import Callable, Mapping, Sequence 18 | import dataclasses 19 | from typing import Any 20 | 21 | from absl import logging 22 | 23 | import flax 24 | import jax 25 | import jax.numpy as jnp 26 | import numpy as np 27 | 28 | _ParamsContainer = dict[str, np.ndarray] | Mapping[str, Mapping[str, Any]] 29 | 30 | 31 | @dataclasses.dataclass 32 | class _ParamRow: 33 | name: str 34 | shape: tuple[int, ...] 35 | dtype: str 36 | size: int 37 | 38 | 39 | @dataclasses.dataclass 40 | class _ParamRowWithSharding(_ParamRow): 41 | sharding: tuple[int | None, ...] | str 42 | 43 | 44 | @dataclasses.dataclass 45 | class _ParamRowWithStats(_ParamRow): 46 | mean: float 47 | std: float 48 | 49 | 50 | @dataclasses.dataclass 51 | class _ParamRowWithStatsAndSharding(_ParamRowWithStats): 52 | sharding: tuple[int | None, ...] | str 53 | 54 | 55 | @jax.jit 56 | def _mean_std_jit(x): 57 | return jax.tree_util.tree_map(jnp.mean, x), jax.tree_util.tree_map(jnp.std, x) 58 | 59 | 60 | def _mean_std(x): 61 | mean = jax.tree_util.tree_map(lambda x: x.mean(), x) 62 | std = jax.tree_util.tree_map(lambda x: x.std(), x) 63 | return mean, std 64 | 65 | 66 | def flatten_dict( 67 | input_dict: dict[str, Any], *, prefix: str = "", delimiter: str = "/" 68 | ) -> dict[str, Any]: 69 | """Flattens the keys of a nested dictionary.""" 70 | output_dict = {} 71 | for key, value in input_dict.items(): 72 | nested_key = f"{prefix}{delimiter}{key}" if prefix else key 73 | if isinstance(value, (dict, flax.core.FrozenDict)): 74 | output_dict.update( 75 | flatten_dict(value, prefix=nested_key, delimiter=delimiter) 76 | ) 77 | else: 78 | output_dict[nested_key] = value 79 | return output_dict 80 | 81 | 82 | def _count_parameters(params: _ParamsContainer) -> int: 83 | """Returns the count of variables for the module or parameter dictionary.""" 84 | params = flatten_dict(params) 85 | return sum(np.prod(v.shape) for v in params.values() if v is not None) 86 | 87 | 88 | def _parameters_size(params: _ParamsContainer) -> int: 89 | """Returns total size (bytes) for the module or parameter dictionary.""" 90 | params = flatten_dict(params) 91 | return sum( 92 | np.prod(v.shape) * v.dtype.itemsize 93 | for v in params.values() 94 | if v is not None 95 | ) 96 | 97 | 98 | def count_parameters(params: _ParamsContainer) -> int: 99 | """Returns the count of variables for the module or parameter dictionary.""" 100 | 101 | return _count_parameters(params) 102 | 103 | 104 | def _make_row(name, value) -> _ParamRow: 105 | if value is None: 106 | return _ParamRow( 107 | name=name, 108 | shape=(), 109 | dtype="", 110 | size=0, 111 | ) 112 | return _ParamRow( 113 | name=name, 114 | shape=value.shape, 115 | dtype=str(value.dtype), 116 | size=int(np.prod(value.shape)), 117 | ) 118 | 119 | 120 | def _make_row_with_sharding(name, value) -> _ParamRowWithSharding: 121 | row = _make_row(name, value) 122 | if hasattr(value, "sharding"): 123 | if hasattr(value.sharding, "spec"): 124 | sharding = tuple(value.sharding.spec) 125 | else: 126 | sharding = str(value.sharding) 127 | else: 128 | sharding = () 129 | return _ParamRowWithSharding(**dataclasses.asdict(row), sharding=sharding) 130 | 131 | 132 | def _make_row_with_stats(name, value, mean, std) -> _ParamRowWithStats: 133 | row = _make_row(name, value) 134 | mean = mean or 0.0 135 | std = std or 0.0 136 | return _ParamRowWithStats( 137 | **dataclasses.asdict(row), 138 | mean=float(jax.device_get(mean)), 139 | std=float(jax.device_get(std)), 140 | ) 141 | 142 | 143 | def _make_row_with_stats_and_sharding( 144 | name, value, mean, std 145 | ) -> _ParamRowWithStatsAndSharding: 146 | row = _make_row_with_sharding(name, value) 147 | return _ParamRowWithStatsAndSharding( 148 | **dataclasses.asdict(row), 149 | mean=float(jax.device_get(mean)), 150 | std=float(jax.device_get(std)), 151 | ) 152 | 153 | 154 | def _get_parameter_rows( 155 | params: _ParamsContainer, 156 | *, 157 | include_stats: bool | str = False, 158 | ) -> list[_ParamRow]: 159 | """Returns information about parameters as a list of dictionaries. 160 | 161 | Args: 162 | params: Dictionary with parameters as NumPy arrays. The dictionary can be 163 | nested. Alternatively a `tf.Module` can be provided, in which case the 164 | `trainable_variables` of the module will be used. 165 | include_stats: If True, add columns with mean and std for each variable. If 166 | the string "sharding", add column a column with the sharding of the 167 | variable. If the string "global", params are sharded global arrays and 168 | this function assumes it is called on every host, i.e. can use 169 | collectives. The sharding of the variables is also added as a column. 170 | 171 | Returns: 172 | A list of `ParamRow`, or `ParamRowWithStats`, depending on the passed value 173 | of `include_stats`. 174 | """ 175 | if not isinstance(params, (dict, flax.core.FrozenDict)): 176 | raise ValueError( 177 | f"Expected `params` to be a dictionary but got {type(params)}" 178 | ) 179 | 180 | params = flatten_dict(params) 181 | if params: 182 | names, values = map(list, tuple(zip(*sorted(params.items())))) 183 | else: 184 | names, values = [], [] 185 | 186 | match include_stats: 187 | case False: 188 | return jax.tree_util.tree_map(_make_row, names, values) 189 | 190 | case True: 191 | mean_and_std = _mean_std(values) 192 | return jax.tree_util.tree_map( 193 | _make_row_with_stats, names, values, *mean_and_std 194 | ) 195 | 196 | case "global": 197 | mean_and_std = _mean_std_jit(values) 198 | return jax.tree_util.tree_map( 199 | _make_row_with_stats_and_sharding, names, values, *mean_and_std 200 | ) 201 | 202 | case "sharding": 203 | return jax.tree_util.tree_map(_make_row_with_sharding, names, values) 204 | 205 | case _: 206 | raise ValueError(f"Unknown `include_stats`: {include_stats}") 207 | 208 | 209 | def _default_table_value_formatter(value): 210 | """Formats ints with "," between thousands and floats to 3 digits.""" 211 | if isinstance(value, bool): 212 | return str(value) 213 | elif isinstance(value, int): 214 | return "{:,}".format(value) 215 | elif isinstance(value, float): 216 | return "{:.3}".format(value) 217 | else: 218 | return str(value) 219 | 220 | 221 | def make_table( 222 | rows: list[Any], 223 | *, 224 | column_names: Sequence[str] | None = None, 225 | value_formatter: Callable[[Any], str] = _default_table_value_formatter, 226 | max_lines: int | None = None, 227 | ) -> str: 228 | """Renders a list of rows to a table. 229 | 230 | Args: 231 | rows: List of dataclass instances of a single type (e.g. `ParamRow`). 232 | column_names: List of columns that that should be included in the output. If 233 | not provided, then the columns are taken from keys of the first row. 234 | value_formatter: Callable used to format cell values. 235 | max_lines: Don't render a table longer than this. 236 | 237 | Returns: 238 | A string representation of the table in the form: 239 | 240 | +---------+---------+ 241 | | Col1 | Col2 | 242 | +---------+---------+ 243 | | value11 | value12 | 244 | | value21 | value22 | 245 | +---------+---------+ 246 | """ 247 | 248 | if any(not dataclasses.is_dataclass(row) for row in rows): 249 | raise ValueError("Expected `rows` to be list of dataclasses") 250 | if len(set(map(type, rows))) > 1: 251 | raise ValueError("Expected elements of `rows` be of same type.") 252 | 253 | class Column: 254 | 255 | def __init__(self, name, values): 256 | self.name = name.capitalize() 257 | self.values = values 258 | self.width = max(len(v) for v in values + [name]) 259 | 260 | if column_names is None: 261 | if not rows: 262 | return "(empty table)" 263 | column_names = [field.name for field in dataclasses.fields(rows[0])] 264 | 265 | columns = [ 266 | Column(name, [value_formatter(getattr(row, name)) for row in rows]) 267 | for name in column_names 268 | ] 269 | 270 | var_line_format = "|" + "".join(f" {{: <{c.width}s}} |" for c in columns) 271 | sep_line_format = var_line_format.replace(" ", "-").replace("|", "+") 272 | header = var_line_format.replace(">", "<").format(*[c.name for c in columns]) 273 | separator = sep_line_format.format(*["" for c in columns]) 274 | 275 | lines = [separator, header, separator] 276 | for i in range(len(rows)): 277 | if max_lines and len(lines) >= max_lines - 3: 278 | lines.append("[...]") 279 | break 280 | lines.append(var_line_format.format(*[c.values[i] for c in columns])) 281 | lines.append(separator) 282 | 283 | return "\n".join(lines) 284 | 285 | 286 | def _get_parameter_overview( 287 | params: _ParamsContainer, 288 | *, 289 | include_stats: bool | str = True, 290 | max_lines: int | None = None, 291 | ) -> str: 292 | """See get_parameter_overview().""" 293 | if include_stats is True and isinstance(params, (dict, flax.core.FrozenDict)): # pylint: disable=g-bool-id-comparison 294 | params = jax.device_get(params) # A no-op if already numpy array. 295 | rows = _get_parameter_rows(params, include_stats=include_stats) 296 | RowType = { # pylint: disable=invalid-name 297 | False: _ParamRow, 298 | True: _ParamRowWithStats, 299 | "global": _ParamRowWithStatsAndSharding, 300 | "sharding": _ParamRowWithSharding, 301 | }[include_stats] 302 | # Pass in `column_names` to enable rendering empty tables. 303 | column_names = [field.name for field in dataclasses.fields(RowType)] 304 | table = make_table(rows, max_lines=max_lines, column_names=column_names) 305 | total_weights = _count_parameters(params) 306 | total_size = _parameters_size(params) 307 | return table + f"\nTotal: {total_weights:,} -- {total_size:,} bytes" 308 | 309 | 310 | def get_parameter_overview( 311 | params: _ParamsContainer, 312 | *, 313 | include_stats: bool | str = True, 314 | max_lines: int | None = None, 315 | ) -> str: 316 | """Returns a string with variables names, their shapes, count. 317 | 318 | Args: 319 | params: Dictionary with parameters as NumPy arrays. The dictionary can be 320 | nested. 321 | include_stats: If True, add columns with mean and std for each variable. If 322 | the string "sharding", add column a column with the sharding of the 323 | variable. If the string "global", params are sharded global arrays and 324 | this function assumes it is called on every host, i.e. can use 325 | collectives. The sharding of the variables is also added as a column. 326 | max_lines: If not `None`, the maximum number of variables to include. 327 | 328 | Returns: 329 | A string with a table like in the example. 330 | 331 | +----------------+---------------+------------+ 332 | | Name | Shape | Size | 333 | +----------------+---------------+------------+ 334 | | FC_1/weights:0 | (63612, 1024) | 65,138,688 | 335 | | FC_1/biases:0 | (1024,) | 1,024 | 336 | | FC_2/weights:0 | (1024, 32) | 32,768 | 337 | | FC_2/biases:0 | (32,) | 32 | 338 | +----------------+---------------+------------+ 339 | Total: 65,172,512 340 | """ 341 | 342 | return _get_parameter_overview( 343 | params, include_stats=include_stats, max_lines=max_lines 344 | ) 345 | 346 | 347 | def _log_parameter_overview( 348 | params: _ParamsContainer, 349 | *, 350 | include_stats: bool | str = True, 351 | max_lines: int | None = None, 352 | msg: str | None = None, 353 | jax_logging_process: int | None = None, 354 | ): 355 | """See log_parameter_overview().""" 356 | 357 | table = _get_parameter_overview( 358 | params, include_stats=include_stats, max_lines=max_lines 359 | ) 360 | if jax_logging_process is None or jax_logging_process == jax.process_index(): 361 | lines = [msg] if msg else [] 362 | lines += table.split("\n") 363 | # The table can be too large to fit into one log entry. 364 | for i in range(0, len(lines), 80): 365 | logging.info("\n%s", "\n".join(lines[i : i + 80])) 366 | 367 | 368 | def log_parameter_overview( 369 | params: _ParamsContainer, 370 | *, 371 | include_stats: bool | str = True, 372 | max_lines: int | None = None, 373 | msg: str | None = None, 374 | jax_logging_process: int | None = None, 375 | ): 376 | """Writes a table with variables name and shapes to INFO log. 377 | 378 | See get_parameter_overview for details. 379 | 380 | Args: 381 | params: Dictionary with parameters as NumPy arrays. The dictionary can be 382 | nested. 383 | include_stats: If True, add columns with mean and std for each variable. If 384 | the string "global", params are sharded global arrays and this function 385 | assumes it is called on every host, i.e. can use collectives. 386 | max_lines: If not `None`, the maximum number of variables to include. 387 | msg: Message to be logged before the overview. 388 | jax_logging_process: Which JAX process ID should do the logging. None = all. 389 | Use this to avoid logspam when include_stats="global". 390 | """ 391 | 392 | _log_parameter_overview( 393 | params, 394 | include_stats=include_stats, 395 | max_lines=max_lines, 396 | msg=msg, 397 | jax_logging_process=jax_logging_process, 398 | ) 399 | -------------------------------------------------------------------------------- /clu/deterministic_data_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The CLU Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Unit tests for the deterministic_data module.""" 16 | import dataclasses 17 | import itertools 18 | import math 19 | 20 | from typing import Dict 21 | from unittest import mock 22 | 23 | from absl.testing import parameterized 24 | from clu import deterministic_data 25 | import jax 26 | from packaging import version 27 | import tensorflow as tf 28 | import tensorflow_datasets as tfds 29 | 30 | _use_split_info = version.parse("4.4.0") < version.parse( 31 | tfds.version.__version__) 32 | 33 | 34 | @dataclasses.dataclass 35 | class MyDatasetBuilder: 36 | 37 | name2len: Dict[str, int] # Number of examples per split. 38 | 39 | def as_dataset(self, split: tfds.core.ReadInstruction, shuffle_files: bool, 40 | read_config: tfds.ReadConfig, decoders) -> tf.data.Dataset: 41 | del shuffle_files, read_config, decoders 42 | if _use_split_info: 43 | split_infos = { 44 | k: tfds.core.SplitInfo(name=k, shard_lengths=[v], num_bytes=0) 45 | for k, v in self.name2len.items() 46 | } 47 | instructions = split.to_absolute(split_infos) 48 | else: 49 | instructions = split.to_absolute(self.name2len) 50 | assert len(instructions) == 1 51 | from_ = instructions[0].from_ or 0 52 | to = instructions[0].to or self.name2len[instructions[0].splitname] 53 | return tf.data.Dataset.range(from_, to).map(lambda i: {"index": i}) 54 | 55 | 56 | @dataclasses.dataclass 57 | class FakeDatasetInfo: 58 | train_size: int = 9 59 | test_size: int = 8 60 | 61 | @property 62 | def splits(self): 63 | return { 64 | "train": tfds.core.SplitInfo("train", [self.train_size], 0), 65 | "test": tfds.core.SplitInfo("test", [self.test_size], 0) 66 | } 67 | 68 | 69 | class DeterministicDataTest(tf.test.TestCase, parameterized.TestCase): 70 | """Tests for deterministic_data module.""" 71 | 72 | @parameterized.parameters( 73 | (9, 0, 1, True, "test[0:9]"), 74 | (9, 0, 2, True, "test[0:4]"), 75 | (9, 1, 2, True, "test[4:8]"), # Last example gets dropped. 76 | (9, 0, 3, True, "test[0:3]"), 77 | (9, 1, 3, True, "test[3:6]"), 78 | (9, 2, 3, True, "test[6:9]"), 79 | (9, 0, 1, False, "test[0:9]"), 80 | (9, 0, 2, False, "test[0:5]"), # First host gets an extra example. 81 | (9, 1, 2, False, "test[5:9]"), 82 | (8, 0, 3, False, "test[0:3]"), # First 2 hosts get 1 example each. 83 | (8, 1, 3, False, "test[3:6]"), 84 | (8, 2, 3, False, "test[6:8]"), 85 | ) 86 | def test_get_read_instruction_for_host_deprecated(self, num_examples: int, 87 | host_id: int, 88 | host_count: int, 89 | drop_remainder: bool, 90 | expected_spec: str): 91 | expected = tfds.core.ReadInstruction.from_spec(expected_spec) 92 | actual = deterministic_data.get_read_instruction_for_host( 93 | "test", 94 | num_examples, 95 | host_id=host_id, 96 | host_count=host_count, 97 | drop_remainder=drop_remainder) 98 | if _use_split_info: 99 | split_infos = { 100 | "test": tfds.core.SplitInfo( 101 | name="test", 102 | shard_lengths=[9], 103 | num_bytes=0, 104 | )} 105 | else: 106 | split_infos = {"test": 9} 107 | self.assertEqual( 108 | expected.to_absolute(split_infos), actual.to_absolute(split_infos)) 109 | 110 | @parameterized.parameters( 111 | # host_id, host_count, drop_remainder, spec, exected_spec_for_host 112 | # train split has 9 examples. 113 | (0, 1, True, "train", "train[0:9]"), 114 | (0, 2, True, "train", "train[0:4]"), 115 | (1, 2, True, "train", "train[4:8]"), # Last example gets dropped. 116 | (0, 3, True, "train", "train[0:3]"), 117 | (1, 3, True, "train", "train[3:6]"), 118 | (2, 3, True, "train", "train[6:9]"), 119 | (0, 1, False, "train", "train[0:9]"), 120 | (0, 2, False, "train", "train[0:5]"), # First host gets an extra example. 121 | (1, 2, False, "train", "train[5:9]"), 122 | # test split has 8 examples. 123 | (0, 3, False, "test", "test[0:3]"), # First 2 hosts get 1 example each. 124 | (1, 3, False, "test", "test[3:6]"), 125 | (2, 3, False, "test", "test[6:8]"), 126 | # Subsplits. 127 | (0, 2, True, "train[:50%]", "train[0:2]"), 128 | (1, 2, True, "train[:50%]", "train[2:4]"), 129 | (0, 2, True, "train[3:7]", "train[3:5]"), 130 | (1, 2, True, "train[3:7]", "train[5:7]"), 131 | (0, 2, True, "train[3:8]", "train[3:5]"), # Last example gets dropped. 132 | (1, 2, True, "train[3:8]", "train[5:7]"), 133 | # 2 splits. 134 | (0, 2, True, "train[3:7]+test", "train[3:5]+test[0:4]"), 135 | (1, 2, True, "train[3:7]+test", "train[5:7]+test[4:8]"), 136 | # First host gets an extra example. 137 | (0, 2, False, "train[3:8]+test[:5]", "train[3:6]+test[0:3]"), 138 | (1, 2, False, "train[3:8]+test[:5]", "train[6:8]+test[3:5]"), 139 | ) 140 | def test_get_read_instruction_for_host(self, host_id: int, host_count: int, 141 | drop_remainder: bool, spec: str, 142 | expected_spec_for_host: str): 143 | 144 | actual_spec_for_host = deterministic_data.get_read_instruction_for_host( 145 | spec, 146 | dataset_info=FakeDatasetInfo(), 147 | host_id=host_id, 148 | host_count=host_count, 149 | drop_remainder=drop_remainder) 150 | expected_spec_for_host = tfds.core.ReadInstruction.from_spec( 151 | expected_spec_for_host) 152 | self.assertEqual(str(actual_spec_for_host), str(expected_spec_for_host)) 153 | 154 | @parameterized.parameters( 155 | # host_id, host_count, balance_remainder, spec, exected_spec_for_host 156 | # test split has 10 examples. 157 | (0, 1, True, "test", "test[0:10]"), 158 | (0, 1, False, "test", "test[0:10]"), 159 | (0, 4, True, "test", "test[0:3]"), 160 | (1, 4, True, "test", "test[3:6]"), 161 | (2, 4, True, "test", "test[6:8]"), 162 | (3, 4, True, "test", "test[8:10]"), 163 | (0, 4, False, "test", "test[0:4]"), 164 | (1, 4, False, "test", "test[4:6]"), 165 | (2, 4, False, "test", "test[6:8]"), 166 | (3, 4, False, "test", "test[8:10]"), 167 | ) 168 | def test_get_read_instruction_balance_remainder(self, host_id: int, 169 | host_count: int, 170 | balance_remainder: bool, 171 | spec: str, 172 | expected_spec_for_host: str): 173 | actual_spec_for_host = deterministic_data.get_read_instruction_for_host( 174 | spec, 175 | dataset_info=FakeDatasetInfo(test_size=10), 176 | host_id=host_id, 177 | host_count=host_count, 178 | remainder_options=deterministic_data.RemainderOptions 179 | .BALANCE_ON_PROCESSES if balance_remainder else 180 | deterministic_data.RemainderOptions.ON_FIRST_PROCESS) 181 | expected_spec_for_host = tfds.core.ReadInstruction.from_spec( 182 | expected_spec_for_host) 183 | self.assertEqual(str(actual_spec_for_host), str(expected_spec_for_host)) 184 | 185 | @parameterized.parameters( 186 | (0, 0), # No hosts. 187 | (1, 1), # Only one host (host_id is zero-based. 188 | (-1, 1), # Negative host_id. 189 | (5, 2), # host_id bigger than number of hosts. 190 | ) 191 | def test_get_read_instruction_for_host_fails(self, host_id: int, 192 | host_count: int): 193 | with self.assertRaises(ValueError): 194 | deterministic_data.get_read_instruction_for_host( 195 | "test", 11, host_id=host_id, host_count=host_count) 196 | 197 | def test_preprocess_with_per_example_rng(self): 198 | 199 | def preprocess_fn(features): 200 | features["b"] = tf.random.stateless_uniform([], features["rng"]) 201 | return features 202 | 203 | rng = jax.random.PRNGKey(42) 204 | ds_in = tf.data.Dataset.from_tensor_slices({"a": [37.2, 31.2, 39.0]}) 205 | ds_out = deterministic_data._preprocess_with_per_example_rng( 206 | ds_in, preprocess_fn, rng=rng) 207 | self.assertAllClose([ 208 | { 209 | "a": 37.2, 210 | "b": 0.79542184 211 | }, 212 | { 213 | "a": 31.2, 214 | "b": 0.45482683 215 | }, 216 | { 217 | "a": 39.0, 218 | "b": 0.85335636 219 | }, 220 | ], list(ds_out)) 221 | 222 | @parameterized.parameters(*itertools.product([2, "auto"], [True, False])) 223 | def test_create_dataset_padding(self, pad_up_to_batches, cardinality): 224 | dataset_builder = mock.Mock() 225 | dataset = tf.data.Dataset.from_tensor_slices( 226 | dict(x=tf.ones((12, 10)), y=tf.ones(12))) 227 | dataset_builder.as_dataset.return_value = dataset 228 | batch_dims = (2, 5) 229 | ds = deterministic_data.create_dataset( 230 | dataset_builder, 231 | split="(ignored)", 232 | batch_dims=batch_dims, 233 | num_epochs=1, 234 | shuffle=False, 235 | pad_up_to_batches=pad_up_to_batches, 236 | cardinality=12 if cardinality else None, 237 | ) 238 | ds_iter = iter(ds) 239 | self.assertAllClose( 240 | dict( 241 | x=tf.ones((2, 5, 10)), 242 | y=tf.ones((2, 5)), 243 | mask=tf.ones((2, 5), bool), 244 | ), next(ds_iter)) 245 | self.assertAllClose( 246 | dict( 247 | x=tf.reshape( 248 | tf.concat([tf.ones( 249 | (2, 10)), tf.zeros((8, 10))], axis=0), (2, 5, 10)), 250 | y=tf.reshape(tf.concat([tf.ones(2), tf.zeros(8)], axis=0), (2, 5)), 251 | mask=tf.reshape( 252 | tf.concat( 253 | [tf.ones(2, bool), tf.zeros(8, bool)], axis=0), (2, 5)), 254 | ), next(ds_iter)) 255 | with self.assertRaises(StopIteration): 256 | next(ds_iter) 257 | 258 | def test_create_dataset_padding_raises_error_cardinality(self): 259 | dataset_builder = mock.Mock() 260 | dataset = tf.data.Dataset.from_tensor_slices( 261 | dict(x=tf.ones((12, 10)), y=tf.ones(12))) 262 | dataset = dataset.filter(lambda x: True) 263 | dataset_builder.as_dataset.return_value = dataset 264 | batch_dims = (2, 5) 265 | with self.assertRaisesRegex( 266 | ValueError, 267 | r"^Cannot determine dataset cardinality."): 268 | deterministic_data.create_dataset( 269 | dataset_builder, 270 | split="(ignored)", 271 | batch_dims=batch_dims, 272 | num_epochs=1, 273 | shuffle=False, 274 | pad_up_to_batches=2, 275 | cardinality=None, 276 | ) 277 | 278 | def test_pad_dataset(self): 279 | dataset = tf.data.Dataset.from_tensor_slices( 280 | dict(x=tf.ones((12, 10)), y=tf.ones(12))) 281 | padded_dataset = deterministic_data.pad_dataset( 282 | dataset, batch_dims=[20], pad_up_to_batches=2, cardinality=12) 283 | self.assertAllClose( 284 | dict( 285 | x=tf.concat([tf.ones( 286 | (12, 10)), tf.zeros((8, 10))], axis=0), 287 | y=tf.concat([tf.ones(12), tf.zeros(8)], axis=0), 288 | mask=tf.concat( 289 | [tf.ones(12, bool), tf.zeros(8, bool)], axis=0)), 290 | next(iter(padded_dataset.batch(20)))) 291 | 292 | def test_pad_nested_dataset(self): 293 | dataset = tf.data.Dataset.from_tensor_slices( 294 | {"x": {"z": (tf.ones((12, 10)), tf.ones(12))}, 295 | "y": tf.ones((12, 4))}) 296 | 297 | def expected(*dims): 298 | return tf.concat([tf.ones((12,) + dims), tf.zeros((8,) + dims)], axis=0) 299 | 300 | padded_dataset = deterministic_data.pad_dataset( 301 | dataset, batch_dims=[20], pad_up_to_batches=2, cardinality=12) 302 | self.assertAllClose( 303 | {"x": {"z": (expected(10), expected())}, 304 | "y": expected(4), 305 | "mask": tf.concat([tf.ones(12, bool), tf.zeros(8, bool)], axis=0)}, 306 | next(iter(padded_dataset.batch(20)))) 307 | 308 | @parameterized.parameters(*itertools.product(range(20), range(1, 4))) 309 | def test_same_cardinality_on_all_hosts(self, num_examples: int, 310 | host_count: int): 311 | builder = MyDatasetBuilder({"train": num_examples}) 312 | cardinalities = [] 313 | for host_id in range(host_count): 314 | split = deterministic_data.get_read_instruction_for_host( 315 | split="train", 316 | num_examples=num_examples, 317 | host_id=host_id, 318 | host_count=host_count, 319 | drop_remainder=True) 320 | ds = deterministic_data.create_dataset( 321 | builder, split=split, batch_dims=[2], shuffle=False, num_epochs=1) 322 | cardinalities.append(ds.cardinality().numpy().item()) 323 | self.assertLen(set(cardinalities), 1) 324 | 325 | @parameterized.parameters(*itertools.product(range(20), range(1, 4))) 326 | def test_same_cardinality_on_all_hosts_with_pad(self, num_examples: int, 327 | host_count: int): 328 | builder = MyDatasetBuilder({"train": num_examples}) 329 | # All hosts should have the same number of batches. 330 | batch_size = 2 331 | pad_up_to_batches = int(math.ceil(num_examples / (batch_size * host_count))) 332 | assert pad_up_to_batches * batch_size * host_count >= num_examples 333 | cardinalities = [] 334 | for host_id in range(host_count): 335 | split = deterministic_data.get_read_instruction_for_host( 336 | split="train", 337 | num_examples=num_examples, 338 | host_id=host_id, 339 | host_count=host_count, 340 | drop_remainder=False) 341 | ds = deterministic_data.create_dataset( 342 | builder, 343 | split=split, 344 | batch_dims=[batch_size], 345 | shuffle=False, 346 | num_epochs=1, 347 | pad_up_to_batches=pad_up_to_batches) 348 | cardinalities.append(ds.cardinality().numpy().item()) 349 | self.assertLen(set(cardinalities), 1) 350 | 351 | 352 | if __name__ == "__main__": 353 | tf.test.main() 354 | -------------------------------------------------------------------------------- /clu/preprocess_spec.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The CLU Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Library for parsing a preprocessing spec. 16 | 17 | A preprocessing spec is a list of preprocessing ops separated by '|' that can be 18 | applied sequentially as a preprocessing function. The preprocessing ops are 19 | provided as input and must implement the PreprocessOp protocol. While not 20 | strictly required we also recommend annotating preprocess ops as dataclasses. 21 | 22 | By convention the preprocessing function operates on dictionaries of features. 23 | Each op can change the dictionary by modifying, adding or removing dictionary 24 | entries. Dictionary entries should be tensors, keys should be strings. 25 | (For common data types we recommend using the feature keys used in TFDS.) 26 | 27 | Example spec: 'fn1|fn2(3)|fn3(keyword=5)' 28 | This will construct the following preprocessing function: 29 | def preprocess_fn(features: Dict[str, tf.Tensor]) -> Dict[str, tf.Tensor]: 30 | features = fn1(features) 31 | features = fn2(features, 3) 32 | features = fn3(features, keyword=5) 33 | return features 34 | 35 | See preprocess_spec_test.py for some simple examples. 36 | """ 37 | 38 | import abc 39 | import ast 40 | import dataclasses 41 | import inspect 42 | import re 43 | import sys 44 | from typing import Dict, List, Sequence, Tuple, Type, TypeVar, Union 45 | 46 | from absl import logging 47 | from flax import traverse_util 48 | import jax.numpy as jnp 49 | import tensorflow as tf 50 | import typing_extensions 51 | from typing_extensions import Protocol 52 | 53 | 54 | # Feature dictionary. Arbitrary nested dictionary with string keys and 55 | # tf.Tensor as leaves. 56 | Tensor = Union[tf.Tensor, tf.RaggedTensor, tf.SparseTensor] 57 | # TFDS allows for nested `Features` ... 58 | Features = Dict[str, Union[Tensor, "Features"]] 59 | # ... but it's usually a better idea NOT to nest them. Also better for PyType. 60 | FlatFeatures = Dict[str, Tensor] 61 | D = TypeVar("D", FlatFeatures, tf.data.Dataset) 62 | 63 | # Feature name for the random seed for tf.random.stateless_* ops. By 64 | # convention ops should split of their random seed and keep the SEED_KEY 65 | # feature: 66 | # ``` 67 | # features[SEEQ_KEY], seed = tf.unstack( 68 | # tf.random.experimental.stateless_split(features[SEED_KEY])) 69 | # ```` 70 | SEED_KEY = "_seed" 71 | 72 | # Regex that finds upper case characters. 73 | _CAMEL_CASE_RGX = re.compile(r"(? Features: 91 | """Applies the preprocessing op to the features.""" 92 | 93 | 94 | # Deprecated. Please use `grain.tensorflow.MapTransform`. 95 | class MapTransform(abc.ABC): 96 | """Base class for transformations of single elements. 97 | 98 | This class implements the PreprocessOp interface and also: 99 | - Limits the features to a flat dictionary (instead of an arbitrary nested 100 | dictionary). 101 | - Provides a convenient implementation of `__call__` that can automatically 102 | apply the single transformation to a single example (`FlatFeatures`) or a 103 | `tf.data.Dataset`. The latter is convenient for SeqIO users migrating to 104 | preprocess_spec.py. For multiple transformations we still recommend users 105 | to use the `PreprocessFn` class. 106 | - Enforces subclasses to be a dataclasses. 107 | """ 108 | 109 | def __new__(cls, *args, **kwargs): 110 | del args, kwargs 111 | # Check that our subclass instance is a dataclass. We cannot do this with 112 | # `__init_subclass__`` because the dataclasses.dataclass decorator wraps 113 | # the intermediate class which is a subclass of MapTransform but not a 114 | # dataclass. 115 | if not dataclasses.is_dataclass(cls): 116 | raise ValueError( 117 | f"Class {cls} is not a dataclass. We strongly recommend annotating " 118 | "transformations with `@dataclasses.dataclass(frozen=True)`.") 119 | return super().__new__(cls) 120 | 121 | def __call__(self, features: D) -> D: 122 | """Applies the transformation to the features or the dataset.""" 123 | logging.warning("clu.preprocess_spec.MapTransform is deprecated. Please " 124 | "switch to grain.tensorflow.MapTransform.") 125 | if isinstance(features, tf.data.Dataset): 126 | return features.map(self._transform, num_parallel_calls=tf.data.AUTOTUNE) 127 | return self._transform(features) 128 | 129 | @abc.abstractmethod 130 | def _transform(self, features: FlatFeatures) -> FlatFeatures: 131 | """Transforms the features.""" 132 | 133 | 134 | # Deprecated. Please use `grain.tensorflow.RandomMapTransform`. 135 | class RandomMapTransform(MapTransform, abc.ABC): 136 | """Base class for random transformations of single elements. 137 | 138 | We require all random transformations to use stateless random operations (e.g. 139 | `tf.random.stateless_uniform()`) and respect the provided random seed. The 140 | user can expect the random seed to be unique for the element. 141 | 142 | If multiple random seeds are required the user can split the seed into N 143 | new seeds: 144 | ``` 145 | seeds = tf.unstack(tf.random.experimental.stateless_split(seed, N)) 146 | ``` 147 | """ 148 | 149 | def __call__(self, features: D) -> D: 150 | logging.warning("clu.preprocess_spec.RandomMapTransform is deprecated. " 151 | "Please switch to grain.tensorflow.RandomMapTransform.") 152 | if isinstance(features, tf.data.Dataset): 153 | return features.map(self, num_parallel_calls=tf.data.AUTOTUNE) 154 | 155 | next_seed, seed = tf.unstack( 156 | tf.random.experimental.stateless_split(features.pop(SEED_KEY))) 157 | features = self._transform(features, seed) 158 | features[SEED_KEY] = next_seed 159 | return features 160 | 161 | @abc.abstractmethod 162 | def _transform(self, features: FlatFeatures, seed: tf.Tensor) -> FlatFeatures: # pytype: disable=signature-mismatch # overriding-parameter-count-checks 163 | """Transforms the features only using stateless random ops.""" 164 | 165 | 166 | # Deprecated. Please use `grain.tensorflow.FilterMapTransform`. 167 | class FilterTransform(abc.ABC): 168 | 169 | def __call__(self, dataset: tf.data.Dataset) -> tf.data.Dataset: 170 | logging.warning("clu.preprocess_spec.FilterTransform is deprecated. Please " 171 | "switch to grain.tensorflow.FilterTransform.") 172 | return dataset.filter(self._predicate) 173 | 174 | @abc.abstractmethod 175 | def _predicate(self, features: FlatFeatures) -> tf.Tensor: 176 | """Returns a True if the element should be kept.""" 177 | 178 | 179 | def get_all_ops(module_name: str) -> List[Tuple[str, Type[PreprocessOp]]]: 180 | """Helper to return all preprocess ops in a module. 181 | 182 | Modules that define processing ops can simply define: 183 | all_ops = lambda: process_spec.get_all_ops(__name__) 184 | all_ops() will then return a list with all dataclasses implementing the 185 | PreprocessOp protocol. 186 | 187 | Args: 188 | module_name: Name of the module. The module must already be imported. 189 | 190 | Returns: 191 | List of tuples of process ops. The first tuple element is the class name 192 | converted to snake case (MyAwesomeTransform => my_awesome_transform) and 193 | the second element is the class. 194 | """ 195 | def is_op(x): 196 | return (inspect.isclass(x) and dataclasses.is_dataclass(x) and 197 | issubclass(x, PreprocessOp)) 198 | 199 | op_name = lambda n: _CAMEL_CASE_RGX.sub("_", n).lower() 200 | members = inspect.getmembers(sys.modules[module_name]) 201 | return [(op_name(name), op) for name, op in members if is_op(op)] 202 | 203 | 204 | def _jax_supported_tf_types(): 205 | types = [ 206 | x for _, x in inspect.getmembers(tf.dtypes) 207 | if isinstance(x, tf.dtypes.DType) and hasattr(jnp, x.name) 208 | ] 209 | # bool is called bool_ in jax and won't be found by the expression above. 210 | return types + [tf.bool] 211 | 212 | 213 | @dataclasses.dataclass 214 | class OnlyJaxTypes: 215 | """Removes all features which types are not supported by JAX. 216 | 217 | This filters dense tensors by dtype and removes sparse and ragged tensors. 218 | The latter don't have an equivalent in JAX. 219 | 220 | Attr: 221 | types: List of allowed types. Defaults to all TF types that can be have an 222 | equivalant in jax.numpy. 223 | """ 224 | 225 | types: List[tf.dtypes.DType] = dataclasses.field( 226 | default_factory=_jax_supported_tf_types) 227 | 228 | def __call__(self, features: Features) -> Features: 229 | features = traverse_util.flatten_dict(features) 230 | for name in list(features): 231 | dtype = features[name].dtype 232 | if dtype not in self.types: 233 | del features[name] 234 | logging.warning( 235 | "Removing feature %r because dtype %s is not supported in JAX.", 236 | name, dtype) 237 | elif isinstance(features[name], tf.SparseTensor): 238 | del features[name] 239 | logging.warning( 240 | "Removing feature %r because sparse tensors are not " 241 | "supported in JAX.", name) 242 | elif isinstance(features[name], tf.RaggedTensor): 243 | del features[name] 244 | logging.warning( 245 | "Removing feature %r because ragged tensors are not support in " 246 | "JAX.", name) 247 | features = traverse_util.unflatten_dict(features) 248 | return features # pytype: disable=bad-return-type 249 | 250 | 251 | @dataclasses.dataclass 252 | class PreprocessFn: 253 | """Chain of preprocessing ops combined to a single preprocessing function. 254 | 255 | Attributes: 256 | ops: List of feature transformations. Transformations will be applied in the 257 | given order. 258 | only_jax_types: If True will add the `OnlyJaxTypes` transformation at the 259 | end. 260 | """ 261 | 262 | ops: Sequence[PreprocessOp] 263 | only_jax_types: bool 264 | 265 | def __call__(self, features: Features) -> Features: 266 | """Sequentially applies all `self.ops` and returns the result.""" 267 | logging.info("Features before preprocessing: %s", 268 | _describe_features(features)) 269 | features = features.copy() 270 | for op in self.ops: 271 | features = op(features) 272 | logging.info("Features after op %s:\n%s", op, 273 | _describe_features(features)) 274 | logging.info("Features after preprocessing: %s", 275 | _describe_features(features)) 276 | if self.only_jax_types: 277 | features = OnlyJaxTypes()(features) 278 | return features 279 | 280 | def __add__(self, other: "PreprocessFn") -> "PreprocessFn": 281 | """Concatenates two `PreprocessingFn`.""" 282 | if not isinstance(other, PreprocessFn): 283 | raise ValueError("Can only add other instances of `PreprocessFn`.") 284 | return PreprocessFn( 285 | ops=tuple(self.ops) + tuple(other.ops), 286 | only_jax_types=self.only_jax_types or other.only_jax_types, 287 | ) 288 | 289 | def __getitem__(self, op_index: Union[int, slice]) -> "PreprocessFn": 290 | """Returns a `PreprocessFn` of the sliced ops.""" 291 | return PreprocessFn( 292 | ops=self.ops[op_index] 293 | if isinstance(op_index, slice) else [self.ops[op_index]], 294 | only_jax_types=self.only_jax_types, 295 | ) 296 | 297 | 298 | def _get_op_class( 299 | expr: List[ast.stmt], 300 | available_ops: Dict[str, Type[PreprocessOp]]) -> Type[PreprocessOp]: 301 | """Gets the process op fn from the given expression.""" 302 | if isinstance(expr, ast.Call): 303 | fn_name = expr.func.id 304 | elif isinstance(expr, ast.Name): 305 | fn_name = expr.id 306 | else: 307 | raise ValueError( 308 | f"Could not parse function name from expression: {expr!r}.") 309 | if fn_name in available_ops: 310 | return available_ops[fn_name] 311 | raise ValueError( 312 | f"'{fn_name}' is not available (available ops: {list(available_ops)}).") 313 | 314 | 315 | def _parse_single_preprocess_op( 316 | spec: str, available_ops: Dict[str, Type[PreprocessOp]]) -> PreprocessOp: 317 | """Parsing the spec for a single preprocess op. 318 | 319 | The op can just be the method name or the method name followed by any 320 | arguments (both positional and keyword) to the method. 321 | See the test cases for some valid examples. 322 | 323 | Args: 324 | spec: String specifying a single processing operations. 325 | available_ops: Available preprocessing ops. 326 | 327 | Returns: 328 | The ProcessOp corresponding to the spec. 329 | """ 330 | try: 331 | expr = ast.parse(spec, mode="eval").body # pytype: disable=attribute-error 332 | except SyntaxError as e: 333 | raise ValueError(f"{spec!r} is not a valid preprocess op spec.") from e 334 | op_class = _get_op_class(expr, available_ops) # pytype: disable=wrong-arg-types 335 | 336 | # Simple case without arguments. 337 | if isinstance(expr, ast.Name): 338 | return op_class() 339 | 340 | assert isinstance(expr, ast.Call) 341 | args = [ast.literal_eval(arg) for arg in expr.args] 342 | kwargs = {kv.arg: ast.literal_eval(kv.value) for kv in expr.keywords} 343 | if not args: 344 | return op_class(**kwargs) 345 | 346 | # Translate positional arguments into keyword arguments. 347 | available_arg_names = [f.name for f in dataclasses.fields(op_class)] 348 | for i, arg in enumerate(args): 349 | name = available_arg_names[i] 350 | if name in kwargs: 351 | raise ValueError( 352 | f"Argument {name} to {op_class} given both as positional argument " 353 | f"(value: {arg}) and keyword argument (value: {kwargs[name]}).") 354 | kwargs[name] = arg 355 | 356 | return op_class(**kwargs) 357 | 358 | 359 | def parse(spec: str, 360 | available_ops: List[Tuple[str, Type[PreprocessOp]]], 361 | *, 362 | only_jax_types: bool = True) -> PreprocessFn: 363 | """Parses a preprocess spec; a '|' separated list of preprocess ops.""" 364 | available_ops = dict(available_ops) 365 | if not spec.strip(): 366 | ops = [] 367 | else: 368 | ops = [ 369 | _parse_single_preprocess_op(s, available_ops) for s in spec.split("|") 370 | ] 371 | return PreprocessFn(ops, only_jax_types=only_jax_types) 372 | 373 | 374 | def _describe_features(features: Features) -> str: 375 | description = {} 376 | for k, v in features.items(): 377 | if isinstance(v, (tf.Tensor, tf.RaggedTensor, tf.SparseTensor)): 378 | description[k] = f"{v.dtype.name}{v.shape}" 379 | elif isinstance(v, dict): 380 | description[k] = _describe_features(v) 381 | else: 382 | description[k] = f"Unsupported type {type(v)} at feature '{k}'." 383 | return str(description) 384 | --------------------------------------------------------------------------------