├── .bazelrc ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── WORKSPACE.bazel ├── docker ├── Dockerfile └── build.sh ├── docs └── images │ ├── envlogger.png │ └── timings.png ├── envlogger ├── BUILD.bazel ├── __init__.py ├── backends │ ├── BUILD.bazel │ ├── __init__.py │ ├── backend_reader.py │ ├── backend_type.py │ ├── backend_writer.py │ ├── cc │ │ ├── BUILD.bazel │ │ ├── __init__.py │ │ ├── episode_info.h │ │ ├── riegeli_dataset_io_constants.h │ │ ├── riegeli_dataset_io_test.cc │ │ ├── riegeli_dataset_reader.cc │ │ ├── riegeli_dataset_reader.h │ │ ├── riegeli_dataset_reader_test.cc │ │ ├── riegeli_dataset_writer.cc │ │ ├── riegeli_dataset_writer.h │ │ ├── riegeli_dataset_writer_test.cc │ │ ├── riegeli_shard_io_test.cc │ │ ├── riegeli_shard_reader.cc │ │ ├── riegeli_shard_reader.h │ │ ├── riegeli_shard_reader_test.cc │ │ ├── riegeli_shard_writer.cc │ │ ├── riegeli_shard_writer.h │ │ └── riegeli_shard_writer_test.cc │ ├── cross_language_test │ │ ├── BUILD.bazel │ │ ├── cc_reader.cc │ │ ├── cross_language_test.py │ │ └── py_writer.py │ ├── in_memory_backend.py │ ├── in_memory_backend_test.py │ ├── python │ │ ├── BUILD.bazel │ │ ├── __init__.py │ │ ├── episode_info.cc │ │ ├── episode_info_test.py │ │ ├── riegeli_dataset_reader.cc │ │ ├── riegeli_dataset_test.py │ │ └── riegeli_dataset_writer.cc │ ├── riegeli_backend_reader.py │ ├── riegeli_backend_writer.py │ ├── riegeli_backend_writer_test.py │ ├── rlds_utils.py │ ├── rlds_utils_test.py │ ├── schedulers.py │ ├── schedulers_test.py │ ├── tfds_backend_testlib.py │ ├── tfds_backend_writer.py │ └── tfds_backend_writer_test.py ├── converters │ ├── BUILD.bazel │ ├── __init__.py │ ├── codec.py │ ├── codec_test.py │ ├── make_visitor.h │ ├── make_visitor_test.cc │ ├── spec_codec.py │ ├── spec_codec_test.py │ ├── xtensor_codec.cc │ ├── xtensor_codec.h │ └── xtensor_codec_test.cc ├── environment_logger.py ├── environment_logger_test.py ├── environment_wrapper.py ├── examples │ ├── BUILD.bazel │ ├── __init__.py │ ├── random_agent_catch.py │ └── tfds_random_agent_catch.py ├── platform │ ├── BUILD.bazel │ ├── bundle.h │ ├── default │ │ ├── BUILD.bazel │ │ ├── bundle.cc │ │ ├── bundle.h │ │ ├── filesystem.cc │ │ ├── filesystem.h │ │ ├── parse_text_proto.h │ │ ├── proto_testutil.h │ │ ├── riegeli_file_reader.h │ │ ├── riegeli_file_writer.h │ │ ├── source_location.h │ │ ├── status_builder.cc │ │ ├── status_builder.h │ │ └── status_macros.h │ ├── filesystem.h │ ├── parse_text_proto.h │ ├── proto_testutil.h │ ├── riegeli_file_reader.h │ ├── riegeli_file_writer.h │ ├── status_macros.h │ └── test_macros.h ├── proto │ ├── BUILD.bazel │ ├── __init__.py │ └── storage.proto ├── reader.py ├── requirements.txt ├── setup.py ├── step_data.py └── testing │ ├── BUILD.bazel │ ├── __init__.py │ └── catch_env.py └── patches ├── BUILD.bazel ├── crc32.BUILD.bazel ├── gmp.BUILD.bazel ├── highwayhash.BUILD.bazel ├── net_zstd.BUILD.bazel ├── proto_utils.cc.diff ├── riegeli.diff ├── snappy.BUILD.bazel ├── xtensor.BUILD.bazel ├── xtl.BUILD.bazel └── zlib.BUILD.bazel /.bazelrc: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited.. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # The `-std=c++17` option is required to compile the Riegeli library and 16 | # `-D_GLIBCXX_USE_CXX11_ABI=0` option is for compatibility with the shared 17 | # libraries compiled with the same setting. 18 | 19 | build -c opt 20 | build --cxxopt="-mavx" 21 | build --cxxopt="-std=c++17" 22 | build --cxxopt="-D_GLIBCXX_USE_CXX11_ABI=0" 23 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement (CLA). You (or your employer) retain the copyright to your 10 | contribution; this simply gives us permission to use and redistribute your 11 | contributions as part of the project. Head over to 12 | to see your current agreements on file or 13 | to sign a new one. 14 | 15 | You generally only need to submit a CLA once, so if you've already submitted one 16 | (even if it was for a different project), you probably don't need to do it 17 | again. 18 | 19 | ## Code Reviews 20 | 21 | All submissions, including submissions by project members, require review. We 22 | use GitHub pull requests for this purpose. Consult 23 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 24 | information on using pull requests. 25 | 26 | ## Community Guidelines 27 | 28 | This project follows 29 | [Google's Open Source Community Guidelines](https://opensource.google/conduct/). 30 | -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM ubuntu:20.04 2 | 3 | # ARG variable for Python 3 minor version. 4 | # This can be set with `--build-arg PY3_VERSION=10` for Python 3.10. 5 | ARG PY3_VERSION=10 6 | 7 | # Set up timezone to avoid getting stuck at `tzdata` setup. 8 | ENV TZ=America/Montreal 9 | RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone 10 | RUN apt update 11 | RUN apt install -y tzdata 12 | 13 | # Install `software-properties-common` to obtain add-apt-repository. 14 | RUN apt update && apt install -y software-properties-common 15 | 16 | # Add deadsnake PPA to get Python 3.10. 17 | RUN add-apt-repository ppa:deadsnakes/ppa 18 | 19 | # Install necessary packages. 20 | RUN apt-get update && apt-get install -y git curl wget software-properties-common python3.$PY3_VERSION python3.$PY3_VERSION-dev libgmp-dev gcc-9 g++-9 tmux vim 21 | 22 | # Install distutils if not Python 3.10 to get `distutils.util`. 23 | RUN if [ "$PY3_VERSION" != "10" ]; then apt install -y python3-distutils python3-apt; fi 24 | 25 | # Install pip. 26 | RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python3.$PY3_VERSION 27 | 28 | # Download bazel. 29 | RUN wget https://github.com/bazelbuild/bazel/releases/download/6.2.1/bazel-6.2.1-linux-x86_64 30 | RUN chmod +x /bazel-6.2.1-linux-x86_64 31 | RUN mv /bazel-6.2.1-linux-x86_64 /usr/bin/bazel 32 | 33 | # Add python alternatives. 34 | RUN update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.$PY3_VERSION 1 35 | 36 | # Override gcc/g++. 37 | RUN update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-9 60 --slave /usr/bin/g++ g++ /usr/bin/g++-9 38 | 39 | # Install some basic things for all python versions. 40 | RUN echo 1 | update-alternatives --config python3 41 | # Get latest `pip`. 42 | RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python3 43 | RUN python3 -m pip install --no-cache --upgrade pip setuptools 44 | 45 | ADD /envlogger/requirements.txt /tmp/requirements.txt 46 | RUN python3 -m pip install --no-cache -r /tmp/requirements.txt grpcio-tools 47 | 48 | 49 | # Add `python` so that `/usr/bin/env` finds it. This is used by `bazel`. 50 | RUN ln -s /usr/bin/python3 /usr/bin/python 51 | 52 | ADD . /envlogger/ 53 | WORKDIR /envlogger 54 | -------------------------------------------------------------------------------- /docker/build.sh: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited.. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | #!/bin/bash 16 | set -e 17 | DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" 18 | cd $DIR/.. 19 | 20 | # Default to Python 3.11. 21 | PY3_VERSION=${1:-11} 22 | echo Python version 3.${PY3_VERSION} 23 | 24 | # Default image label to "envlogger" 25 | IMAGE_LABEL=${2:-envlogger} 26 | echo Output docker image label: ${IMAGE_LABEL} 27 | 28 | docker build -t ${IMAGE_LABEL} -f docker/Dockerfile . --build-arg PY3_VERSION=${PY3_VERSION} 29 | -------------------------------------------------------------------------------- /docs/images/envlogger.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/envlogger/db88178769bf7e6a01e2ec52ff2b816db39adb2e/docs/images/envlogger.png -------------------------------------------------------------------------------- /docs/images/timings.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/envlogger/db88178769bf7e6a01e2ec52ff2b816db39adb2e/docs/images/timings.png -------------------------------------------------------------------------------- /envlogger/BUILD.bazel: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited.. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Build targets for environment logger. 16 | load("@rules_python//python:defs.bzl", "py_library") 17 | 18 | package(default_visibility = ["//visibility:public"]) 19 | 20 | py_library( 21 | name = "envlogger", 22 | srcs = ["__init__.py"], 23 | deps = [ 24 | ":environment_logger", 25 | ":reader", 26 | ":step_data", 27 | "//envlogger/backends:backend_type", 28 | "//envlogger/backends:riegeli_backend_writer", 29 | "//envlogger/backends:schedulers", 30 | "//envlogger/proto:storage_py_pb2", 31 | ], 32 | ) 33 | 34 | py_library( 35 | name = "environment_logger", 36 | srcs = ["environment_logger.py"], 37 | deps = [ 38 | ":environment_wrapper", 39 | ":step_data", 40 | "//envlogger/backends:backend_type", 41 | "//envlogger/backends:backend_writer", 42 | "//envlogger/backends:in_memory_backend", 43 | "//envlogger/backends:riegeli_backend_writer", 44 | "//envlogger/converters:spec_codec", 45 | ], 46 | ) 47 | 48 | py_test( 49 | name = "environment_logger_test", 50 | srcs = ["environment_logger_test.py"], 51 | deps = [ 52 | ":environment_logger", 53 | ":reader", 54 | ":step_data", 55 | "//envlogger/backends:backend_writer", 56 | "//envlogger/backends:in_memory_backend", 57 | "//envlogger/backends:schedulers", 58 | "//envlogger/converters:codec", 59 | "//envlogger/converters:spec_codec", 60 | "//envlogger/proto:storage_py_pb2", 61 | "//envlogger/testing:catch_env", 62 | "@com_google_riegeli//python/riegeli", 63 | ], 64 | ) 65 | 66 | py_library( 67 | name = "environment_wrapper", 68 | srcs = ["environment_wrapper.py"], 69 | deps = [ 70 | ], 71 | ) 72 | 73 | py_library( 74 | name = "reader", 75 | srcs = ["reader.py"], 76 | data = [ 77 | "//envlogger/backends/python:episode_info.so", 78 | "//envlogger/backends/python:riegeli_dataset_reader.so", 79 | ], 80 | deps = [ 81 | ":step_data", 82 | "//envlogger/backends:backend_reader", 83 | "//envlogger/backends:backend_type", 84 | "//envlogger/backends:in_memory_backend", 85 | "//envlogger/backends:riegeli_backend_reader", 86 | "//envlogger/converters:spec_codec", 87 | ], 88 | ) 89 | 90 | py_library( 91 | name = "step_data", 92 | srcs = ["step_data.py"], 93 | deps = [ 94 | ], 95 | ) 96 | -------------------------------------------------------------------------------- /envlogger/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 DeepMind Technologies Limited.. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """A one-stop import for commonly used modules in EnvLogger.""" 17 | 18 | from envlogger import environment_logger 19 | from envlogger import reader 20 | from envlogger import step_data 21 | from envlogger.backends import backend_type 22 | from envlogger.backends import riegeli_backend_writer 23 | from envlogger.backends import schedulers 24 | from envlogger.proto import storage_pb2 25 | 26 | EnvLogger = environment_logger.EnvLogger 27 | Reader = reader.Reader 28 | BackendType = backend_type.BackendType 29 | StepData = step_data.StepData 30 | Scheduler = schedulers.Scheduler 31 | RiegeliBackendWriter = riegeli_backend_writer.RiegeliBackendWriter 32 | Data = storage_pb2.Data 33 | Datum = storage_pb2.Datum 34 | -------------------------------------------------------------------------------- /envlogger/backends/BUILD.bazel: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited.. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Backends. 16 | load("@rules_python//python:defs.bzl", "py_library", "py_test") 17 | 18 | package(default_visibility = ["//visibility:public"]) 19 | 20 | py_library( 21 | name = "backend_type", 22 | srcs = ["backend_type.py"], 23 | ) 24 | 25 | py_library( 26 | name = "schedulers", 27 | srcs = ["schedulers.py"], 28 | deps = [ 29 | "//envlogger:step_data", 30 | ], 31 | ) 32 | 33 | py_test( 34 | name = "schedulers_test", 35 | srcs = ["schedulers_test.py"], 36 | deps = [ 37 | ":schedulers", 38 | "//envlogger:step_data", 39 | ], 40 | ) 41 | 42 | py_library( 43 | name = "backend_reader", 44 | srcs = ["backend_reader.py"], 45 | data = ["//envlogger/backends/python:episode_info.so"], 46 | deps = [ 47 | "//envlogger:step_data", 48 | "//envlogger/converters:codec", 49 | ], 50 | ) 51 | 52 | py_library( 53 | name = "backend_writer", 54 | srcs = ["backend_writer.py"], 55 | deps = [ 56 | "//envlogger:step_data", 57 | ], 58 | ) 59 | 60 | py_library( 61 | name = "in_memory_backend", 62 | srcs = ["in_memory_backend.py"], 63 | data = ["//envlogger/backends/python:episode_info.so"], 64 | deps = [ 65 | ":backend_reader", 66 | ":backend_writer", 67 | "//envlogger:step_data", 68 | ], 69 | ) 70 | 71 | py_test( 72 | name = "in_memory_backend_test", 73 | srcs = ["in_memory_backend_test.py"], 74 | deps = [ 75 | ":backend_writer", 76 | ":in_memory_backend", 77 | ":schedulers", 78 | "//envlogger:step_data", 79 | "//envlogger/testing:catch_env", 80 | ], 81 | ) 82 | 83 | py_library( 84 | name = "riegeli_backend_reader", 85 | srcs = ["riegeli_backend_reader.py"], 86 | data = [ 87 | "//envlogger/backends/python:episode_info.so", 88 | "//envlogger/backends/python:riegeli_dataset_reader.so", 89 | ], 90 | deps = [ 91 | ":backend_reader", 92 | "//envlogger:step_data", 93 | "//envlogger/converters:codec", 94 | "//envlogger/proto:storage_py_pb2", 95 | ], 96 | ) 97 | 98 | py_library( 99 | name = "riegeli_backend_writer", 100 | srcs = ["riegeli_backend_writer.py"], 101 | data = [ 102 | "//envlogger/backends/python:riegeli_dataset_writer.so", 103 | ], 104 | deps = [ 105 | ":backend_writer", 106 | ":schedulers", 107 | "//envlogger:step_data", 108 | "//envlogger/converters:codec", 109 | ], 110 | ) 111 | 112 | py_test( 113 | name = "riegeli_backend_writer_test", 114 | srcs = ["riegeli_backend_writer_test.py"], 115 | deps = [ 116 | ":backend_writer", 117 | ":riegeli_backend_reader", 118 | ":riegeli_backend_writer", 119 | ":schedulers", 120 | "//envlogger:step_data", 121 | "//envlogger/testing:catch_env", 122 | ], 123 | ) 124 | 125 | py_library( 126 | name = "rlds_utils", 127 | srcs = ["rlds_utils.py"], 128 | deps = [ 129 | "//envlogger:step_data", 130 | ], 131 | ) 132 | 133 | py_library( 134 | name = "tfds_backend_writer", 135 | srcs = ["tfds_backend_writer.py"], 136 | deps = [ 137 | ":backend_writer", 138 | ":rlds_utils", 139 | "//envlogger:step_data", 140 | ], 141 | ) 142 | 143 | py_library( 144 | name = "tfds_backend_testlib", 145 | srcs = ["tfds_backend_testlib.py"], 146 | deps = [ 147 | ":backend_writer", 148 | ":schedulers", 149 | ":tfds_backend_writer", 150 | "//envlogger:step_data", 151 | "//envlogger/testing:catch_env", 152 | ], 153 | ) 154 | 155 | py_test( 156 | name = "rlds_utils_test", 157 | srcs = ["rlds_utils_test.py"], 158 | deps = [ 159 | ":rlds_utils", 160 | ":tfds_backend_testlib", 161 | "//envlogger:step_data", 162 | ], 163 | ) 164 | 165 | py_test( 166 | name = "tfds_backend_writer_test", 167 | srcs = ["tfds_backend_writer_test.py"], 168 | deps = [ 169 | ":backend_writer", 170 | ":rlds_utils", 171 | ":tfds_backend_testlib", 172 | ":tfds_backend_writer", 173 | "//envlogger:step_data", 174 | ], 175 | ) 176 | -------------------------------------------------------------------------------- /envlogger/backends/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 DeepMind Technologies Limited.. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /envlogger/backends/backend_reader.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 DeepMind Technologies Limited.. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Abstract interface for reading trajectories.""" 17 | 18 | import abc 19 | from collections.abc import Callable, Iterator, Sequence 20 | from typing import Any, Generic, Optional, TypeVar, Union 21 | 22 | from absl import logging 23 | from envlogger import step_data 24 | from envlogger.backends.python import episode_info 25 | from envlogger.converters import codec 26 | 27 | 28 | T = TypeVar('T') 29 | 30 | 31 | class _SequenceAdapter(Generic[T], Sequence[T]): 32 | """Convenient visitor for episodes/steps.""" 33 | 34 | def __init__(self, count: int, get_nth_item: Callable[[int], T]): 35 | """Constructor. 36 | 37 | Args: 38 | count: Total number of items. 39 | get_nth_item: Function to get the nth item. 40 | """ 41 | self._count = count 42 | self._index = 0 43 | self._get_nth_item = get_nth_item 44 | 45 | def __getitem__(self, index: Union[int, slice]) -> Union[T, list[T]]: 46 | """Retrieves items from this sequence. 47 | 48 | Args: 49 | index: item index or slice of indices. 50 | 51 | Returns: 52 | The item at `index` if index is of type `int`, or a list of items if 53 | `index` is a slice. If `index` is a negative integer, then it is 54 | equivalent to index + len(self). 55 | 56 | Raises: 57 | IndexError: if index is an integer outside of the bounds [-length, 58 | length - 1]. 59 | """ 60 | if isinstance(index, slice): 61 | indices = index.indices(len(self)) 62 | return [self._get_nth_item(i) for i in range(*indices)] 63 | if index >= self._count or index < -self._count: 64 | raise IndexError(f'`index`=={index} is out of the range [{-self._count}, ' 65 | f'{self._count - 1}].') 66 | index = index if index >= 0 else index + self._count 67 | return self._get_nth_item(index) 68 | 69 | def __len__(self) -> int: 70 | return self._count 71 | 72 | def __iter__(self) -> Iterator[T]: 73 | while self._index < len(self): 74 | yield self[self._index] 75 | self._index += 1 76 | self._index = 0 77 | 78 | def __next__(self) -> T: 79 | if self._index < len(self): 80 | index = self._index 81 | self._index += 1 82 | return self[index] 83 | else: 84 | raise StopIteration() 85 | 86 | 87 | class BackendReader(metaclass=abc.ABCMeta): 88 | """Base class for trajectory readers.""" 89 | 90 | def __init__(self): 91 | self._init_visitors() 92 | 93 | def copy(self) -> 'BackendReader': 94 | """Returns a copy of self.""" 95 | 96 | c = self._copy() 97 | c._init_visitors() 98 | return c 99 | 100 | @abc.abstractmethod 101 | def _copy(self) -> 'BackendReader': 102 | """Implementation-specific copy behavior.""" 103 | 104 | def _init_visitors(self): 105 | """Initializes visitors.""" 106 | 107 | logging.info('Creating visitors.') 108 | self._steps = _SequenceAdapter( 109 | count=self._get_num_steps(), get_nth_item=self._get_nth_step) 110 | self._episodes = _SequenceAdapter( 111 | count=self._get_num_episodes(), get_nth_item=self._get_nth_episode) 112 | self._episode_metadata = _SequenceAdapter( 113 | count=self._get_num_episodes(), 114 | get_nth_item=self._get_nth_episode_metadata) 115 | logging.info('Done creating visitors.') 116 | 117 | @abc.abstractmethod 118 | def _get_nth_step(self, i: int) -> step_data.StepData: 119 | pass 120 | 121 | @abc.abstractmethod 122 | def _get_num_steps(self) -> int: 123 | pass 124 | 125 | @abc.abstractmethod 126 | def _get_num_episodes(self) -> int: 127 | pass 128 | 129 | @abc.abstractmethod 130 | def _get_nth_episode_info(self, 131 | i: int, 132 | include_metadata: bool = False 133 | ) -> episode_info.EpisodeInfo: 134 | pass 135 | 136 | def _get_nth_episode(self, i: int) -> Sequence[step_data.StepData]: 137 | """Yields timesteps for episode `i` (0-based).""" 138 | episode = self._get_nth_episode_info(i, include_metadata=False) 139 | 140 | def get_nth_step_from_episode(j: int): 141 | return self._get_nth_step(episode.start + j) 142 | 143 | return _SequenceAdapter( 144 | count=episode.num_steps, get_nth_item=get_nth_step_from_episode) 145 | 146 | def _get_nth_episode_metadata(self, i: int) -> Optional[Any]: 147 | """Returns the metadata for episode `i` (0-based).""" 148 | episode = self._get_nth_episode_info(i, include_metadata=True) 149 | return codec.decode(episode.metadata) 150 | 151 | def __enter__(self): 152 | return self 153 | 154 | def __exit__(self, exc_type, exc_value, tb): 155 | self.close() 156 | 157 | def __del__(self): 158 | self.close() 159 | 160 | @abc.abstractmethod 161 | def close(self) -> None: 162 | pass 163 | 164 | @abc.abstractmethod 165 | def metadata(self) -> dict[str, Any]: 166 | pass 167 | 168 | @property 169 | def episodes(self) -> Sequence[Sequence[step_data.StepData]]: 170 | return self._episodes 171 | 172 | def episode_metadata(self) -> Sequence[Optional[Any]]: 173 | return self._episode_metadata 174 | 175 | @property 176 | def steps(self) -> Sequence[step_data.StepData]: 177 | return self._steps 178 | -------------------------------------------------------------------------------- /envlogger/backends/backend_type.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 DeepMind Technologies Limited.. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """An enumeration that specifies the logging backend.""" 17 | 18 | import enum 19 | 20 | 21 | class BackendType(enum.IntEnum): 22 | RIEGELI = 0 23 | IN_MEMORY = 1 24 | -------------------------------------------------------------------------------- /envlogger/backends/backend_writer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 DeepMind Technologies Limited.. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Abstract trajectory logging interface.""" 17 | 18 | import abc 19 | from typing import Any, Optional 20 | 21 | from envlogger import step_data 22 | from envlogger.backends import schedulers 23 | 24 | 25 | class BackendWriter(metaclass=abc.ABCMeta): 26 | """Abstract trajectory logging interface.""" 27 | 28 | def __init__( 29 | self, 30 | metadata: Optional[dict[str, Any]] = None, 31 | scheduler: Optional[schedulers.Scheduler] = None, 32 | ): 33 | """BackendWriter base class. 34 | 35 | Args: 36 | metadata: Any dataset-level custom data to be written. 37 | scheduler: A callable that takes the current timestep, current 38 | action, the environment itself and returns True if the current step 39 | should be logged, False otherwise. This function is called _before_ 40 | `step_fn`, meaning that if it returns False, `step_fn` will not be 41 | called at all. NOTE: This scheduler should NEVER skip the first timestep 42 | in the episode, otherwise `EnvLogger` will not know that such episode 43 | really exists. 44 | """ 45 | self._scheduler = scheduler 46 | self._metadata = metadata 47 | 48 | def record_step(self, data: step_data.StepData, is_new_episode: bool) -> None: 49 | if (self._scheduler is not None and not self._scheduler(data)): 50 | return 51 | self._record_step(data, is_new_episode) 52 | 53 | @abc.abstractmethod 54 | def set_episode_metadata(self, data: Any) -> None: 55 | pass 56 | 57 | @abc.abstractmethod 58 | def _record_step(self, data: step_data.StepData, 59 | is_new_episode: bool) -> None: 60 | pass 61 | 62 | @abc.abstractmethod 63 | def close(self) -> None: 64 | pass 65 | 66 | def __del__(self): 67 | self.close() 68 | 69 | def metadata(self): 70 | return self._metadata 71 | -------------------------------------------------------------------------------- /envlogger/backends/cc/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 DeepMind Technologies Limited.. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /envlogger/backends/cc/episode_info.h: -------------------------------------------------------------------------------- 1 | // Copyright 2024 DeepMind Technologies Limited.. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | // Struct for efficiently retrieving episodic information. 16 | #ifndef THIRD_PARTY_PY_ENVLOGGER_BACKENDS_CC_EPISODE_INFO_H_ 17 | #define THIRD_PARTY_PY_ENVLOGGER_BACKENDS_CC_EPISODE_INFO_H_ 18 | 19 | #include 20 | #include 21 | 22 | #include "envlogger/proto/storage.pb.h" 23 | 24 | namespace envlogger { 25 | 26 | struct EpisodeInfo { 27 | // The step index where this episode starts. 28 | // This can be either a local step within a single trajectory file, or a 29 | // global step across many shards. 30 | int64_t start; 31 | // The number of steps in this episode. 32 | int64_t num_steps; 33 | // Optional metadata which is only filled if requested. 34 | std::optional metadata; 35 | }; 36 | 37 | } // namespace envlogger 38 | 39 | #endif // THIRD_PARTY_PY_ENVLOGGER_BACKENDS_CC__EPISODE_INFO_H_ 40 | -------------------------------------------------------------------------------- /envlogger/backends/cc/riegeli_dataset_io_constants.h: -------------------------------------------------------------------------------- 1 | // Copyright 2024 DeepMind Technologies Limited.. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #ifndef THIRD_PARTY_PY_ENVLOGGER_BACKENDS_CC_RIEGELI_DATASET_IO_CONSTANTS_H_ 16 | #define THIRD_PARTY_PY_ENVLOGGER_BACKENDS_CC_RIEGELI_DATASET_IO_CONSTANTS_H_ 17 | 18 | #include "absl/strings/string_view.h" 19 | 20 | namespace envlogger { 21 | // These constants are used internally in our codebase and should not be relied 22 | // upon by clients. 23 | namespace internal { 24 | // The riegeli filename for metadata set at Init() time. 25 | inline constexpr absl::string_view kMetadataFilename = "metadata.riegeli"; 26 | // Steps (timesteps, actions and per-step metadata). 27 | inline constexpr absl::string_view kStepsFilename = "steps.riegeli"; 28 | // Step offsets for faster seeking into kStepsFilename. 29 | inline constexpr absl::string_view kStepOffsetsFilename = 30 | "step_offsets.riegeli"; 31 | // Episodic metadata. 32 | inline constexpr absl::string_view kEpisodeMetadataFilename = 33 | "episode_metadata.riegeli"; 34 | // Episode offsets for faster seeking into kEpisodeMetadataFilename and also 35 | // episode starts. 36 | inline constexpr absl::string_view kEpisodeIndexFilename = 37 | "episode_index.riegeli"; 38 | } // namespace internal 39 | } // namespace envlogger 40 | 41 | #endif // THIRD_PARTY_PY_ENVLOGGER_BACKENDS_CC_RIEGELI_DATASET_IO_CONSTANTS_H_ 42 | -------------------------------------------------------------------------------- /envlogger/backends/cc/riegeli_dataset_io_test.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2024 DeepMind Technologies Limited.. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | #include 21 | 22 | #include "gmock/gmock.h" 23 | #include "gtest/gtest.h" 24 | #include "absl/flags/flag.h" 25 | #include "absl/status/status.h" 26 | #include "absl/strings/string_view.h" 27 | #include "absl/time/clock.h" 28 | #include "absl/time/time.h" 29 | #include "envlogger/backends/cc/episode_info.h" 30 | #include "envlogger/backends/cc/riegeli_dataset_reader.h" 31 | #include "envlogger/backends/cc/riegeli_dataset_writer.h" 32 | #include "envlogger/backends/cc/riegeli_shard_writer.h" 33 | #include "envlogger/platform/filesystem.h" 34 | #include "envlogger/platform/parse_text_proto.h" 35 | #include "envlogger/platform/proto_testutil.h" 36 | #include "envlogger/platform/test_macros.h" 37 | #include "envlogger/proto/storage.pb.h" 38 | #include "riegeli/records/record_reader.h" 39 | #include "riegeli/records/record_writer.h" 40 | 41 | namespace envlogger { 42 | namespace { 43 | 44 | using ::testing::Eq; 45 | using ::testing::Not; 46 | using ::testing::Value; 47 | 48 | // A simple matcher to compare the output of RiegeliDatasetReader::Episode(). 49 | MATCHER_P2(EqualsEpisode, start_index, num_steps, "") { 50 | return Value(arg.start, start_index) && Value(arg.num_steps, num_steps); 51 | } 52 | 53 | TEST(RiegeliDatasetTest, MetadataTest) { 54 | const std::string data_dir = 55 | file::JoinPath(getenv("TEST_TMPDIR"), "metadata"); 56 | const Data metadata = 57 | ParseTextProtoOrDie("datum: { values: { int32_values: 1234 } }"); 58 | const int max_episodes_per_shard = -1; 59 | 60 | ENVLOGGER_EXPECT_OK(file::CreateDir(data_dir)); 61 | { 62 | RiegeliDatasetWriter writer; 63 | ENVLOGGER_EXPECT_OK(writer.Init(data_dir, metadata, max_episodes_per_shard, 64 | "transpose,brotli:6,chunk_size:1M")); 65 | // Write a single step to pass RiegeliDatasetReader::Init()'s strict checks. 66 | Data data; 67 | data.mutable_datum()->mutable_values()->add_float_values(1.234f); 68 | writer.AddStep(data, /*is_new_episode=*/true); 69 | writer.Flush(); 70 | } 71 | 72 | RiegeliDatasetReader reader; 73 | ENVLOGGER_EXPECT_OK(reader.Init(data_dir)); 74 | const auto actual_metadata = reader.Metadata(); 75 | EXPECT_THAT(actual_metadata, Not(Eq(std::nullopt))); 76 | EXPECT_THAT(*actual_metadata, EqualsProto(metadata)); 77 | 78 | ENVLOGGER_EXPECT_OK(file::RecursivelyDelete(data_dir)); 79 | } 80 | 81 | } // namespace 82 | } // namespace envlogger 83 | -------------------------------------------------------------------------------- /envlogger/backends/cc/riegeli_dataset_reader.h: -------------------------------------------------------------------------------- 1 | // Copyright 2024 DeepMind Technologies Limited.. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #ifndef THIRD_PARTY_PY_ENVLOGGER_BACKENDS_CC_RIEGELI_DATASET_READER_H_ 16 | #define THIRD_PARTY_PY_ENVLOGGER_BACKENDS_CC_RIEGELI_DATASET_READER_H_ 17 | 18 | #include 19 | #include 20 | #include 21 | #include 22 | #include 23 | #include 24 | #include 25 | 26 | #include "google/protobuf/message.h" 27 | #include "absl/status/status.h" 28 | #include "absl/status/statusor.h" 29 | #include "absl/strings/string_view.h" 30 | #include "envlogger/backends/cc/episode_info.h" 31 | #include "envlogger/backends/cc/riegeli_shard_reader.h" 32 | #include "envlogger/backends/cc/riegeli_shard_writer.h" 33 | #include "envlogger/proto/storage.pb.h" 34 | 35 | namespace envlogger { 36 | 37 | // A RiegeliDatasetReader contains trajectory information for the entire 38 | // lifetime of a single Reinforcement Learning (RL) actor. Each 39 | // RiegeliDatasetReader represents a one dimensional trajectory composed of 0 or 40 | // more steps. Steps can be grouped into non-overlapping, contiguous episodes 41 | // for episodic RL environments. 42 | // 43 | // Internally, a RiegeliDatasetReader is sharded into chronologically ordered 44 | // sub directories called "timestamp directories". Each of these sub directories 45 | // contain their own trajectories and indexes, so in order to find a specific 46 | // step or episode we first need to determine the sub directory and then their 47 | // internal index. An episode is never split between two shards. 48 | class RiegeliDatasetReader { 49 | public: 50 | RiegeliDatasetReader() = default; 51 | RiegeliDatasetReader(RiegeliDatasetReader&&) = default; 52 | ~RiegeliDatasetReader(); 53 | 54 | // Clones a RiegeliDatasetReader. 55 | // The cloned reader can be safely used in a different thread. 56 | absl::StatusOr Clone(); 57 | 58 | // Releases resources and closes the underlying files. 59 | void Close(); 60 | 61 | absl::Status Init(absl::string_view data_dir); 62 | 63 | // Returns metadata associated with this RiegeliDatasetReader. 64 | std::optional Metadata() const; 65 | 66 | int64_t NumSteps() const; 67 | int64_t NumEpisodes() const; 68 | 69 | // Returns step data at index `step_index`. 70 | // Returns nullopt if `step_index` is not in [0, NumSteps()). 71 | template 72 | std::optional Step(int64_t step_index); 73 | 74 | // Returns information for accessing a specific episode. 75 | // 76 | // NOTE: `start` in the returned object refers to the "global" step index, not 77 | // the local one in the shard. That is, this value can be passed to Step() 78 | // below without any modification. 79 | // 80 | // Returns nullopt if `episode_index` is not in [0, NumEpisodes()). 81 | std::optional Episode(int64_t episode_index, 82 | bool include_metadata = false); 83 | 84 | // Returns a shard reader for the specified episode index. 85 | absl::StatusOr GetShard(int64_t episode_index); 86 | 87 | std::string DataDir() const { return data_dir_; } 88 | 89 | private: 90 | // A Shard represents a single timestamp directory. 91 | struct Shard { 92 | // The path to the timestamp directory. 93 | std::string timestamp_dir; 94 | // The index internal to this timestamp directory. 95 | RiegeliShardReader index; 96 | 97 | // The global step index at which this shard starts. 98 | int64_t global_step_index = -1; 99 | 100 | // The cumulative number of steps up to this shard (in the order that it's 101 | // inserted in shards_). 102 | int64_t cumulative_steps = 0; 103 | // The cumulative number of episodes up to this shard (in the order that 104 | // it's inserted in shards_). 105 | int64_t cumulative_episodes = 0; 106 | }; 107 | 108 | // Returns the first element in `shards_` that is not greater than 109 | // `global_index`. 110 | std::pair FindShard( 111 | int64_t global_index, std::function extractor); 112 | 113 | // The data directory associated with this reader. 114 | std::string data_dir_; 115 | 116 | // The total number of steps across all timestamp directories. 117 | int64_t total_num_steps_ = 0; 118 | // The total number of episodes across all timestamp directories. 119 | // Note: This implementation differs from the Python one because the latter 120 | // depends on an explicit "max_episodes_per_file" entry in the metadata while 121 | // this one calculates the same thing via binary search. 122 | int64_t total_num_episodes_ = 0; 123 | 124 | std::optional metadata_; 125 | 126 | // The list of all shards in this RiegeliDatasetReader. 127 | std::vector shards_; 128 | }; 129 | 130 | //////////////////////////////////////////////////////////////////////////////// 131 | // Implementation details of RiegeliDatasetReader::Step. 132 | //////////////////////////////////////////////////////////////////////////////// 133 | 134 | template 135 | std::optional RiegeliDatasetReader::Step(int64_t step_index) { 136 | if (step_index < 0 || step_index >= NumSteps()) return std::nullopt; 137 | 138 | Shard* shard = nullptr; 139 | int64_t local_step_index = -1; 140 | std::tie(shard, local_step_index) = 141 | FindShard(step_index, [](const RiegeliDatasetReader::Shard& shard) { 142 | return shard.cumulative_steps; 143 | }); 144 | return shard->index.Step(local_step_index); 145 | } 146 | 147 | } // namespace envlogger 148 | 149 | #endif // THIRD_PARTY_PY_ENVLOGGER_BACKENDS_CC_RIEGELI_DATASET_READER_H_ 150 | -------------------------------------------------------------------------------- /envlogger/backends/cc/riegeli_dataset_writer.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2024 DeepMind Technologies Limited.. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #include "envlogger/backends/cc/riegeli_dataset_writer.h" 16 | 17 | #include 18 | #include 19 | #include 20 | #include 21 | #include 22 | #include 23 | 24 | #include "glog/logging.h" 25 | #include "google/protobuf/message.h" 26 | #include "absl/status/status.h" 27 | #include "absl/status/statusor.h" 28 | #include "absl/strings/string_view.h" 29 | #include "absl/time/clock.h" 30 | #include "absl/time/time.h" 31 | #include "envlogger/backends/cc/episode_info.h" 32 | #include "envlogger/backends/cc/riegeli_dataset_io_constants.h" 33 | #include "envlogger/backends/cc/riegeli_shard_reader.h" 34 | #include "envlogger/backends/cc/riegeli_shard_writer.h" 35 | #include "envlogger/platform/filesystem.h" 36 | #include "envlogger/platform/riegeli_file_writer.h" 37 | #include "envlogger/platform/status_macros.h" 38 | #include "envlogger/proto/storage.pb.h" 39 | #include "riegeli/base/types.h" 40 | #include "riegeli/records/record_reader.h" 41 | #include "riegeli/records/record_writer.h" 42 | 43 | namespace envlogger { 44 | 45 | namespace { 46 | 47 | // Writes `data` as a single record in the Riegeli file pointed by `filepath`. 48 | absl::Status WriteSingleRiegeliRecord(const absl::string_view filepath, 49 | const Data& data) { 50 | riegeli::RecordWriter writer( 51 | RiegeliFileWriter(filepath), 52 | riegeli::RecordWriterBase::Options().set_transpose(true)); 53 | if (!writer.WriteRecord(data)) return writer.status(); 54 | if (!writer.Flush(riegeli::FlushType::kFromMachine)) return writer.status(); 55 | if (!writer.Close()) return writer.status(); 56 | 57 | return absl::OkStatus(); 58 | } 59 | 60 | std::string NewTimestampDirName(absl::Time time) { 61 | return absl::FormatTime("%Y%m%dT%H%M%S%E6f", time, absl::UTCTimeZone()); 62 | } 63 | 64 | absl::Status CreateRiegeliShardWriter(absl::string_view data_dir, 65 | absl::string_view writer_options, 66 | RiegeliShardWriter* writer) { 67 | const std::string dirname = NewTimestampDirName(absl::Now()); 68 | const std::string timestamp_dir = file::JoinPath(data_dir, dirname); 69 | ENVLOGGER_RETURN_IF_ERROR(file::CreateDir(timestamp_dir)); 70 | writer->Flush(); // Flush before creating a new one. 71 | ENVLOGGER_RETURN_IF_ERROR(writer->Init( 72 | /*steps_filepath=*/file::JoinPath(timestamp_dir, 73 | internal::kStepsFilename), 74 | /*step_offsets_filepath=*/ 75 | file::JoinPath(timestamp_dir, internal::kStepOffsetsFilename), 76 | /*episode_metadata_filepath=*/ 77 | file::JoinPath(timestamp_dir, internal::kEpisodeMetadataFilename), 78 | /*episode_index_filepath=*/ 79 | file::JoinPath(timestamp_dir, internal::kEpisodeIndexFilename), 80 | writer_options)); 81 | return absl::OkStatus(); 82 | } 83 | 84 | } // namespace 85 | 86 | absl::Status RiegeliDatasetWriter::Init(std::string data_dir, 87 | const Data& metadata, 88 | int64_t max_episodes_per_shard, 89 | std::string writer_options, 90 | int episode_counter) { 91 | if (data_dir.empty()) return absl::NotFoundError("Empty data_dir."); 92 | 93 | data_dir_ = data_dir; 94 | writer_options_ = writer_options; 95 | const std::string metadata_filepath = 96 | file::JoinPath(data_dir, internal::kMetadataFilename); 97 | if (!file::GetSize(metadata_filepath).ok()) { 98 | // If metadata does not yet exist, write it. 99 | ENVLOGGER_RETURN_IF_ERROR(WriteSingleRiegeliRecord( 100 | /*filepath=*/metadata_filepath, /*data=*/metadata)); 101 | } 102 | max_episodes_per_shard_ = max_episodes_per_shard; 103 | if (max_episodes_per_shard_ <= 0) { 104 | ENVLOGGER_RETURN_IF_ERROR( 105 | CreateRiegeliShardWriter(data_dir, writer_options_, &writer_)); 106 | } 107 | episode_counter_ = episode_counter; 108 | 109 | return absl::OkStatus(); 110 | } 111 | 112 | bool RiegeliDatasetWriter::AddStep(const google::protobuf::Message& data, 113 | bool is_new_episode) { 114 | if (is_new_episode) { 115 | if (max_episodes_per_shard_ > 0 && 116 | episode_counter_++ % max_episodes_per_shard_ == 0) { 117 | ENVLOGGER_CHECK_OK( 118 | CreateRiegeliShardWriter(data_dir_, writer_options_, &writer_)); 119 | } 120 | } 121 | return writer_.AddStep(data, is_new_episode); 122 | } 123 | 124 | void RiegeliDatasetWriter::SetEpisodeMetadata(const Data& data) { 125 | writer_.SetEpisodeMetadata(data); 126 | } 127 | 128 | void RiegeliDatasetWriter::Flush() { writer_.Flush(); } 129 | 130 | void RiegeliDatasetWriter::Close() { writer_.Close(); } 131 | 132 | } // namespace envlogger 133 | -------------------------------------------------------------------------------- /envlogger/backends/cc/riegeli_dataset_writer.h: -------------------------------------------------------------------------------- 1 | // Copyright 2024 DeepMind Technologies Limited.. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #ifndef THIRD_PARTY_PY_ENVLOGGER_BACKENDS_CC_RIEGELI_DATASET_WRITER_H_ 16 | #define THIRD_PARTY_PY_ENVLOGGER_BACKENDS_CC_RIEGELI_DATASET_WRITER_H_ 17 | 18 | #include 19 | #include 20 | #include 21 | #include 22 | 23 | #include "google/protobuf/message.h" 24 | #include "absl/status/status.h" 25 | #include "absl/strings/string_view.h" 26 | #include "envlogger/backends/cc/episode_info.h" 27 | #include "envlogger/backends/cc/riegeli_shard_reader.h" 28 | #include "envlogger/backends/cc/riegeli_shard_writer.h" 29 | #include "envlogger/proto/storage.pb.h" 30 | 31 | namespace envlogger { 32 | 33 | // Automates creating trajectories that can be efficiently read from disk by 34 | // RiegeliDatasetReader. 35 | class RiegeliDatasetWriter { 36 | public: 37 | RiegeliDatasetWriter() = default; 38 | 39 | // Initializes this writer to the following `data_dir`. 40 | // `metadata` is a client-specific payload. 41 | // `max_episodes_per_shard` determines the maximum number of episodes a single 42 | // RiegeliShardWriter shard will hold. If non-positive, a single shard file 43 | // will hold all steps and episodes. 44 | // `episode_counter` sets the current episode that this writer is processing. 45 | // It influences the shards that are created if `max_episodes_per_shard` is 46 | // positive. 47 | // 48 | // IMPORTANT: `data_dir` MUST exist _before_ calling Init(). 49 | absl::Status Init( 50 | std::string data_dir, const Data& metadata = Data(), 51 | int64_t max_episodes_per_shard = 0, 52 | std::string writer_options = "transpose,brotli:6,chunk_size:1M", 53 | int episode_counter = 0); 54 | 55 | // Adds `data` to the trajectory and marks it as a new episode if 56 | // `is_new_episode==true`. 57 | // Returns true if `data` has been written, false otherwise. 58 | bool AddStep(const google::protobuf::Message& data, bool is_new_episode = false); 59 | 60 | // Sets episodic metadata for the _current_ episode. 61 | // This can be called multiple times but the value will be written only when a 62 | // new episode comes in or when this writer is about to be destructed. 63 | // Notice that calling this before an `AddStep(..., /*is_new_episode=*/true)` 64 | // is called leads to this writer ignoring the `data` that's passed. 65 | void SetEpisodeMetadata(const Data& data); 66 | 67 | void Flush(); 68 | 69 | // Finalizes all writes and releases all handles. 70 | void Close(); 71 | 72 | // Const accessors to internal state. 73 | std::string DataDir() const { return data_dir_; } 74 | std::string WriterOptions() const { return writer_options_; } 75 | int64_t MaxEpisodesPerShard() const { return max_episodes_per_shard_; } 76 | int EpisodeCounter() const { return episode_counter_; } 77 | 78 | private: 79 | std::string data_dir_; 80 | std::string writer_options_; 81 | int64_t max_episodes_per_shard_ = 0; 82 | int episode_counter_ = 0; 83 | RiegeliShardWriter writer_; 84 | }; 85 | 86 | } // namespace envlogger 87 | 88 | #endif // THIRD_PARTY_PY_ENVLOGGER_BACKENDS_CC_RIEGELI_DATASET_WRITER_H_ 89 | -------------------------------------------------------------------------------- /envlogger/backends/cc/riegeli_shard_reader.h: -------------------------------------------------------------------------------- 1 | // Copyright 2024 DeepMind Technologies Limited.. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #ifndef THIRD_PARTY_PY_ENVLOGGER_BACKENDS_CC_RIEGELI_SHARD_READER_H_ 16 | #define THIRD_PARTY_PY_ENVLOGGER_BACKENDS_CC_RIEGELI_SHARD_READER_H_ 17 | 18 | #include 19 | #include 20 | #include 21 | #include 22 | #include 23 | 24 | #include "glog/logging.h" 25 | #include "google/protobuf/message.h" 26 | #include "absl/status/status.h" 27 | #include "absl/status/statusor.h" 28 | #include "absl/strings/str_cat.h" 29 | #include "absl/strings/string_view.h" 30 | #include "envlogger/backends/cc/episode_info.h" 31 | #include "envlogger/platform/riegeli_file_reader.h" 32 | #include "envlogger/proto/storage.pb.h" 33 | #include "riegeli/base/object.h" 34 | #include "riegeli/records/record_reader.h" 35 | 36 | namespace envlogger { 37 | 38 | // A RiegeliShardReader refers to a single Riegeli index file (usually named 39 | // "index.riegeli") that's associated with a Riegeli trajectory file (usually 40 | // named "trajectories.riegeli"). The index file is an external index into the 41 | // trajectory so that accesses have minimal latency. 42 | // 43 | // This class reads the data from one of these index files and returns 44 | // information for accessing individual steps and episodes in the trajectories 45 | // file. 46 | // RiegeliShardReader is not thread-safe (i.e. multiple threads should not 47 | // concurrently call `Step()` or `Episode()`). However, it is possible to create 48 | // cheap copies of RiegeliShardReader using the Clone function, creating a new 49 | // reader which will share with the original reader the relatively expensive 50 | // in-memory index but will contain their individual Riegeli file handlers. 51 | // These copies can then be safely passed to threads. 52 | // 53 | // Note on nomenclature: 54 | // - The term "index" when used in an array refers to the position in the 55 | // array and is always 0-based throughout this library. For example, in the 56 | // array [13, 17, 24, 173] the element 13 is at index 0, 17 at index 1, 24 at 57 | // index 2 and 173 at index 3. 58 | // - The term "offset" refers to a file offset in bytes. This is the unit used 59 | // by Riegeli to position its reading head to read a particular record. 60 | class RiegeliShardReader { 61 | public: 62 | RiegeliShardReader() = default; 63 | 64 | RiegeliShardReader(RiegeliShardReader&&) = default; 65 | RiegeliShardReader& operator=(RiegeliShardReader&&) = default; 66 | 67 | ~RiegeliShardReader(); 68 | 69 | // Reads trajectory data written by RiegeliShardWriter. 70 | // 71 | // If `step_offsets_filepath` is empty the constructor will return early 72 | // leaving the object members empty. All methods should fail if called on a 73 | // object in this state. This is used by clients for efficiently initializing 74 | // objects to minimize memory allocations by preallocating everything 75 | // beforehand. 76 | absl::Status Init(absl::string_view steps_filepath, 77 | absl::string_view step_offsets_filepath, 78 | absl::string_view episode_metadata_filepath, 79 | absl::string_view episode_index_filepath); 80 | 81 | // Clones a RiegeliShardReader. 82 | // The cloned reader owns its own file handles but shares the ShardData 83 | // with the original reader. 84 | // The cloned reader can safely be used in a different thread. 85 | absl::StatusOr Clone(); 86 | 87 | // Returns the number of steps indexed by this RiegeliShardReader. 88 | int64_t NumSteps() const; 89 | 90 | // Returns the number of episodes indexed by this RiegeliShardReader. 91 | int64_t NumEpisodes() const; 92 | 93 | // Returns step data at index `step_index`. 94 | // Returns nullopt if step_index is not in [0, NumSteps()). 95 | template 96 | std::optional Step(int64_t step_index); 97 | 98 | // Returns information for accessing a specific episode. 99 | // Returns nullopt if `episode_index` is not in [0, NumEpisodes()). 100 | std::optional Episode(int64_t episode_index, 101 | bool include_metadata = false); 102 | 103 | // Releases resources and closes the underlying files. 104 | void Close(); 105 | 106 | private: 107 | struct ShardData { 108 | std::string steps_filepath; 109 | std::string episode_metadata_filepath; 110 | 111 | // Riegeli offsets for quickly accessing steps. 112 | std::vector step_offsets; 113 | 114 | // Riegeli offsets for quickly accessing episodic metadata. 115 | std::vector episode_metadata_offsets; 116 | 117 | // The first steps of each episodes. 118 | // Episode index -> Step index. 119 | std::vector episode_starts; 120 | }; 121 | 122 | // Structure of the shard. 123 | // Note that all instances created with Clone() will share that information. 124 | std::shared_ptr shard_; 125 | 126 | riegeli::RecordReader steps_reader_{riegeli::kClosed}; 127 | riegeli::RecordReader episode_metadata_reader_{ 128 | riegeli::kClosed}; 129 | }; 130 | 131 | //////////////////////////////////////////////////////////////////////////////// 132 | // Implementation details of RiegeliShardReader::Step. 133 | //////////////////////////////////////////////////////////////////////////////// 134 | 135 | template 136 | std::optional RiegeliShardReader::Step(int64_t step_index) { 137 | const auto& step_offsets = shard_->step_offsets; 138 | if (step_index < 0 || 139 | step_index >= static_cast(step_offsets.size())) { 140 | return std::nullopt; 141 | } 142 | 143 | const int64_t offset = step_offsets[step_index]; 144 | if (!steps_reader_.Seek(offset)) { 145 | VLOG(0) << absl::StrCat("Failed to seek to offset ", offset, 146 | " status: ", steps_reader_.status().ToString()); 147 | return std::nullopt; 148 | } 149 | T data; 150 | const bool read_status = steps_reader_.ReadRecord(data); 151 | if (!read_status) return std::nullopt; 152 | 153 | return data; 154 | } 155 | 156 | } // namespace envlogger 157 | 158 | #endif // THIRD_PARTY_PY_ENVLOGGER_BACKENDS_CC_RIEGELI_SHARD_READER_H_ 159 | -------------------------------------------------------------------------------- /envlogger/backends/cc/riegeli_shard_writer.h: -------------------------------------------------------------------------------- 1 | // Copyright 2024 DeepMind Technologies Limited.. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #ifndef THIRD_PARTY_PY_ENVLOGGER_BACKENDS_CC_RIEGELI_SHARD_WRITER_H_ 16 | #define THIRD_PARTY_PY_ENVLOGGER_BACKENDS_CC_RIEGELI_SHARD_WRITER_H_ 17 | 18 | #include 19 | #include 20 | #include 21 | 22 | #include "google/protobuf/message.h" 23 | #include "absl/status/status.h" 24 | #include "absl/strings/string_view.h" 25 | #include "envlogger/backends/cc/episode_info.h" 26 | #include "envlogger/platform/riegeli_file_writer.h" 27 | #include "envlogger/proto/storage.pb.h" 28 | #include "riegeli/base/object.h" 29 | #include "riegeli/records/record_writer.h" 30 | 31 | namespace envlogger { 32 | 33 | // Creates Riegeli index files to allow for efficient record lookup from another 34 | // Riegeli file (i.e. an external index). 35 | class RiegeliShardWriter { 36 | public: 37 | RiegeliShardWriter() = default; 38 | ~RiegeliShardWriter(); 39 | 40 | // IMPORTANT: The directory where these file live MUST exist _before_ calling 41 | // Init(). 42 | // 43 | // steps_filepath: 44 | // Path to riegeli file that'll contain the actual trajectory data as Data 45 | // objects, including timesteps (observations, rewards, discounts), actions 46 | // and step metadata. 47 | // Each riegeli entry corresponds to a single step. 48 | // Each entry can be seeked in O(1) time via an external index 49 | // (step_offsets_filepath). 50 | // step_offsets_filepath: 51 | // An external index into steps_filepath. This allows clients to access any 52 | // individual step in O(1). Riegeli entries are Datum objects. 53 | // episode_metadata_filepath: 54 | // Path to riegeli file that'll contain episodic metadata (e.g. discounted 55 | // returns, summarized data etc). Riegeli entries are Data objects. 56 | // Each entry corresponds to the metadata of a single episode. 57 | // Metadata is stored only for episodes that have it (i.e. if an episode 58 | // does not store any metadata, nothing is stored here). 59 | // Each entry can be seeked in O(1) time via an external index 60 | // (episode_index_filepath). 61 | // episode_index_filepath: 62 | // Path to a riegeli file containing (|Episodes| x 2) tensors where: 63 | // * 1st dim is a step index indicating the start of the episode. 64 | // * 2nd dim is a riegeli offset for optional metadata (for looking up into 65 | // episode_metadata_filepath). -1 indicates that no metadata exists for a 66 | // specific episode. 67 | // Each riegeli entry is a Datum object. 68 | absl::Status Init(absl::string_view steps_filepath, 69 | absl::string_view step_offsets_filepath, 70 | absl::string_view episode_metadata_filepath, 71 | absl::string_view episode_index_filepath, 72 | absl::string_view writer_options); 73 | 74 | // Adds `data` to the trajectory and marks it as a new episode if 75 | // `is_new_episode==true`. 76 | // Returns true if `data` has been written, false otherwise. 77 | bool AddStep(const google::protobuf::Message& data, bool is_new_episode = false); 78 | 79 | // Sets episodic metadata for the _current_ episode. 80 | // This can be called multiple times but the value will be written only when a 81 | // new episode comes in or when this writer is about to be destructed. 82 | // Notice that calling this before an `AddStep(..., /*is_new_episode=*/true)` 83 | // is called leads to this writer ignoring the `data` that's passed. 84 | void SetEpisodeMetadata(const Data& data); 85 | 86 | // Flushes the index to disk. 87 | void Flush(); 88 | 89 | // Finalizes all writes and releases all handles. 90 | void Close(); 91 | 92 | private: 93 | // The number of steps in the last flush. 94 | int num_steps_at_flush_ = 0; 95 | 96 | std::vector step_offsets_; 97 | 98 | // The first steps of each episodes. 99 | // Episode index -> Step index. 100 | std::vector episode_starts_; 101 | // The riegeli offset into `episode_metadata_writer_`. 102 | std::vector episode_offsets_; 103 | 104 | // Metadata for the _current_ episode. 105 | std::optional episode_metadata_; 106 | 107 | // Steps, episodes and their riegeli numeric offsets. 108 | riegeli::RecordWriter steps_writer_{riegeli::kClosed}; 109 | riegeli::RecordWriter step_offsets_writer_{ 110 | riegeli::kClosed}; 111 | riegeli::RecordWriter episode_metadata_writer_{ 112 | riegeli::kClosed}; 113 | riegeli::RecordWriter episode_index_writer_{ 114 | riegeli::kClosed}; 115 | 116 | }; 117 | 118 | } // namespace envlogger 119 | 120 | #endif // THIRD_PARTY_PY_ENVLOGGER_BACKENDS_CC_RIEGELI_SHARD_WRITER_H_ 121 | -------------------------------------------------------------------------------- /envlogger/backends/cross_language_test/BUILD.bazel: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited.. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Tests for reading/writing from different languages. 16 | 17 | load("@rules_python//python:defs.bzl", "py_binary", "py_test") 18 | 19 | cc_binary( 20 | name = "cc_reader", 21 | testonly = 1, 22 | srcs = ["cc_reader.cc"], 23 | deps = [ 24 | "//envlogger/backends/cc:riegeli_dataset_reader", 25 | "//envlogger/converters:xtensor_codec", 26 | "//envlogger/platform:proto_testutil", 27 | "//envlogger/proto:storage_cc_proto", 28 | "@com_github_google_glog//:glog", 29 | "@com_google_absl//absl/flags:flag", 30 | "@com_google_absl//absl/flags:parse", 31 | "@com_google_googletest//:gtest", 32 | "@gmp", 33 | "@xtensor", 34 | ], 35 | ) 36 | 37 | py_binary( 38 | name = "py_writer", 39 | srcs = ["py_writer.py"], 40 | deps = [ 41 | "//envlogger", 42 | ], 43 | ) 44 | 45 | py_test( 46 | name = "cross_language_test", 47 | srcs = ["cross_language_test.py"], 48 | data = [ 49 | ":cc_reader", 50 | ":py_writer", 51 | ], 52 | deps = [ 53 | "@rules_python//python/runfiles", 54 | ], 55 | ) 56 | -------------------------------------------------------------------------------- /envlogger/backends/cross_language_test/cc_reader.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2024 DeepMind Technologies Limited.. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #include 16 | #include 17 | #include 18 | #include 19 | 20 | #include "absl/flags/parse.h" 21 | #include "gmock/gmock.h" 22 | #include "gtest/gtest.h" 23 | #include "absl/flags/flag.h" 24 | #include 25 | #include "envlogger/backends/cc/riegeli_dataset_reader.h" 26 | #include "envlogger/converters/xtensor_codec.h" 27 | #include "envlogger/platform/proto_testutil.h" 28 | #include "envlogger/proto/storage.pb.h" 29 | #include "xtensor/xarray.hpp" 30 | 31 | ABSL_FLAG(std::string, trajectories_dir, "", "Path to reader trajectory."); 32 | 33 | using ::testing::DoubleEq; 34 | using ::testing::Eq; 35 | using ::testing::FloatEq; 36 | using ::testing::IsTrue; 37 | using ::testing::SizeIs; 38 | 39 | int main(int argc, char** argv) { 40 | absl::ParseCommandLine(argc, argv); 41 | 42 | VLOG(0) << "Starting C++ Reader..."; 43 | VLOG(0) << "--trajectories_dir: " << absl::GetFlag(FLAGS_trajectories_dir); 44 | envlogger::RiegeliDatasetReader reader; 45 | const absl::Status init_status = 46 | reader.Init(absl::GetFlag(FLAGS_trajectories_dir)); 47 | VLOG(0) << "init_status: " << init_status; 48 | 49 | VLOG(0) << "reader.NumSteps(): " << reader.NumSteps(); 50 | for (int64_t i = 0; i < reader.NumSteps(); ++i) { 51 | std::optional step = reader.Step(i); 52 | EXPECT_THAT(step.has_value(), IsTrue()) 53 | << "All steps should be readable. Step " << i << " is not available."; 54 | 55 | envlogger::DataView step_view(std::addressof(*step)); 56 | EXPECT_THAT(step_view->value_case(), Eq(envlogger::Data::ValueCase::kTuple)) 57 | << "Each step should be a tuple."; 58 | EXPECT_THAT(step_view, SizeIs(3)) 59 | << "Each step should consist of (timestep, action, custom data)"; 60 | const envlogger::Data& timestep = *step_view[0]; 61 | const envlogger::Data& action = *step_view[1]; 62 | const envlogger::Data& custom_data = *step_view[2]; 63 | VLOG(1) << "timestep: " << timestep.ShortDebugString(); 64 | VLOG(1) << "action: " << action.ShortDebugString(); 65 | VLOG(1) << "custom_data: " << custom_data.ShortDebugString(); 66 | envlogger::DataView timestep_view(×tep); 67 | EXPECT_THAT(timestep_view->value_case(), 68 | Eq(envlogger::Data::ValueCase::kTuple)) 69 | << "Each timestep should be a tuple."; 70 | EXPECT_THAT(timestep_view, SizeIs(4)) 71 | << "Each timestep should consist of (step type, reward, discount, " 72 | "observation)"; 73 | 74 | // Check timestep values. 75 | // Check step type. 76 | const envlogger::Data& step_type = *timestep_view[0]; 77 | VLOG(2) << "step_type: " << step_type.ShortDebugString(); 78 | std::optional decoded_step_type = 79 | envlogger::Decode(step_type.datum()); 80 | EXPECT_THAT(decoded_step_type.has_value(), IsTrue()) 81 | << "Failed to decode step_type"; 82 | EXPECT_THAT(absl::holds_alternative(*decoded_step_type), 83 | IsTrue()); 84 | const mpz_class step_type_decoded = 85 | absl::get(*decoded_step_type); 86 | EXPECT_THAT(cmp(step_type_decoded, i ? 1 : 0), Eq(0)); 87 | // Check reward. 88 | const envlogger::Data& reward = *timestep_view[1]; 89 | VLOG(2) << "reward: " << reward.ShortDebugString(); 90 | std::optional decoded_reward = 91 | envlogger::Decode(reward.datum()); 92 | EXPECT_THAT(decoded_reward.has_value(), IsTrue()) 93 | << "Failed to decode reward"; 94 | EXPECT_THAT(absl::holds_alternative(*decoded_reward), IsTrue()); 95 | const double r = absl::get(*decoded_reward); 96 | EXPECT_THAT(r, DoubleEq(i / 100.0)); 97 | // Check discount. 98 | const envlogger::Data& discount = *timestep_view[2]; 99 | VLOG(2) << "discount: " << discount.ShortDebugString(); 100 | std::optional decoded_discount = 101 | envlogger::Decode(discount.datum()); 102 | EXPECT_THAT(decoded_discount.has_value(), IsTrue()) 103 | << "Failed to decode discount"; 104 | EXPECT_THAT(absl::holds_alternative(*decoded_discount), IsTrue()); 105 | const double gamma = absl::get(*decoded_discount); 106 | EXPECT_THAT(gamma, DoubleEq(0.99)); 107 | // Check observation. 108 | const envlogger::Data& observation = *timestep_view[3]; 109 | VLOG(2) << "observation: " << observation.ShortDebugString(); 110 | std::optional decoded_obs = 111 | envlogger::Decode(observation.datum()); 112 | EXPECT_THAT(decoded_obs.has_value(), IsTrue()) 113 | << "Failed to decode observation"; 114 | EXPECT_THAT(absl::holds_alternative>(*decoded_obs), 115 | IsTrue()); 116 | const xt::xarray& obs = absl::get>(*decoded_obs); 117 | EXPECT_THAT(obs, SizeIs(1)); 118 | EXPECT_THAT(obs(0), FloatEq(i)); 119 | 120 | // Check action. 121 | std::optional decoded_action = 122 | envlogger::Decode(action.datum()); 123 | EXPECT_THAT(decoded_action.has_value(), IsTrue()) 124 | << "Failed to decode action"; 125 | EXPECT_THAT(absl::holds_alternative(*decoded_action), IsTrue()); 126 | const int a = absl::get(*decoded_action); 127 | EXPECT_THAT(a, Eq(100 - i)); 128 | 129 | // There should be no custom data, but it should still be a valid pointer. 130 | EXPECT_THAT(custom_data, EqualsProto(envlogger::Data())); 131 | } 132 | 133 | reader.Close(); 134 | 135 | return 0; 136 | } 137 | -------------------------------------------------------------------------------- /envlogger/backends/cross_language_test/cross_language_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 DeepMind Technologies Limited.. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests writing trajectories in one language and reading from another.""" 17 | 18 | from collections.abc import Sequence 19 | import os 20 | import shutil 21 | import subprocess 22 | 23 | from absl import logging 24 | from absl.testing import absltest 25 | from absl.testing import parameterized 26 | 27 | from rules_python.python.runfiles import runfiles 28 | 29 | 30 | def _execute_binary(rel_path: str, args: Sequence[str]) -> bytes: 31 | r = runfiles.Create() 32 | path = r.Rlocation(os.path.join('__main__', 'envlogger', rel_path)) 33 | cmd = [path] + args 34 | return subprocess.check_output(cmd, env=r.EnvVars()) 35 | 36 | 37 | class CrossLanguageTest(parameterized.TestCase): 38 | 39 | def test_py_writer_cc_reader(self): 40 | # Set up a trajectory directory. 41 | trajectories_dir = os.path.join(absltest.TEST_TMPDIR.value, 'my_trajectory') 42 | logging.info('trajectories_dir: %r', trajectories_dir) 43 | os.makedirs(trajectories_dir) 44 | 45 | # Find Python writer and run it. 46 | py_writer_output = _execute_binary( 47 | 'backends/cross_language_test/py_writer', 48 | args=[f'--trajectories_dir={trajectories_dir}']) 49 | logging.info('py_writer_output: %r', py_writer_output) 50 | 51 | # Find C++ reader and run it. 52 | cc_reader_output = _execute_binary( 53 | 'backends/cross_language_test/cc_reader', 54 | args=[f'--trajectories_dir={trajectories_dir}']) 55 | logging.info('cc_reader_output: %r', cc_reader_output) 56 | 57 | # If everything went well, there should be no 58 | # `subprocess.CalledProcessError`. 59 | 60 | logging.info('Cleaning up trajectories_dir %r', trajectories_dir) 61 | shutil.rmtree(trajectories_dir) 62 | 63 | 64 | if __name__ == '__main__': 65 | absltest.main() 66 | -------------------------------------------------------------------------------- /envlogger/backends/cross_language_test/py_writer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 DeepMind Technologies Limited.. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """A simple python binary that creates a simple RL trajectory.""" 17 | 18 | from collections.abc import Sequence 19 | 20 | from absl import app 21 | from absl import flags 22 | from absl import logging 23 | import dm_env 24 | import envlogger 25 | import numpy as np 26 | 27 | _TRAJECTORIES_DIR = flags.DEFINE_string( 28 | 'trajectories_dir', None, 'Path to write trajectory.', required=True) 29 | 30 | 31 | def main(argv: Sequence[str]) -> None: 32 | if len(argv) > 1: 33 | raise app.UsageError('Too many command-line arguments.') 34 | 35 | logging.info('Starting Python-based writer...') 36 | logging.info('--trajectories_dir: %r', _TRAJECTORIES_DIR.value) 37 | 38 | writer = envlogger.RiegeliBackendWriter( 39 | data_directory=_TRAJECTORIES_DIR.value, metadata={'my_data': [1, 2, 3]}) 40 | writer.record_step( 41 | envlogger.StepData( 42 | timestep=dm_env.TimeStep( 43 | observation=np.array([0.0], dtype=np.float32), 44 | reward=0.0, 45 | discount=0.99, 46 | step_type=dm_env.StepType.FIRST), 47 | action=np.int32(100)), 48 | is_new_episode=True) 49 | for i in range(1, 100): 50 | writer.record_step( 51 | envlogger.StepData( 52 | timestep=dm_env.TimeStep( 53 | observation=np.array([float(i)], dtype=np.float32), 54 | reward=i / 100.0, 55 | discount=0.99, 56 | step_type=dm_env.StepType.MID), 57 | action=np.int32(100 - i)), 58 | is_new_episode=False) 59 | writer.close() 60 | 61 | 62 | if __name__ == '__main__': 63 | app.run(main) 64 | -------------------------------------------------------------------------------- /envlogger/backends/in_memory_backend.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 DeepMind Technologies Limited.. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Environment logger backend that stores all data in RAM. 17 | """ 18 | 19 | import copy 20 | from typing import Any 21 | 22 | from envlogger import step_data 23 | from envlogger.backends import backend_reader 24 | from envlogger.backends import backend_writer 25 | from envlogger.backends.python import episode_info 26 | 27 | 28 | 29 | class InMemoryBackendWriter(backend_writer.BackendWriter): 30 | """Backend that stores trajectory data in memory.""" 31 | 32 | def __init__(self, **base_kwargs): 33 | super().__init__(**base_kwargs) 34 | self.steps = [] 35 | self.episode_metadata = {} 36 | self.episode_start_indices = [] 37 | 38 | def _record_step(self, data: step_data.StepData, 39 | is_new_episode: bool) -> None: 40 | if is_new_episode: 41 | self.episode_start_indices.append(len(self.steps)) 42 | self.steps.append(data) 43 | 44 | def set_episode_metadata(self, data: Any) -> None: 45 | current_episode = len(self.episode_start_indices) 46 | if current_episode > 0: 47 | self.episode_metadata[current_episode] = data 48 | 49 | def close(self) -> None: 50 | pass 51 | 52 | 53 | class InMemoryBackendReader(backend_reader.BackendReader): 54 | """Reader that reads data from an InMemoryBackend.""" 55 | 56 | def __init__(self, in_memory_backend_writer: InMemoryBackendWriter): 57 | self._backend = in_memory_backend_writer 58 | super().__init__() 59 | 60 | def _copy(self) -> 'InMemoryBackendReader': 61 | return copy.deepcopy(self) 62 | 63 | def close(self) -> None: 64 | pass 65 | 66 | def _get_nth_step(self, i: int) -> step_data.StepData: 67 | return self._backend.steps[i] 68 | 69 | def _get_nth_episode_info(self, 70 | i: int, 71 | include_metadata: bool = False 72 | ) -> episode_info.EpisodeInfo: 73 | if i == len(self._backend.episode_start_indices) - 1: # Last episode. 74 | length = len(self._backend.steps) - self._backend.episode_start_indices[i] 75 | else: 76 | length = (self._backend.episode_start_indices[i + 1] - 77 | self._backend.episode_start_indices[i]) 78 | episode_metadata = self._backend.episode_metadata.get(i, None) 79 | return episode_info.EpisodeInfo( 80 | start=self._backend.episode_start_indices[i], 81 | num_steps=length, 82 | metadata=episode_metadata) 83 | 84 | def _get_num_steps(self) -> int: 85 | return len(self._backend.steps) 86 | 87 | def _get_num_episodes(self) -> int: 88 | return len(self._backend.episode_start_indices) 89 | 90 | def metadata(self): 91 | return self._backend.metadata() 92 | -------------------------------------------------------------------------------- /envlogger/backends/python/BUILD.bazel: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited.. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | load("@pybind11_bazel//:build_defs.bzl", "pybind_extension") 16 | 17 | package(default_visibility = ["//envlogger:__subpackages__"]) 18 | 19 | pybind_extension( 20 | name = "riegeli_dataset_reader", 21 | srcs = ["riegeli_dataset_reader.cc"], 22 | deps = [ 23 | "//envlogger/backends/cc:riegeli_dataset_reader", 24 | "//envlogger/proto:storage_cc_proto", 25 | "@com_google_riegeli//riegeli/bytes:string_writer", 26 | "@com_google_riegeli//riegeli/endian:endian_writing", 27 | "@pybind11_protobuf//pybind11_protobuf:proto_casters", 28 | ], 29 | ) 30 | 31 | pybind_extension( 32 | name = "riegeli_dataset_writer", 33 | srcs = ["riegeli_dataset_writer.cc"], 34 | deps = [ 35 | "//envlogger/backends/cc:riegeli_dataset_writer", 36 | "//envlogger/proto:storage_cc_proto", 37 | "@com_google_riegeli//riegeli/bytes:string_writer", 38 | "@com_google_riegeli//riegeli/endian:endian_writing", 39 | "@pybind11_protobuf//pybind11_protobuf:proto_casters", 40 | ], 41 | ) 42 | 43 | pybind_extension( 44 | name = "episode_info", 45 | srcs = ["episode_info.cc"], 46 | deps = [ 47 | "//envlogger/backends/cc:episode_info", 48 | "//envlogger/proto:storage_cc_proto", 49 | "@pybind11_protobuf//pybind11_protobuf:proto_casters", 50 | ], 51 | ) 52 | 53 | py_test( 54 | name = "episode_info_test", 55 | srcs = ["episode_info_test.py"], 56 | data = [":episode_info.so"], 57 | deps = [ 58 | "//envlogger/proto:storage_py_pb2", 59 | ], 60 | ) 61 | -------------------------------------------------------------------------------- /envlogger/backends/python/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 DeepMind Technologies Limited.. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /envlogger/backends/python/episode_info.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2024 DeepMind Technologies Limited.. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #include "envlogger/backends/cc/episode_info.h" 16 | 17 | #include 18 | #include 19 | 20 | #include "envlogger/proto/storage.pb.h" 21 | #include "pybind11//pybind11.h" 22 | #include "pybind11//stl.h" 23 | #include "pybind11_protobuf/proto_casters.h" 24 | 25 | PYBIND11_MODULE(episode_info, m) { 26 | pybind11::google::ImportProtoModule(); 27 | 28 | m.doc() = "EpisodeInfo bindings."; 29 | 30 | pybind11::class_(m, "EpisodeInfo") 31 | .def(pybind11::init>(), 32 | pybind11::arg("start") = 0, pybind11::arg("num_steps") = 0, 33 | pybind11::arg("metadata") = std::nullopt) 34 | .def_readwrite("start", &envlogger::EpisodeInfo::start) 35 | .def_readwrite("num_steps", &envlogger::EpisodeInfo::num_steps) 36 | .def_readwrite("metadata", &envlogger::EpisodeInfo::metadata); 37 | } 38 | -------------------------------------------------------------------------------- /envlogger/backends/python/episode_info_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 DeepMind Technologies Limited.. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for episode_info.cc.""" 17 | 18 | import random 19 | 20 | from absl.testing import absltest 21 | from absl.testing import parameterized 22 | from envlogger.backends.python import episode_info 23 | from envlogger.proto import storage_pb2 24 | 25 | 26 | class EpisodeInfoTest(parameterized.TestCase): 27 | 28 | def test_empty_episode_info(self): 29 | episode = episode_info.EpisodeInfo() 30 | self.assertEqual(episode.start, 0) 31 | self.assertEqual(episode.num_steps, 0) 32 | self.assertIsNone(episode.metadata) 33 | 34 | def test_episode_info_init_with_random_kwargs(self): 35 | random_starts = [random.randint(-1, 10000) for _ in range(100)] 36 | random_num_steps = [random.randint(-1, 10000) for _ in range(100)] 37 | random_metadata = [] 38 | 39 | dimension = storage_pb2.Datum.Shape.Dim() 40 | dimension.size = -438 41 | for _ in range(100): 42 | metadata = storage_pb2.Data() 43 | metadata.datum.shape.dim.append(dimension) 44 | metadata.datum.values.int32_values.append(random.randint(-1, 10000)) 45 | random_metadata.append(metadata) 46 | 47 | for start, num_steps, metadata in zip(random_starts, random_num_steps, 48 | random_metadata): 49 | episode = episode_info.EpisodeInfo( 50 | start=start, num_steps=num_steps, metadata=metadata) 51 | self.assertEqual(episode.start, start) 52 | self.assertEqual(episode.num_steps, num_steps) 53 | self.assertSequenceEqual(episode.metadata.datum.values.int32_values, 54 | metadata.datum.values.int32_values) 55 | 56 | 57 | if __name__ == '__main__': 58 | absltest.main() 59 | -------------------------------------------------------------------------------- /envlogger/backends/python/riegeli_dataset_writer.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2024 DeepMind Technologies Limited.. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #include "envlogger/backends/cc/riegeli_dataset_writer.h" 16 | 17 | #include 18 | #include 19 | #include 20 | 21 | #include "envlogger/proto/storage.pb.h" 22 | #include "pybind11//pybind11.h" 23 | #include "pybind11//stl.h" 24 | #include "pybind11_protobuf/proto_casters.h" 25 | #include "riegeli/bytes/string_writer.h" 26 | #include "riegeli/endian/endian_writing.h" 27 | 28 | namespace { 29 | 30 | // This traverses the proto to convert repeated 'float_values' into big-endian 31 | // byte arrays 'float_values_buffer'. In python this allows using 32 | // np.frombuffer(float_values_buffer) which is overall 2.3x more efficient than 33 | // building the np array from the repeated field (this includes the cost of 34 | // doing the conversion in C++). 35 | // In the future, we could do the same for other data types stored in repeated 36 | // fields. 37 | void OptimizeDataProto(envlogger::Data* data) { 38 | switch (data->value_case()) { 39 | case envlogger::Data::kDatum: { 40 | auto* datum = data->mutable_datum(); 41 | if (!datum->values().float_values().empty()) { 42 | riegeli::StringWriter writer( 43 | datum->mutable_values()->mutable_float_values_buffer()); 44 | writer.SetWriteSizeHint(datum->values().float_values_size() * 45 | sizeof(float)); 46 | riegeli::WriteBigEndianFloats(datum->values().float_values(), writer); 47 | writer.Close(); 48 | datum->mutable_values()->clear_float_values(); 49 | } 50 | break; 51 | } 52 | case envlogger::Data::kArray: 53 | for (auto& value : *data->mutable_array()->mutable_values()) { 54 | OptimizeDataProto(&value); 55 | } 56 | break; 57 | case envlogger::Data::kTuple: 58 | for (auto& value : *data->mutable_tuple()->mutable_values()) { 59 | OptimizeDataProto(&value); 60 | } 61 | break; 62 | case envlogger::Data::kDict: 63 | for (auto& value : *data->mutable_dict()->mutable_values()) { 64 | OptimizeDataProto(&value.second); 65 | } 66 | break; 67 | case envlogger::Data::VALUE_NOT_SET: 68 | break; 69 | } 70 | } 71 | } // namespace 72 | 73 | PYBIND11_MODULE(riegeli_dataset_writer, m) { 74 | pybind11::google::ImportProtoModule(); 75 | pybind11::module::import("envlogger.backends.python.episode_info"); 76 | 77 | m.doc() = "RiegeliDatasetWriter bindings."; 78 | 79 | pybind11::class_(m, "RiegeliDatasetWriter") 80 | .def(pybind11::init<>()) 81 | // Initializes the writer with the given arguments. 82 | // If successful, `void` is returned with no side effects. Otherwise a 83 | // `RuntimeError` is raised with an accompanying message. 84 | // Note: `absl::Status` isn't used because there are incompatibilities 85 | // between slightly different versions of `pybind11_abseil` when used with 86 | // different projects. Please see 87 | // https://github.com/deepmind/envlogger/issues/3 for more details. 88 | .def( 89 | "init", 90 | [](envlogger::RiegeliDatasetWriter* self, std::string data_dir, 91 | const envlogger::Data& metadata = envlogger::Data(), 92 | int64_t max_episodes_per_shard = 0, 93 | std::string writer_options = 94 | "transpose,brotli:6,chunk_size:1M") -> void { 95 | const absl::Status status = self->Init( 96 | data_dir, metadata, max_episodes_per_shard, writer_options); 97 | if (!status.ok()) throw std::runtime_error(status.ToString()); 98 | }, 99 | pybind11::arg("data_dir"), 100 | pybind11::arg("metadata") = envlogger::Data(), 101 | pybind11::arg("max_episodes_per_shard") = 0, 102 | pybind11::arg("writer_options") = "transpose,brotli:6,chunk_size:1M") 103 | .def("add_step", &envlogger::RiegeliDatasetWriter::AddStep, 104 | pybind11::arg("data"), pybind11::arg("is_new_episode") = false) 105 | .def("set_episode_metadata", 106 | &envlogger::RiegeliDatasetWriter::SetEpisodeMetadata, 107 | pybind11::arg("data")) 108 | .def("flush", &envlogger::RiegeliDatasetWriter::Flush) 109 | .def("close", &envlogger::RiegeliDatasetWriter::Close) 110 | // Accessors. 111 | .def("data_dir", &envlogger::RiegeliDatasetWriter::DataDir) 112 | .def("max_episodes_per_shard", 113 | &envlogger::RiegeliDatasetWriter::MaxEpisodesPerShard) 114 | .def("writer_options", &envlogger::RiegeliDatasetWriter::WriterOptions) 115 | .def("episode_counter", &envlogger::RiegeliDatasetWriter::EpisodeCounter) 116 | // Pickling support. 117 | .def(pybind11::pickle( 118 | [](const envlogger::RiegeliDatasetWriter& self) { // __getstate__(). 119 | pybind11::dict output; 120 | output["data_dir"] = self.DataDir(); 121 | output["max_episodes_per_shard"] = self.MaxEpisodesPerShard(); 122 | output["writer_options"] = self.WriterOptions(); 123 | output["episode_counter_"] = self.EpisodeCounter(); 124 | return output; 125 | }, 126 | [](pybind11::dict d) { // __setstate__(). 127 | const std::string data_dir = d["data_dir"].cast(); 128 | const int64_t max_episodes_per_shard = 129 | d["max_episodes_per_shard"].cast(); 130 | const std::string writer_options = 131 | d["writer_options"].cast(); 132 | 133 | auto writer = std::make_unique(); 134 | const absl::Status status = writer->Init( 135 | /*data_dir=*/data_dir, /*metadata=*/envlogger::Data(), 136 | /*max_episodes_per_shard=*/max_episodes_per_shard, 137 | /*writer_options=*/writer_options); 138 | if (!status.ok()) { 139 | throw std::runtime_error( 140 | "Failed to initialize RiegeliDatasetWriter: " + 141 | status.ToString()); 142 | } 143 | return writer; 144 | })); 145 | } 146 | -------------------------------------------------------------------------------- /envlogger/backends/riegeli_backend_reader.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 DeepMind Technologies Limited.. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """For reading trajectory data from riegeli files.""" 17 | 18 | import copy 19 | from typing import Any 20 | 21 | from absl import logging 22 | import dm_env 23 | from envlogger import step_data 24 | from envlogger.backends import backend_reader 25 | from envlogger.backends.python import episode_info 26 | from envlogger.backends.python import riegeli_dataset_reader 27 | from envlogger.converters import codec 28 | from envlogger.proto import storage_pb2 29 | 30 | 31 | 32 | 33 | 34 | class RiegeliBackendReader(backend_reader.BackendReader): 35 | """A class that reads logs produced by an EnvironmentLoggerWrapper instance. 36 | 37 | Attributes: 38 | episodes: Traverse the data episode-wise in list-like fashion. 39 | steps: Traverse the data stepwise in list-like fashion. 40 | """ 41 | 42 | def __init__(self, data_directory: str): 43 | self._reader = riegeli_dataset_reader.RiegeliDatasetReader() 44 | try: 45 | self._reader.init(data_directory) 46 | except RuntimeError as e: 47 | error_message = str(e) 48 | if error_message.startswith('NOT_FOUND: Empty steps in '): 49 | # This case happens frequently when clients abruptly kill the 50 | # EnvironmentLogger without calling its .close() method, which then 51 | # causes the last shard to be truncated. This can be because the client 52 | # exited successfully and "forgot" to call .close(), which is a bug, but 53 | # also because of a preempted work unit, which is expected to happen 54 | # under distributed settings. 55 | # We can't do much to fix the bad usages, but we can be a bit more 56 | # permissive and try to read the successful shards. 57 | logging.exception("""Ignoring error due to empty step offset file. 58 | ********************************* 59 | **** You likely forgot to *** 60 | **** call close() on your env *** 61 | **** *** 62 | *********************************""") 63 | else: 64 | raise 65 | 66 | self._metadata = codec.decode(self._reader.metadata()) or {} 67 | super().__init__() 68 | 69 | def _copy(self) -> 'RiegeliBackendReader': 70 | c = copy.copy(self) 71 | 72 | c._metadata = copy.deepcopy(self._metadata) 73 | c._reader = self._reader.clone() 74 | 75 | return c 76 | 77 | def close(self): 78 | if self._reader is not None: 79 | self._reader.close() 80 | self._reader = None 81 | 82 | def _decode_step_data(self, data: tuple[Any, Any, Any]) -> step_data.StepData: 83 | """Recovers dm_env.TimeStep from logged data (either dict or tuple).""" 84 | # Recover the TimeStep from the first tuple element. 85 | timestep = dm_env.TimeStep( 86 | dm_env.StepType(data[0][0]), data[0][1], data[0][2], data[0][3]) 87 | return step_data.StepData(timestep, data[1], data[2]) 88 | 89 | def _get_num_steps(self): 90 | return self._reader.num_steps 91 | 92 | def _get_num_episodes(self): 93 | return self._reader.num_episodes 94 | 95 | def _get_nth_step(self, i: int) -> step_data.StepData: 96 | """Returns the timestep given by offset `i` (0-based).""" 97 | serialized_data = self._reader.serialized_step(i) 98 | data = storage_pb2.Data.FromString(serialized_data) 99 | return self._decode_step_data(codec.decode(data)) 100 | 101 | def _get_nth_episode_info(self, 102 | i: int, 103 | include_metadata: bool = False 104 | ) -> episode_info.EpisodeInfo: 105 | """Returns the index of the start of nth episode, and its length.""" 106 | return self._reader.episode(i, include_metadata) 107 | 108 | def metadata(self): 109 | return self._metadata 110 | -------------------------------------------------------------------------------- /envlogger/backends/riegeli_backend_writer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 DeepMind Technologies Limited.. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """For writing trajectory data to riegeli files.""" 17 | 18 | from typing import Any, Optional 19 | 20 | from absl import logging 21 | from envlogger import step_data 22 | from envlogger.backends import backend_writer 23 | from envlogger.backends import schedulers 24 | from envlogger.backends.python import riegeli_dataset_writer 25 | from envlogger.converters import codec 26 | 27 | 28 | 29 | class RiegeliBackendWriter(backend_writer.BackendWriter): 30 | """Backend that writes trajectory data to riegeli files.""" 31 | 32 | def __init__( 33 | self, 34 | data_directory: str, 35 | max_episodes_per_file: int = 10000, 36 | writer_options: str = 'transpose,brotli:6,chunk_size:1M', 37 | flush_scheduler: Optional[schedulers.Scheduler] = None, 38 | **base_kwargs, 39 | ): 40 | """Constructor. 41 | 42 | Calling `close()` will flush the trajectories and the index to disk and will 43 | ensure that they can be read later on. If it isn't called, there is a large 44 | risk of losing data. This is particularly common in some RL frameworks that 45 | do not clean up their environments. If the environment runs for a very long 46 | time, this can happen only to the last shard, but if the instance is 47 | short-lived, then a large portion of the trajectories can disappear. 48 | 49 | Args: 50 | data_directory: Destination for the episode data. IMPORTANT: 51 | `data_directory` MUST exist _before_ calling `__init__()`. 52 | max_episodes_per_file: maximum number of episodes stored in one file. 53 | writer_options: Comma-separated list of options that are passed to Riegeli 54 | RecordWriter as is. 55 | flush_scheduler: This controls when data is flushed to permanent storage. 56 | If `None`, it defaults to a step-wise Bernoulli scheduler with 1/5000 57 | chances of flushing. 58 | **base_kwargs: arguments for the base class. 59 | """ 60 | super().__init__(**base_kwargs) 61 | self._data_directory = data_directory 62 | if flush_scheduler is None: 63 | self._flush_scheduler = schedulers.BernoulliStepScheduler(1.0 / 5000) 64 | else: 65 | self._flush_scheduler = flush_scheduler 66 | self._data_writer = riegeli_dataset_writer.RiegeliDatasetWriter() 67 | logging.info('self._data_directory: %r', self._data_directory) 68 | 69 | metadata = self._metadata or {} 70 | 71 | try: 72 | self._data_writer.init( 73 | data_dir=data_directory, 74 | metadata=codec.encode(metadata), 75 | max_episodes_per_shard=max_episodes_per_file, 76 | writer_options=writer_options) 77 | except RuntimeError as e: 78 | logging.exception('exception: %r', e) 79 | 80 | def _record_step(self, data: step_data.StepData, 81 | is_new_episode: bool) -> None: 82 | encoded_data = codec.encode(data) 83 | if not self._data_writer.add_step(encoded_data, is_new_episode): 84 | raise RuntimeError( 85 | 'Failed to write `data`. Please see logs for more details.') 86 | 87 | if self._flush_scheduler is not None and not self._flush_scheduler(data): 88 | return 89 | self._data_writer.flush() 90 | 91 | def set_episode_metadata(self, data: Any) -> None: 92 | encoded_data = codec.encode(data) 93 | self._data_writer.set_episode_metadata(encoded_data) 94 | 95 | def close(self) -> None: 96 | logging.info('Deleting the backend with data_dir: %r', self._data_directory) 97 | self._data_writer.close() 98 | logging.info('Done deleting the backend with data_dir: %r', 99 | self._data_directory) 100 | -------------------------------------------------------------------------------- /envlogger/backends/rlds_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 DeepMind Technologies Limited.. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Utils to convert Envlogger data into RLDS.""" 17 | 18 | from typing import Any, Optional 19 | 20 | from absl import logging 21 | from envlogger import step_data 22 | import numpy as np 23 | import tensorflow as tf 24 | import tensorflow_datasets as tfds 25 | 26 | Step = dict[str, Any] 27 | 28 | 29 | def to_rlds_step(prev_step: step_data.StepData, 30 | step: Optional[step_data.StepData]) -> Step: 31 | """Builds an RLDS step from two Envlogger steps. 32 | 33 | Steps follow the RLDS convention from https://github.com/google-research/rlds. 34 | 35 | Args: 36 | prev_step: previous step. 37 | step: current step. If None, it builds the last step (where the observation 38 | is the last one, and the action, reward and discount are undefined). 39 | 40 | Returns: 41 | RLDS Step. 42 | 43 | """ 44 | metadata = {} 45 | if isinstance(prev_step.custom_data, dict): 46 | metadata = prev_step.custom_data 47 | return { 48 | 'action': 49 | step.action if step else tf.nest.map_structure( 50 | np.zeros_like, prev_step.action), 51 | 'discount': 52 | step.timestep.discount if step else tf.nest.map_structure( 53 | np.zeros_like, prev_step.timestep.discount), 54 | 'is_first': 55 | prev_step.timestep.first(), 56 | 'is_last': 57 | prev_step.timestep.last(), 58 | 'is_terminal': (prev_step.timestep.last() and 59 | prev_step.timestep.discount == 0.0), 60 | 'observation': 61 | prev_step.timestep.observation, 62 | 'reward': 63 | step.timestep.reward if step else tf.nest.map_structure( 64 | np.zeros_like, prev_step.timestep.reward), 65 | **metadata, 66 | } 67 | 68 | 69 | def _find_extra_shard(split_info: tfds.core.SplitInfo) -> Optional[Any]: 70 | """Returns the filename of the extra shard, or None if all shards are in the metadata.""" 71 | if split_info.filename_template is None: 72 | # Filename template is not initialized. 73 | return None 74 | filepath = split_info.filename_template.sharded_filepath( 75 | shard_index=split_info.num_shards, num_shards=split_info.num_shards + 1) 76 | if tf.io.gfile.exists(filepath): 77 | # There is one extra shard for which we don't have metadata. 78 | return filepath 79 | return None 80 | 81 | 82 | def maybe_recover_last_shard(builder: tfds.core.DatasetBuilder): 83 | """Goes through the splits and recovers the incomplete shards. 84 | 85 | It checks if the last shard is missing. If that is the case, it rewrites the 86 | metadata. This requires to read the full shard so it may take some time. 87 | 88 | We assume that only the last shard can be unaccounted for in the 89 | metadata because the logger generates shards sequentially and it updates the 90 | metadata once a shard is done and before starting the new shard. 91 | 92 | Args: 93 | builder: TFDS builder of the dataset that may have incomplete shards. 94 | 95 | Returns: 96 | A builder with the new split information. 97 | 98 | """ 99 | split_infos = builder.info.splits 100 | splits_to_update = 0 101 | for _, split_info in split_infos.items(): 102 | extra_shard = _find_extra_shard(split_info) 103 | if extra_shard is None: 104 | continue 105 | logging.info('Recovering data for shard %s.', extra_shard) 106 | splits_to_update += 1 107 | ds = tf.data.TFRecordDataset(extra_shard) 108 | num_examples = 0 109 | num_bytes = 0 110 | for ex in ds: 111 | num_examples += 1 112 | num_bytes += len(ex.numpy()) 113 | 114 | new_split_info = split_info.replace( 115 | shard_lengths=split_info.shard_lengths + [num_examples], 116 | num_bytes=split_info.num_bytes + num_bytes) 117 | old_splits = [ 118 | v for k, v in builder.info.splits.items() if k != new_split_info.name 119 | ] 120 | builder.info.set_splits(tfds.core.SplitDict(old_splits + [new_split_info])) 121 | if splits_to_update > 0: 122 | builder.info.write_to_directory(builder.data_dir) 123 | return builder 124 | -------------------------------------------------------------------------------- /envlogger/backends/schedulers.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 DeepMind Technologies Limited.. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Common logging scheduling strategies.""" 17 | 18 | from collections.abc import Callable 19 | from typing import Optional, Union 20 | 21 | from envlogger import step_data 22 | import numpy as np 23 | 24 | 25 | # A Scheduler returns True when something should be activated and False 26 | # otherwise. 27 | Scheduler = Callable[[step_data.StepData], bool] 28 | 29 | 30 | class NStepScheduler: 31 | """Returns `True` every N times it is called.""" 32 | 33 | def __init__(self, step_interval: int): 34 | if step_interval <= 0: 35 | raise ValueError(f'step_interval must be positive, got {step_interval}') 36 | 37 | self._step_interval = step_interval 38 | self._step_counter = 0 39 | 40 | def __call__(self, unused_data: step_data.StepData): 41 | """Returns `True` every N times it is called.""" 42 | 43 | should_log = self._step_counter % self._step_interval == 0 44 | self._step_counter += 1 45 | return should_log 46 | 47 | 48 | class BernoulliStepScheduler: 49 | """Returns `True` with a given probability.""" 50 | 51 | def __init__(self, keep_probability: float, seed: Optional[int] = None): 52 | if keep_probability < 0.0 or keep_probability > 1.0: 53 | raise ValueError( 54 | f'keep_probability must be in [0,1], got: {keep_probability}') 55 | 56 | self._keep_probability = keep_probability 57 | self._rng = np.random.default_rng(seed) 58 | 59 | def __call__(self, unused_data: step_data.StepData): 60 | """Returns `True` with probability `self._keep_probability`.""" 61 | 62 | return self._rng.random() < self._keep_probability 63 | 64 | 65 | class NEpisodeScheduler: 66 | """Returns `True` every N episodes.""" 67 | 68 | def __init__(self, episode_interval: int): 69 | if episode_interval <= 0: 70 | raise ValueError( 71 | f'episode_interval must be positive, got {episode_interval}') 72 | 73 | self._episode_interval = episode_interval 74 | self._episode_counter = 0 75 | 76 | def __call__(self, data: step_data.StepData): 77 | """Returns `True` every N episodes.""" 78 | 79 | should_log = self._episode_counter % self._episode_interval == 0 80 | if data.timestep.last(): 81 | self._episode_counter += 1 82 | return should_log 83 | 84 | 85 | class BernoulliEpisodeScheduler: 86 | """Returns `True` with a given probability at every episode.""" 87 | 88 | def __init__(self, keep_probability: float, seed: Optional[int] = None): 89 | if keep_probability < 0.0 or keep_probability > 1.0: 90 | raise ValueError( 91 | f'keep_probability must be in [0,1], got: {keep_probability}') 92 | 93 | self._keep_probability = keep_probability 94 | self._rng = np.random.default_rng(seed) 95 | self._current_p = self._rng.random() 96 | 97 | def __call__(self, data: step_data.StepData): 98 | """Returns `True` with probability `self._keep_probability`.""" 99 | 100 | if data.timestep.last(): 101 | self._current_p = self._rng.random() 102 | return self._current_p < self._keep_probability 103 | 104 | 105 | class ListStepScheduler: 106 | """Returns `True` for steps in `desired_steps`. 107 | 108 | Please see unit tests for examples of using this scheduler. In particular, 109 | you can use Numpy's functions such as logspace() to generate non-linear steps. 110 | """ 111 | 112 | def __init__(self, desired_steps: Union[list[int], np.ndarray]): 113 | if (isinstance(desired_steps, np.ndarray) and 114 | not (desired_steps.dtype == np.int32 or 115 | desired_steps.dtype == np.int64)): 116 | raise TypeError( 117 | f'desired_steps.dtype must be np.in32 or np.int64: {desired_steps} ' 118 | f'(dtype: {desired_steps.dtype})') 119 | if len(desired_steps) <= 0: 120 | raise ValueError(f'desired_steps cannot be empty: {desired_steps}') 121 | 122 | self._desired_steps = set(desired_steps) 123 | self._step_counter = 0 124 | 125 | def __call__(self, data: step_data.StepData): 126 | """Returns `True` every N episodes.""" 127 | 128 | should_log = self._step_counter in self._desired_steps 129 | self._step_counter += 1 130 | return should_log 131 | 132 | 133 | class ListEpisodeScheduler: 134 | """Returns `True` for episodes in `desired_episodes`. 135 | 136 | Please see unit tests for examples of using this scheduler. In particular, 137 | you can use Numpy's functions such as logspace() to generate non-linear steps. 138 | """ 139 | 140 | def __init__(self, desired_episodes: Union[list[int], np.ndarray]): 141 | if (isinstance(desired_episodes, np.ndarray) and 142 | not (desired_episodes.dtype == np.int32 or 143 | desired_episodes.dtype == np.int64)): 144 | raise TypeError('desired_episodes.dtype must be np.in32 or np.int64: ' 145 | f'{desired_episodes} (dtype: {desired_episodes.dtype})') 146 | if len(desired_episodes) <= 0: 147 | raise ValueError(f'desired_episodes cannot be empty: {desired_episodes}') 148 | 149 | self._desired_episodes = set(desired_episodes) 150 | self._episode_counter = 0 151 | 152 | def __call__(self, data: step_data.StepData): 153 | """Returns `True` every N episodes.""" 154 | 155 | should_log = self._episode_counter in self._desired_episodes 156 | if data.timestep.last(): 157 | self._episode_counter += 1 158 | return should_log 159 | -------------------------------------------------------------------------------- /envlogger/backends/tfds_backend_testlib.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 DeepMind Technologies Limited.. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Utils to test the backends.""" 17 | import time 18 | from typing import Any, Optional 19 | 20 | from absl import logging 21 | from envlogger import step_data 22 | from envlogger.backends import backend_writer 23 | from envlogger.backends import tfds_backend_writer 24 | from envlogger.testing import catch_env 25 | import numpy as np 26 | import tensorflow as tf 27 | import tensorflow_datasets as tfds 28 | 29 | 30 | def generate_episode_data( 31 | backend: backend_writer.BackendWriter, 32 | num_episodes: int = 2, 33 | ) -> list[list[step_data.StepData]]: 34 | """Runs a Catch environment for `num_episodes` and logs them. 35 | 36 | Args: 37 | backend: environment logger writer. 38 | num_episodes: number of episodes to generate. 39 | 40 | Returns: 41 | List of generated episodes. 42 | 43 | """ 44 | env = catch_env.Catch() 45 | 46 | logging.info('Training a random agent for %r episodes...', num_episodes) 47 | episodes_data = [] 48 | for index in range(num_episodes): 49 | episode = [] 50 | timestep = env.reset() 51 | data = step_data.StepData(timestep, None, {'timestamp': int(time.time())}) 52 | episode.append(data) 53 | backend.record_step(data, is_new_episode=True) 54 | 55 | while not timestep.last(): 56 | action = np.random.randint(low=0, high=3) 57 | timestep = env.step(action) 58 | data = step_data.StepData(timestep, action, 59 | {'timestamp': int(time.time())}) 60 | episode.append(data) 61 | backend.record_step(data, is_new_episode=False) 62 | backend.set_episode_metadata({'episode_id': index}) 63 | episodes_data.append(episode) 64 | 65 | logging.info('Done training a random agent for %r episodes.', num_episodes) 66 | env.close() 67 | backend.close() 68 | return episodes_data 69 | 70 | 71 | def catch_env_tfds_config( 72 | name: str = 'catch_example') -> tfds.rlds.rlds_base.DatasetConfig: 73 | """Creates a TFDS DatasetConfig for the Catch environment.""" 74 | return tfds.rlds.rlds_base.DatasetConfig( 75 | name=name, 76 | observation_info=tfds.features.Tensor( 77 | shape=(10, 5), dtype=tf.float32, 78 | encoding=tfds.features.Encoding.ZLIB), 79 | action_info=tf.int64, 80 | reward_info=tf.float64, 81 | discount_info=tf.float64, 82 | step_metadata_info={'timestamp': tf.int64}, 83 | episode_metadata_info={'episode_id': tf.int64}) 84 | 85 | 86 | def tfds_backend_catch_env( 87 | data_directory: str, 88 | max_episodes_per_file: int = 1, 89 | split_name: Optional[str] = None, 90 | ds_metadata: Optional[dict[Any, Any]] = None, 91 | store_ds_metadata: bool = True, 92 | ) -> tfds_backend_writer.TFDSBackendWriter: 93 | """Creates a TFDS Backend Writer for the Catch Environment. 94 | 95 | Args: 96 | data_directory: directory where the data will be created (it has to exist). 97 | max_episodes_per_file: maximum number of episodes per file. 98 | split_name: number of the TFDS split to create. 99 | ds_metadata: metadata of the dataset. 100 | store_ds_metadata: if the metadata should be stored. 101 | Returns: 102 | TFDS backend writer. 103 | """ 104 | return tfds_backend_writer.TFDSBackendWriter( 105 | data_directory=data_directory, 106 | split_name=split_name, 107 | ds_config=catch_env_tfds_config(), 108 | max_episodes_per_file=max_episodes_per_file, 109 | metadata=ds_metadata, 110 | store_ds_metadata=store_ds_metadata) 111 | -------------------------------------------------------------------------------- /envlogger/backends/tfds_backend_writer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 DeepMind Technologies Limited.. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """TFDS backend for Envlogger.""" 17 | import dataclasses 18 | from typing import Any, Optional 19 | 20 | from absl import logging 21 | from envlogger import step_data 22 | from envlogger.backends import backend_writer 23 | from envlogger.backends import rlds_utils 24 | import tensorflow_datasets as tfds 25 | 26 | 27 | DatasetConfig = tfds.rlds.rlds_base.DatasetConfig 28 | 29 | 30 | @dataclasses.dataclass 31 | class Episode(object): 32 | """Episode that is being constructed.""" 33 | prev_step: step_data.StepData 34 | steps: Optional[list[rlds_utils.Step]] = None 35 | metadata: Optional[dict[str, Any]] = None 36 | 37 | def add_step(self, step: step_data.StepData) -> None: 38 | rlds_step = rlds_utils.to_rlds_step(self.prev_step, step) 39 | if self.steps is None: 40 | self.steps = [] 41 | self.steps.append(rlds_step) 42 | self.prev_step = step 43 | 44 | def get_rlds_episode(self) -> dict[str, Any]: 45 | last_step = rlds_utils.to_rlds_step(self.prev_step, None) 46 | if self.steps is None: 47 | self.steps = [] 48 | if self.metadata is None: 49 | self.metadata = {} 50 | 51 | return {'steps': self.steps + [last_step], **self.metadata} 52 | 53 | 54 | class TFDSBackendWriter(backend_writer.BackendWriter): 55 | """Backend that writes trajectory data in TFDS format (and RLDS structure).""" 56 | 57 | 58 | def __init__(self, 59 | data_directory: str, 60 | ds_config: tfds.rlds.rlds_base.DatasetConfig, 61 | max_episodes_per_file: int = 1000, 62 | split_name: Optional[str] = None, 63 | version: str = '0.0.1', 64 | store_ds_metadata: bool = False, 65 | **base_kwargs): 66 | """Constructor. 67 | 68 | Args: 69 | data_directory: Directory to store the data 70 | ds_config: Dataset Configuration. 71 | max_episodes_per_file: Number of episodes to store per shard. 72 | split_name: Name to be used by the split. If None, 'train' will be used. 73 | version: version (major.minor.patch) of the dataset. 74 | store_ds_metadata: if False, it won't store the dataset level 75 | metadata. 76 | **base_kwargs: arguments for the base class. 77 | """ 78 | super().__init__(**base_kwargs) 79 | if not split_name: 80 | split_name = 'train' 81 | ds_identity = tfds.core.dataset_info.DatasetIdentity( 82 | name=ds_config.name, 83 | version=tfds.core.Version(version), 84 | data_dir=data_directory, 85 | module_name='') 86 | if store_ds_metadata: 87 | metadata = self._metadata 88 | else: 89 | metadata = None 90 | self._data_directory = data_directory 91 | self._ds_info = tfds.rlds.rlds_base.build_info(ds_config, ds_identity, 92 | metadata) 93 | self._ds_info.set_file_format('tfrecord') 94 | 95 | self._current_episode = None 96 | 97 | self._sequential_writer = tfds.core.SequentialWriter( 98 | self._ds_info, max_episodes_per_file) 99 | self._split_name = split_name 100 | self._sequential_writer.initialize_splits([split_name]) 101 | logging.info('self._data_directory: %r', self._data_directory) 102 | 103 | def _write_and_reset_episode(self): 104 | if self._current_episode is not None: 105 | self._sequential_writer.add_examples( 106 | {self._split_name: [self._current_episode.get_rlds_episode()]}) 107 | self._current_episode = None 108 | 109 | def _record_step(self, data: step_data.StepData, 110 | is_new_episode: bool) -> None: 111 | """Stores RLDS steps in TFDS format.""" 112 | 113 | if is_new_episode: 114 | self._write_and_reset_episode() 115 | 116 | if self._current_episode is None: 117 | self._current_episode = Episode(prev_step=data) 118 | else: 119 | self._current_episode.add_step(data) 120 | 121 | def set_episode_metadata(self, data: dict[str, Any]) -> None: 122 | self._current_episode.metadata = data 123 | 124 | def close(self) -> None: 125 | logging.info('Deleting the backend with data_dir: %r', self._data_directory) 126 | self._write_and_reset_episode() 127 | self._sequential_writer.close_all() 128 | logging.info('Done deleting the backend with data_dir: %r', 129 | self._data_directory) 130 | -------------------------------------------------------------------------------- /envlogger/converters/BUILD.bazel: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited.. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Converters to and from environment logger proto format. 16 | load("@rules_python//python:defs.bzl", "py_library", "py_test") 17 | 18 | package(default_visibility = ["//visibility:public"]) 19 | 20 | py_library( 21 | name = "codec", 22 | srcs = ["codec.py"], 23 | deps = [ 24 | "//envlogger/proto:storage_py_pb2", 25 | ], 26 | ) 27 | 28 | py_test( 29 | name = "codec_test", 30 | srcs = ["codec_test.py"], 31 | deps = [ 32 | ":codec", 33 | "//envlogger/proto:storage_py_pb2", 34 | ], 35 | ) 36 | 37 | py_library( 38 | name = "spec_codec", 39 | srcs = ["spec_codec.py"], 40 | ) 41 | 42 | py_test( 43 | name = "spec_codec_test", 44 | srcs = ["spec_codec_test.py"], 45 | deps = [ 46 | ":spec_codec", 47 | "//envlogger/proto:storage_py_pb2", 48 | ], 49 | ) 50 | 51 | cc_library( 52 | name = "xtensor_codec", 53 | srcs = ["xtensor_codec.cc"], 54 | hdrs = ["xtensor_codec.h"], 55 | deps = [ 56 | "//envlogger/proto:storage_cc_proto", 57 | "@com_github_google_glog//:glog", 58 | "@com_google_absl//absl/base", 59 | "@com_google_absl//absl/strings", 60 | "@com_google_absl//absl/strings:cord", 61 | "@com_google_absl//absl/strings:str_format", 62 | "@com_google_riegeli//riegeli/bytes:string_writer", 63 | "@com_google_riegeli//riegeli/endian:endian_reading", 64 | "@com_google_riegeli//riegeli/endian:endian_writing", 65 | "@gmp", 66 | "@xtensor", 67 | ], 68 | ) 69 | 70 | cc_test( 71 | name = "xtensor_codec_test", 72 | srcs = ["xtensor_codec_test.cc"], 73 | deps = [ 74 | ":make_visitor", 75 | ":xtensor_codec", 76 | "//envlogger/platform:parse_text_proto", 77 | "//envlogger/platform:proto_testutil", 78 | "//envlogger/proto:storage_cc_proto", 79 | "@com_google_googletest//:gtest", 80 | "@com_google_googletest//:gtest_main", 81 | "@gmp", 82 | ], 83 | ) 84 | 85 | cc_library( 86 | name = "make_visitor", 87 | hdrs = ["make_visitor.h"], 88 | ) 89 | 90 | cc_test( 91 | name = "make_visitor_test", 92 | srcs = ["make_visitor_test.cc"], 93 | deps = [ 94 | ":make_visitor", 95 | "@com_google_googletest//:gtest", 96 | "@com_google_googletest//:gtest_main", 97 | ], 98 | ) 99 | -------------------------------------------------------------------------------- /envlogger/converters/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 DeepMind Technologies Limited.. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /envlogger/converters/make_visitor.h: -------------------------------------------------------------------------------- 1 | // Copyright 2024 DeepMind Technologies Limited.. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #ifndef THIRD_PARTY_PY_ENVLOGGER_CONVERTERS_MAKE_VISITOR_H_ 16 | #define THIRD_PARTY_PY_ENVLOGGER_CONVERTERS_MAKE_VISITOR_H_ 17 | 18 | // Template magic for more easily creating overloaded visitors for use with 19 | // std::visit. 20 | // 21 | // Example: 22 | // 23 | // const std::variant my_variant = PopulateVariant(); 24 | // const auto visitor = envlogger::MakeVisitor( 25 | // [](const Foo& foo) { DoFooStuff(foo); }, 26 | // [](const Bar& bar) { DoBarStuff(bar); } 27 | // ); 28 | // std::visit(visitor, my_variant); 29 | 30 | namespace envlogger { 31 | 32 | // This uses C++17 type expansion to inherit from each of the lambda types 33 | // passed in to the constructor and inherit all of their operator()s. 34 | template 35 | struct Visitor : Visitors... { 36 | explicit Visitor(const Visitors&... v) : Visitors(v)... {} 37 | using Visitors::operator()...; 38 | }; 39 | 40 | template 41 | Visitor MakeVisitor(Visitors... visitors) { 42 | return Visitor(visitors...); 43 | } 44 | 45 | } // namespace envlogger 46 | 47 | #endif // THIRD_PARTY_PY_ENVLOGGER_CONVERTERS_MAKE_VISITOR_H_ 48 | -------------------------------------------------------------------------------- /envlogger/converters/make_visitor_test.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2024 DeepMind Technologies Limited.. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #include "envlogger/converters/make_visitor.h" 16 | 17 | #include 18 | #include 19 | #include 20 | #include 21 | 22 | #include "gmock/gmock.h" 23 | #include "gtest/gtest.h" 24 | 25 | namespace envlogger { 26 | namespace { 27 | 28 | using ::testing::Eq; 29 | using ::testing::IsFalse; 30 | using ::testing::IsTrue; 31 | 32 | TEST(MakeVisitorTest, SanityCheck) { 33 | bool executed_char_visitor = false; 34 | bool executed_double_visitor = false; 35 | 36 | const auto visitor = MakeVisitor( 37 | [&](char) { 38 | executed_char_visitor = true; 39 | executed_double_visitor = false; 40 | }, 41 | [&](double) { 42 | executed_char_visitor = false; 43 | executed_double_visitor = true; 44 | }, 45 | [&](int) { FAIL() << "This shouldn't be called."; }); 46 | 47 | std::visit(visitor, std::variant('x')); 48 | EXPECT_THAT(executed_char_visitor, IsTrue()); 49 | EXPECT_THAT(executed_double_visitor, IsFalse()); 50 | 51 | std::visit(visitor, std::variant(1.5)); 52 | EXPECT_THAT(executed_char_visitor, IsFalse()); 53 | EXPECT_THAT(executed_double_visitor, IsTrue()); 54 | } 55 | 56 | struct SizeOfVisitor { 57 | template 58 | std::size_t operator()(const T&) const { 59 | return sizeof(T); 60 | } 61 | }; 62 | 63 | struct Padded { 64 | char pad[64]; 65 | }; 66 | constexpr size_t kPaddedSize = sizeof(Padded); 67 | 68 | TEST(MakeVisitorTest, TemplatedVisitor) { 69 | using VariantT = std::variant; 70 | 71 | const auto visitor = 72 | MakeVisitor([](std::nullptr_t) -> std::size_t { return 0; }, // 73 | SizeOfVisitor{}, // 74 | [](char) -> std::size_t { return -1; }); 75 | EXPECT_THAT(std::visit(visitor, VariantT('x')), Eq(static_cast(-1))); 76 | EXPECT_THAT(std::visit(visitor, VariantT(1.5)), Eq(sizeof(double))); 77 | EXPECT_THAT(std::visit(visitor, VariantT(nullptr)), Eq(0)); 78 | EXPECT_THAT(std::visit(visitor, VariantT(Padded{})), Eq(kPaddedSize)); 79 | } 80 | 81 | } // namespace 82 | } // namespace envlogger 83 | -------------------------------------------------------------------------------- /envlogger/converters/spec_codec.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 DeepMind Technologies Limited.. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Encoder/decoder for dm_env.specs.Array (and subclasses). 17 | """ 18 | 19 | from typing import Any, Optional, Union 20 | 21 | import dm_env 22 | from dm_env import specs 23 | import numpy as np 24 | 25 | 26 | _ENVIRONMENT_SPEC_NAMES = [ 27 | 'observation_spec', 28 | 'action_spec', 29 | 'reward_spec', 30 | 'discount_spec', 31 | ] 32 | 33 | 34 | def encode_environment_specs( 35 | env: Optional[dm_env.Environment], 36 | ) -> dict[str, Any]: 37 | """Encodes all the specs from a given environment.""" 38 | if env: 39 | return { 40 | 'observation_spec': encode(env.observation_spec()), 41 | 'action_spec': encode(env.action_spec()), 42 | 'reward_spec': encode(env.reward_spec()), 43 | 'discount_spec': encode(env.discount_spec()), 44 | } 45 | return {} 46 | 47 | 48 | def decode_environment_specs( 49 | encoded_specs: dict[str, Any], 50 | ) -> dict[str, Optional[specs.Array]]: 51 | """Decodes all the specs of an environment.""" 52 | if encoded_specs: 53 | return {spec_name: decode(encoded_specs[spec_name]) # pytype: disable=bad-return-type # always-use-return-annotations 54 | for spec_name in _ENVIRONMENT_SPEC_NAMES} 55 | return {spec_name: None for spec_name in _ENVIRONMENT_SPEC_NAMES} 56 | 57 | 58 | def _array_spec_to_dict(array_spec: specs.Array) -> dict[str, Any]: 59 | """Encodes an Array spec as a dictionary.""" 60 | dict_spec = { 61 | 'shape': np.array(array_spec.shape, dtype=np.int64), 62 | 'dtype': str(array_spec.dtype), 63 | 'name': array_spec.name, 64 | } 65 | if isinstance(array_spec, specs.BoundedArray): 66 | dict_spec.update({ 67 | 'minimum': array_spec.minimum, 68 | 'maximum': array_spec.maximum, 69 | }) 70 | if isinstance(array_spec, specs.DiscreteArray): 71 | dict_spec.update({'num_values': array_spec.num_values}) 72 | return dict_spec 73 | 74 | 75 | def encode( 76 | spec: Union[specs.Array, list[Any], tuple[Any, ...], dict[str, Any]], 77 | ) -> Union[list[Any], tuple[Any, ...], dict[str, Any]]: 78 | """Encodes `spec` using plain Python objects. 79 | 80 | This function supports bare Array specs, lists of Array specs, Tuples of Array 81 | specs, Dicts of string to Array specs and any combination of these things such 82 | as Dict[str, Tuple[List[Array, Array]]]. 83 | 84 | Args: 85 | spec: The actual spec to encode. 86 | Returns: 87 | The same spec encoded in a way that can be serialized to disk. 88 | Raises: 89 | TypeError: When the argument is not among the supported types. 90 | """ 91 | if isinstance(spec, specs.Array): 92 | return _array_spec_to_dict(spec) 93 | if isinstance(spec, list): 94 | return [encode(x) for x in spec] 95 | if isinstance(spec, tuple): 96 | return tuple((encode(x) for x in spec)) 97 | if isinstance(spec, dict): 98 | return {k: encode(v) for k, v in spec.items()} 99 | raise TypeError( 100 | 'encode() should be called with an argument of type specs.Array (and ' 101 | f'subclasses), list, tuple or dict. Found {type(spec)}: {spec}.') 102 | 103 | 104 | def decode( 105 | spec: Union[list[Any], tuple[Any, ...], dict[str, Any]], 106 | ) -> Union[specs.Array, list[Any], tuple[Any, ...], dict[str, Any]]: 107 | """Parses `spec` into the supported dm_env spec formats.""" 108 | if isinstance(spec, dict): 109 | if 'shape' in spec and 'dtype' in spec: 110 | shape = spec['shape'] if spec['shape'] is not None else () 111 | if 'num_values' in spec: 112 | # DiscreteArray case. 113 | return specs.DiscreteArray( 114 | num_values=spec['num_values'], 115 | dtype=spec['dtype'], 116 | name=spec['name']) 117 | elif 'minimum' in spec and 'maximum' in spec: 118 | # BoundedArray case. 119 | return specs.BoundedArray( 120 | shape=shape, 121 | dtype=spec['dtype'], 122 | minimum=spec['minimum'], 123 | maximum=spec['maximum'], 124 | name=spec['name']) 125 | else: 126 | # Base Array spec case. 127 | return specs.Array(shape=shape, dtype=spec['dtype'], name=spec['name']) 128 | # Recursively decode array elements. 129 | return {k: decode(v) for k, v in spec.items()} 130 | elif isinstance(spec, list): 131 | return [decode(x) for x in spec] 132 | elif isinstance(spec, tuple): 133 | return tuple(decode(x) for x in spec) 134 | raise TypeError( 135 | 'decode() should be called with an argument of type list, tuple or dict.' 136 | f' Found: {type(spec)}: {spec}.') 137 | -------------------------------------------------------------------------------- /envlogger/environment_wrapper.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 DeepMind Technologies Limited.. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Base class for implementing environment wrappers..""" 17 | 18 | import pickle 19 | 20 | import dm_env 21 | 22 | 23 | class EnvironmentWrapper(dm_env.Environment): 24 | """An Environment which delegates calls to another environment. 25 | 26 | Subclasses should override one or more methods to modify the behavior of the 27 | backing environment as desired per the Decorator Pattern. 28 | 29 | This exposes the wrapped environment to subclasses with the `._environment` 30 | property and also defines `__getattr__` so that attributes are invisibly 31 | forwarded to the wrapped environment (and hence enabling duck-typing). 32 | """ 33 | 34 | def __init__(self, environment: dm_env.Environment): 35 | self._environment = environment 36 | 37 | def __getattr__(self, name): 38 | return getattr(self._environment, name) 39 | 40 | def __getstate__(self): 41 | return pickle.dumps(self._environment) 42 | 43 | def __setstate__(self, state): 44 | self._environment = pickle.loads(state) 45 | 46 | def step(self, action) -> dm_env.TimeStep: 47 | return self._environment.step(action) 48 | 49 | def reset(self) -> dm_env.TimeStep: 50 | return self._environment.reset() 51 | 52 | def action_spec(self): 53 | return self._environment.action_spec() 54 | 55 | def discount_spec(self): 56 | return self._environment.discount_spec() 57 | 58 | def observation_spec(self): 59 | return self._environment.observation_spec() 60 | 61 | def reward_spec(self): 62 | return self._environment.reward_spec() 63 | 64 | def close(self): 65 | return self._environment.close() 66 | -------------------------------------------------------------------------------- /envlogger/examples/BUILD.bazel: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited.. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Examples. 16 | load("@rules_python//python:defs.bzl", "py_binary") 17 | 18 | package(default_visibility = ["//visibility:public"]) 19 | 20 | py_binary( 21 | name = "random_agent_catch", 22 | srcs = ["random_agent_catch.py"], 23 | deps = [ 24 | "//envlogger", 25 | "//envlogger/testing:catch_env", 26 | ], 27 | ) 28 | 29 | py_binary( 30 | name = "tfds_random_agent_catch", 31 | srcs = ["tfds_random_agent_catch.py"], 32 | deps = [ 33 | "//envlogger", 34 | "//envlogger/backends:tfds_backend_writer", 35 | "//envlogger/testing:catch_env", 36 | ], 37 | ) 38 | -------------------------------------------------------------------------------- /envlogger/examples/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 DeepMind Technologies Limited.. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /envlogger/examples/random_agent_catch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 DeepMind Technologies Limited.. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """A simple binary to run catch for a while and record its trajectories. 17 | """ 18 | 19 | import time 20 | 21 | from absl import app 22 | from absl import flags 23 | from absl import logging 24 | import envlogger 25 | from envlogger.testing import catch_env 26 | import numpy as np 27 | 28 | 29 | FLAGS = flags.FLAGS 30 | 31 | flags.DEFINE_integer('num_episodes', 1000, 'Number of episodes to log.') 32 | flags.DEFINE_string('trajectories_dir', '/tmp/catch_data/', 33 | 'Path in a filesystem to record trajectories.') 34 | 35 | 36 | def main(unused_argv): 37 | logging.info('Creating Catch environment...') 38 | env = catch_env.Catch() 39 | logging.info('Done creating Catch environment.') 40 | 41 | def step_fn(unused_timestep, unused_action, unused_env): 42 | return {'timestamp': time.time()} 43 | 44 | logging.info('Wrapping environment with EnvironmentLogger...') 45 | with envlogger.EnvLogger( 46 | env, 47 | data_directory=FLAGS.trajectories_dir, 48 | max_episodes_per_file=1000, 49 | metadata={ 50 | 'agent_type': 'random', 51 | 'env_type': type(env).__name__, 52 | 'num_episodes': FLAGS.num_episodes, 53 | }, 54 | step_fn=step_fn) as env: 55 | logging.info('Done wrapping environment with EnvironmentLogger.') 56 | 57 | logging.info('Training a random agent for %r episodes...', 58 | FLAGS.num_episodes) 59 | for i in range(FLAGS.num_episodes): 60 | logging.info('episode %r', i) 61 | timestep = env.reset() 62 | while not timestep.last(): 63 | action = np.random.randint(low=0, high=3) 64 | timestep = env.step(action) 65 | logging.info('Done training a random agent for %r episodes.', 66 | FLAGS.num_episodes) 67 | 68 | 69 | if __name__ == '__main__': 70 | app.run(main) 71 | -------------------------------------------------------------------------------- /envlogger/examples/tfds_random_agent_catch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 DeepMind Technologies Limited.. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """A simple binary to run catch for a while and record its trajectories. 17 | """ 18 | 19 | import time 20 | 21 | from absl import app 22 | from absl import flags 23 | from absl import logging 24 | import envlogger 25 | from envlogger.backends import tfds_backend_writer 26 | 27 | from envlogger.testing import catch_env 28 | import numpy as np 29 | import tensorflow as tf 30 | import tensorflow_datasets as tfds 31 | 32 | 33 | 34 | FLAGS = flags.FLAGS 35 | 36 | flags.DEFINE_integer('num_episodes', 1000, 'Number of episodes to log.') 37 | flags.DEFINE_string('trajectories_dir', '/tmp/catch_data/', 38 | 'Path in a filesystem to record trajectories.') 39 | 40 | 41 | def main(unused_argv): 42 | logging.info('Creating Catch environment...') 43 | env = catch_env.Catch() 44 | logging.info('Done creating Catch environment.') 45 | 46 | def step_fn(unused_timestep, unused_action, unused_env): 47 | return {'timestamp': time.time()} 48 | 49 | dataset_config = tfds.rlds.rlds_base.DatasetConfig( 50 | name='catch_example', 51 | observation_info=tfds.features.Tensor( 52 | shape=(10, 5), dtype=tf.float32, 53 | encoding=tfds.features.Encoding.ZLIB), 54 | action_info=tf.int64, 55 | reward_info=tf.float64, 56 | discount_info=tf.float64, 57 | step_metadata_info={'timestamp': tf.float32}) 58 | 59 | logging.info('Wrapping environment with EnvironmentLogger...') 60 | with envlogger.EnvLogger( 61 | env, 62 | step_fn=step_fn, 63 | backend = tfds_backend_writer.TFDSBackendWriter( 64 | data_directory=FLAGS.trajectories_dir, 65 | split_name='train', 66 | max_episodes_per_file=500, 67 | ds_config=dataset_config), 68 | ) as env: 69 | logging.info('Done wrapping environment with EnvironmentLogger.') 70 | 71 | logging.info('Training a random agent for %r episodes...', 72 | FLAGS.num_episodes) 73 | for i in range(FLAGS.num_episodes): 74 | logging.info('episode %r', i) 75 | timestep = env.reset() 76 | while not timestep.last(): 77 | action = np.random.randint(low=0, high=3) 78 | timestep = env.step(action) 79 | logging.info('Done training a random agent for %r episodes.', 80 | FLAGS.num_episodes) 81 | 82 | 83 | if __name__ == '__main__': 84 | app.run(main) 85 | -------------------------------------------------------------------------------- /envlogger/platform/BUILD.bazel: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited.. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Utilities 16 | 17 | package(default_visibility = ["//visibility:public"]) 18 | 19 | cc_library( 20 | name = "bundle", 21 | hdrs = ["bundle.h"], 22 | deps = ["//envlogger/platform/default:bundle"], 23 | ) 24 | 25 | cc_library( 26 | name = "filesystem", 27 | hdrs = ["filesystem.h"], 28 | deps = [ 29 | ":status_macros", 30 | "//envlogger/platform/default:filesystem", 31 | "@com_google_absl//absl/status", 32 | "@com_google_absl//absl/status:statusor", 33 | "@com_google_absl//absl/strings", 34 | ], 35 | ) 36 | 37 | cc_library( 38 | name = "parse_text_proto", 39 | hdrs = ["parse_text_proto.h"], 40 | deps = ["//envlogger/platform/default:parse_text_proto"], 41 | ) 42 | 43 | cc_library( 44 | name = "proto_testutil", 45 | testonly = 1, 46 | hdrs = ["proto_testutil.h"], 47 | deps = ["//envlogger/platform/default:proto_testutil"], 48 | ) 49 | 50 | cc_library( 51 | name = "riegeli_file_reader", 52 | hdrs = ["riegeli_file_reader.h"], 53 | deps = ["//envlogger/platform/default:riegeli_file_reader"], 54 | ) 55 | 56 | cc_library( 57 | name = "riegeli_file_writer", 58 | hdrs = ["riegeli_file_writer.h"], 59 | deps = ["//envlogger/platform/default:riegeli_file_writer"], 60 | ) 61 | 62 | cc_library( 63 | name = "status_macros", 64 | hdrs = ["status_macros.h"], 65 | deps = ["//envlogger/platform/default:status_macros"], 66 | ) 67 | 68 | cc_library( 69 | name = "test_macros", 70 | testonly = 1, 71 | hdrs = ["test_macros.h"], 72 | deps = [ 73 | ":status_macros", 74 | "@com_google_googletest//:gtest", 75 | ], 76 | ) 77 | -------------------------------------------------------------------------------- /envlogger/platform/bundle.h: -------------------------------------------------------------------------------- 1 | // Copyright 2024 DeepMind Technologies Limited.. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #ifndef THIRD_PARTY_PY_ENVLOGGER_PLATFORM_BUNDLE_H_ 16 | #define THIRD_PARTY_PY_ENVLOGGER_PLATFORM_BUNDLE_H_ 17 | 18 | #include "envlogger/platform/default/bundle.h" 19 | 20 | #endif // THIRD_PARTY_PY_ENVLOGGER_PLATFORM_BUNDLE_H_ 21 | -------------------------------------------------------------------------------- /envlogger/platform/default/BUILD.bazel: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited.. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Default platform specific targets. 16 | 17 | package(default_visibility = ["//visibility:public"]) 18 | 19 | cc_library( 20 | name = "bundle", 21 | srcs = ["bundle.cc"], 22 | hdrs = ["bundle.h"], 23 | deps = ["@com_github_google_glog//:glog"], 24 | ) 25 | 26 | cc_library( 27 | name = "filesystem", 28 | srcs = ["filesystem.cc"], 29 | hdrs = ["filesystem.h"], 30 | linkopts = ["-lstdc++fs"], 31 | deps = [ 32 | "//envlogger/platform:status_macros", 33 | "@com_google_absl//absl/status:statusor", 34 | "@com_google_absl//absl/strings", 35 | ], 36 | ) 37 | 38 | cc_library( 39 | name = "parse_text_proto", 40 | hdrs = ["parse_text_proto.h"], 41 | deps = ["@com_github_google_glog//:glog"], 42 | ) 43 | 44 | cc_library( 45 | name = "proto_testutil", 46 | testonly = 1, 47 | hdrs = ["proto_testutil.h"], 48 | deps = [ 49 | "@com_github_google_glog//:glog", 50 | "@com_google_googletest//:gtest", 51 | ], 52 | ) 53 | 54 | cc_library( 55 | name = "riegeli_file_reader", 56 | hdrs = ["riegeli_file_reader.h"], 57 | deps = [ 58 | ":filesystem", 59 | "@com_google_absl//absl/strings", 60 | "@com_google_riegeli//riegeli/bytes:fd_reader", 61 | ], 62 | ) 63 | 64 | cc_library( 65 | name = "riegeli_file_writer", 66 | hdrs = ["riegeli_file_writer.h"], 67 | deps = [ 68 | ":filesystem", 69 | "@com_google_absl//absl/strings", 70 | "@com_google_riegeli//riegeli/bytes:fd_writer", 71 | ], 72 | ) 73 | 74 | cc_library( 75 | name = "source_location", 76 | hdrs = ["source_location.h"], 77 | ) 78 | 79 | cc_library( 80 | name = "status_builder", 81 | srcs = ["status_builder.cc"], 82 | hdrs = ["status_builder.h"], 83 | deps = [ 84 | ":source_location", 85 | "@com_google_absl//absl/base:core_headers", 86 | "@com_google_absl//absl/status", 87 | "@com_google_absl//absl/strings", 88 | ], 89 | ) 90 | 91 | cc_library( 92 | name = "status_macros", 93 | hdrs = ["status_macros.h"], 94 | deps = [ 95 | ":status_builder", 96 | "@com_google_absl//absl/status", 97 | ], 98 | ) 99 | -------------------------------------------------------------------------------- /envlogger/platform/default/bundle.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2024 DeepMind Technologies Limited.. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #include "envlogger/platform/default/bundle.h" 16 | 17 | #include 18 | #include // NOLINT(build/c++11) 19 | #include 20 | 21 | #include "glog/logging.h" 22 | 23 | namespace envlogger { 24 | namespace thread { 25 | 26 | Bundle::Bundle() : finished_(false) {} 27 | Bundle::~Bundle() { 28 | CHECK(finished_) << "JoinAll() should be called before releasing the bundle."; 29 | } 30 | 31 | void Bundle::Add(std::function&& function) { 32 | CHECK(!finished_) << "Add cannot be called after JoinAll is invoked."; 33 | futures_.push_back(std::async(std::launch::async, std::move(function))); 34 | } 35 | 36 | void Bundle::JoinAll() { 37 | CHECK(!finished_) << "JoinAll should be called only once."; 38 | finished_ = true; 39 | for (const auto& future : futures_) { 40 | future.wait(); 41 | } 42 | futures_.clear(); 43 | } 44 | 45 | } // namespace thread 46 | } // namespace envlogger 47 | -------------------------------------------------------------------------------- /envlogger/platform/default/bundle.h: -------------------------------------------------------------------------------- 1 | // Copyright 2024 DeepMind Technologies Limited.. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #ifndef THIRD_PARTY_PY_ENVLOGGER_PLATFORM_DEFAULT_BUNDLE_H_ 16 | #define THIRD_PARTY_PY_ENVLOGGER_PLATFORM_DEFAULT_BUNDLE_H_ 17 | 18 | #include 19 | #include // NOLINT(build/c++11) 20 | #include 21 | 22 | namespace envlogger { 23 | namespace thread { 24 | 25 | // Bundles a set of parallel function calls. 26 | class Bundle { 27 | public: 28 | Bundle(); 29 | ~Bundle(); 30 | 31 | // Adds the function for asynchronous execution. It cannot be called if 32 | // JoinAll() has been invoked. 33 | void Add(std::function&& function); 34 | 35 | // Waits for the execution of the added asynchronous functions to terminate. 36 | // It must be called before releasing the Bundle object. 37 | void JoinAll(); 38 | 39 | private: 40 | bool finished_ = false; 41 | std::vector> futures_; 42 | }; 43 | 44 | } // namespace thread 45 | } // namespace envlogger 46 | 47 | #endif // THIRD_PARTY_PY_ENVLOGGER_PLATFORM_DEFAULT_BUNDLE_H_ 48 | -------------------------------------------------------------------------------- /envlogger/platform/default/filesystem.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2024 DeepMind Technologies Limited.. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #include 16 | #include 17 | #include 18 | #include // NOLINT(build/c++11) 19 | #include 20 | 21 | #include "absl/status/status.h" 22 | #include "absl/status/statusor.h" 23 | #include "absl/strings/str_cat.h" 24 | #include "absl/strings/string_view.h" 25 | 26 | namespace envlogger { 27 | namespace file { 28 | 29 | std::string JoinPath(absl::string_view dirname, absl::string_view basename) { 30 | return std::filesystem::path(dirname) / basename; 31 | } 32 | 33 | absl::Status CreateDir(absl::string_view path) { 34 | if (!std::filesystem::create_directory(path)) { 35 | return absl::InternalError( 36 | absl::StrCat("Unable to create directory ", path)); 37 | } 38 | return absl::OkStatus(); 39 | } 40 | 41 | absl::StatusOr> GetSubdirectories( 42 | absl::string_view path, absl::string_view sentinel) { 43 | std::vector subdirs; 44 | for (const auto& entry : std::filesystem::directory_iterator(path)) { 45 | if (entry.is_directory()) { 46 | if (!sentinel.empty() && 47 | !std::filesystem::exists(entry.path() / sentinel)) { 48 | continue; 49 | } 50 | subdirs.push_back(entry.path()); 51 | } 52 | } 53 | return subdirs; 54 | } 55 | 56 | absl::Status RecursivelyDelete(absl::string_view path) { 57 | if (!std::filesystem::remove_all(path)) { 58 | return absl::InternalError( 59 | absl::StrCat("Unable to recursively delete directory ", path)); 60 | } 61 | return absl::OkStatus(); 62 | } 63 | 64 | absl::StatusOr GetSize(absl::string_view path) { 65 | std::error_code ec; 66 | const std::uintmax_t size = std::filesystem::file_size(path, ec); 67 | if (ec) { 68 | return absl::NotFoundError(absl::StrCat("Could not find file ", path)); 69 | } 70 | return static_cast(size); 71 | } 72 | 73 | } // namespace file 74 | } // namespace envlogger 75 | -------------------------------------------------------------------------------- /envlogger/platform/default/filesystem.h: -------------------------------------------------------------------------------- 1 | // Copyright 2024 DeepMind Technologies Limited.. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #ifndef THIRD_PARTY_PY_ENVLOGGER_PLATFORM_DEFAULT_FILESYSTEM_H_ 16 | #define THIRD_PARTY_PY_ENVLOGGER_PLATFORM_DEFAULT_FILESYSTEM_H_ 17 | 18 | // Intentionally empty. 19 | 20 | #endif // THIRD_PARTY_PY_ENVLOGGER_PLATFORM_DEFAULT_FILESYSTEM_H_ 21 | -------------------------------------------------------------------------------- /envlogger/platform/default/parse_text_proto.h: -------------------------------------------------------------------------------- 1 | // Copyright 2024 DeepMind Technologies Limited.. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #ifndef THIRD_PARTY_PY_ENVLOGGER_PLATFORM_DEFAULT_PARSE_TEXT_PROTO_H_ 16 | #define THIRD_PARTY_PY_ENVLOGGER_PLATFORM_DEFAULT_PARSE_TEXT_PROTO_H_ 17 | 18 | #include "glog/logging.h" 19 | #include "google/protobuf/text_format.h" 20 | 21 | namespace envlogger { 22 | // Forward declarations for the friend statement. 23 | namespace internal { 24 | class ParseProtoHelper; 25 | } // namespace internal 26 | internal::ParseProtoHelper ParseTextProtoOrDie(const std::string& input); 27 | 28 | namespace internal { 29 | // Helper class to automatically infer the type of the protocol buffer message. 30 | class ParseProtoHelper { 31 | public: 32 | template 33 | operator T() { 34 | T result; 35 | CHECK(google::protobuf::TextFormat::ParseFromString(input_, &result)); 36 | return result; 37 | } 38 | 39 | private: 40 | friend ParseProtoHelper envlogger::ParseTextProtoOrDie( 41 | const std::string& input); 42 | ParseProtoHelper(const std::string& input) : input_(input) {} 43 | 44 | const std::string& input_; 45 | }; 46 | } // namespace internal 47 | 48 | // Parses the specified ASCII protocol buffer message or dies. 49 | internal::ParseProtoHelper ParseTextProtoOrDie(const std::string& input) { 50 | return internal::ParseProtoHelper(input); 51 | } 52 | 53 | // Parses the specified ASCII protocol buffer message and returns it or dies. 54 | template 55 | T ParseTextOrDie(const std::string& input) { 56 | return ParseTextProtoOrDie(input); 57 | } 58 | 59 | } // namespace envlogger 60 | 61 | #endif // THIRD_PARTY_PY_ENVLOGGER_PLATFORM_DEFAULT_PARSE_TEXT_PROTO_H_ 62 | -------------------------------------------------------------------------------- /envlogger/platform/default/proto_testutil.h: -------------------------------------------------------------------------------- 1 | // Copyright 2024 DeepMind Technologies Limited.. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #ifndef THIRD_PARTY_PY_ENVLOGGER_PLATFORM_DEFAULT_PROTO_TESTUTIL_H_ 16 | #define THIRD_PARTY_PY_ENVLOGGER_PLATFORM_DEFAULT_PROTO_TESTUTIL_H_ 17 | 18 | #include 19 | 20 | #include "glog/logging.h" 21 | #include "google/protobuf/text_format.h" 22 | #include "google/protobuf/util/message_differencer.h" 23 | #include "gmock/gmock.h" 24 | 25 | namespace envlogger { 26 | namespace internal { 27 | 28 | // Simple implementation of a proto matcher comparing string representations. 29 | // 30 | // IMPORTANT: Only use this for protos whose textual representation is 31 | // deterministic (that may not be the case for the map collection type). 32 | class ProtoStringMatcher { 33 | public: 34 | explicit ProtoStringMatcher(const std::string& expected) 35 | : expected_proto_str_(expected) {} 36 | explicit ProtoStringMatcher(const google::protobuf::Message& expected) 37 | : expected_proto_str_(expected.DebugString()) {} 38 | 39 | template 40 | bool MatchAndExplain(const Message& actual_proto, 41 | ::testing::MatchResultListener* listener) const; 42 | 43 | void DescribeTo(::std::ostream* os) const { *os << expected_proto_str_; } 44 | void DescribeNegationTo(::std::ostream* os) const { 45 | *os << "not equal to expected message: " << expected_proto_str_; 46 | } 47 | 48 | void SetComparePartially() { 49 | scope_ = ::google::protobuf::util::MessageDifferencer::PARTIAL; 50 | } 51 | 52 | private: 53 | const std::string expected_proto_str_; 54 | google::protobuf::util::MessageDifferencer::Scope scope_ = 55 | google::protobuf::util::MessageDifferencer::FULL; 56 | }; 57 | 58 | template 59 | T CreateProto(const std::string& textual_proto) { 60 | T proto; 61 | CHECK(google::protobuf::TextFormat::ParseFromString(textual_proto, &proto)); 62 | return proto; 63 | } 64 | 65 | template 66 | bool ProtoStringMatcher::MatchAndExplain( 67 | const Message& actual_proto, 68 | ::testing::MatchResultListener* listener) const { 69 | Message expected_proto = CreateProto(expected_proto_str_); 70 | 71 | google::protobuf::util::MessageDifferencer differencer; 72 | std::string differences; 73 | differencer.ReportDifferencesToString(&differences); 74 | differencer.set_scope(scope_); 75 | 76 | if (!differencer.Compare(expected_proto, actual_proto)) { 77 | *listener << "the protos are different:\n" << differences; 78 | return false; 79 | } 80 | 81 | return true; 82 | } 83 | } // namespace internal 84 | 85 | // Polymorphic matcher to compare any two protos. 86 | inline ::testing::PolymorphicMatcher EqualsProto( 87 | const std::string& x) { 88 | return ::testing::MakePolymorphicMatcher(internal::ProtoStringMatcher(x)); 89 | } 90 | 91 | // Polymorphic matcher to compare any two protos. 92 | inline ::testing::PolymorphicMatcher EqualsProto( 93 | const google::protobuf::Message& x) { 94 | return ::testing::MakePolymorphicMatcher(internal::ProtoStringMatcher(x)); 95 | } 96 | 97 | // Only compare the fields populated in the matcher proto. 98 | template 99 | inline InnerProtoMatcher Partially(InnerProtoMatcher inner_proto_matcher) { 100 | inner_proto_matcher.mutable_impl().SetComparePartially(); 101 | return inner_proto_matcher; 102 | } 103 | 104 | } // namespace envlogger 105 | 106 | #endif // THIRD_PARTY_PY_ENVLOGGER_PLATFORM_DEFAULT_PROTO_TESTUTIL_H_ 107 | -------------------------------------------------------------------------------- /envlogger/platform/default/riegeli_file_reader.h: -------------------------------------------------------------------------------- 1 | // Copyright 2024 DeepMind Technologies Limited.. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #ifndef THIRD_PARTY_PY_ENVLOGGER_PLATFORM_DEFAULT_RIEGELI_FILE_READER_H_ 16 | #define THIRD_PARTY_PY_ENVLOGGER_PLATFORM_DEFAULT_RIEGELI_FILE_READER_H_ 17 | 18 | #include "riegeli/bytes/fd_reader.h" 19 | 20 | namespace envlogger { 21 | 22 | using RiegeliFileReader = ::riegeli::FdReader<>; 23 | 24 | } // namespace envlogger 25 | 26 | #endif // THIRD_PARTY_PY_ENVLOGGER_PLATFORM_DEFAULT_RIEGELI_FILE_READER_H_ 27 | -------------------------------------------------------------------------------- /envlogger/platform/default/riegeli_file_writer.h: -------------------------------------------------------------------------------- 1 | // Copyright 2024 DeepMind Technologies Limited.. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #ifndef THIRD_PARTY_PY_ENVLOGGER_PLATFORM_DEFAULT_RIEGELI_FILE_WRITER_H_ 16 | #define THIRD_PARTY_PY_ENVLOGGER_PLATFORM_DEFAULT_RIEGELI_FILE_WRITER_H_ 17 | 18 | #include "riegeli/bytes/fd_writer.h" 19 | 20 | namespace envlogger { 21 | 22 | using RiegeliFileWriter = ::riegeli::FdWriter<>; 23 | 24 | } // namespace envlogger 25 | 26 | #endif // THIRD_PARTY_PY_ENVLOGGER_PLATFORM_DEFAULT_RIEGELI_FILE_WRITER_H_ 27 | -------------------------------------------------------------------------------- /envlogger/platform/default/source_location.h: -------------------------------------------------------------------------------- 1 | // Copyright 2024 DeepMind Technologies Limited.. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #ifndef THIRD_PARTY_PY_ENVLOGGER_PLATFORM_DEFAULT_SOURCE_LOCATION_H_ 16 | #define THIRD_PARTY_PY_ENVLOGGER_PLATFORM_DEFAULT_SOURCE_LOCATION_H_ 17 | 18 | #include 19 | 20 | namespace envlogger { 21 | namespace internal { 22 | 23 | // Class representing a specific location in the source code of a program. 24 | // source_location is copyable. 25 | class source_location { 26 | public: 27 | // Avoid this constructor; it populates the object with dummy values. 28 | constexpr source_location() : line_(0), file_name_(nullptr) {} 29 | 30 | // Wrapper to invoke the private constructor below. This should only be 31 | // used by the LOC macro, hence the name. 32 | static constexpr source_location DoNotInvokeDirectly(std::uint_least32_t line, 33 | const char* file_name) { 34 | return source_location(line, file_name); 35 | } 36 | 37 | // The line number of the captured source location. 38 | constexpr std::uint_least32_t line() const { return line_; } 39 | 40 | // The file name of the captured source location. 41 | constexpr const char* file_name() const { return file_name_; } 42 | 43 | // column() and function_name() are omitted because we don't have a 44 | // way to support them. 45 | 46 | private: 47 | // Do not invoke this constructor directly. Instead, use the 48 | // LOC macro below. 49 | // 50 | // file_name must outlive all copies of the source_location 51 | // object, so in practice it should be a string literal. 52 | constexpr source_location(std::uint_least32_t line, const char* file_name) 53 | : line_(line), file_name_(file_name) {} 54 | 55 | std::uint_least32_t line_; 56 | const char* file_name_; 57 | }; 58 | 59 | } // namespace internal 60 | } // namespace envlogger 61 | 62 | // If a function takes a source_location parameter, pass this as the argument. 63 | #define LOC \ 64 | ::envlogger::internal::source_location::DoNotInvokeDirectly(__LINE__, \ 65 | __FILE__) 66 | 67 | #endif // THIRD_PARTY_PY_ENVLOGGER_PLATFORM_DEFAULT_SOURCE_LOCATION_H_ 68 | -------------------------------------------------------------------------------- /envlogger/platform/default/status_builder.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2024 DeepMind Technologies Limited.. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #include "envlogger/platform/default/status_builder.h" 16 | 17 | #include 18 | #include 19 | #include 20 | 21 | namespace envlogger { 22 | namespace internal { 23 | 24 | StatusBuilder::StatusBuilder(const StatusBuilder& sb) { 25 | status_ = sb.status_; 26 | file_ = sb.file_; 27 | line_ = sb.line_; 28 | no_logging_ = sb.no_logging_; 29 | stream_ = std::make_unique(sb.stream_->str()); 30 | join_style_ = sb.join_style_; 31 | } 32 | 33 | StatusBuilder& StatusBuilder::operator=(const StatusBuilder& sb) { 34 | status_ = sb.status_; 35 | file_ = sb.file_; 36 | line_ = sb.line_; 37 | no_logging_ = sb.no_logging_; 38 | stream_ = std::make_unique(sb.stream_->str()); 39 | join_style_ = sb.join_style_; 40 | return *this; 41 | } 42 | 43 | StatusBuilder& StatusBuilder::SetAppend() { 44 | if (status_.ok()) return *this; 45 | join_style_ = MessageJoinStyle::kAppend; 46 | return *this; 47 | } 48 | 49 | StatusBuilder& StatusBuilder::SetPrepend() { 50 | if (status_.ok()) return *this; 51 | join_style_ = MessageJoinStyle::kPrepend; 52 | return *this; 53 | } 54 | 55 | StatusBuilder& StatusBuilder::SetNoLogging() { 56 | no_logging_ = true; 57 | return *this; 58 | } 59 | 60 | StatusBuilder::operator absl::Status() const& { 61 | if (stream_->str().empty() || no_logging_) { 62 | return status_; 63 | } 64 | return StatusBuilder(*this).JoinMessageToStatus(); 65 | } 66 | 67 | StatusBuilder::operator absl::Status() && { 68 | if (stream_->str().empty() || no_logging_) { 69 | return status_; 70 | } 71 | return JoinMessageToStatus(); 72 | } 73 | 74 | absl::Status StatusBuilder::JoinMessageToStatus() { 75 | std::string message; 76 | if (join_style_ == MessageJoinStyle::kAnnotate) { 77 | if (!status_.ok()) { 78 | message = absl::StrCat(status_.message(), "; ", stream_->str()); 79 | } 80 | } else { 81 | message = join_style_ == MessageJoinStyle::kPrepend 82 | ? absl::StrCat(stream_->str(), status_.message()) 83 | : absl::StrCat(status_.message(), stream_->str()); 84 | } 85 | return absl::Status(status_.code(), message); 86 | } 87 | 88 | } // namespace internal 89 | } // namespace envlogger 90 | -------------------------------------------------------------------------------- /envlogger/platform/default/status_builder.h: -------------------------------------------------------------------------------- 1 | // Copyright 2024 DeepMind Technologies Limited.. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #ifndef THIRD_PARTY_PY_ENVLOGGER_PLATFORM_DEFAULT_STATUS_BUILDER_H_ 16 | #define THIRD_PARTY_PY_ENVLOGGER_PLATFORM_DEFAULT_STATUS_BUILDER_H_ 17 | 18 | #include 19 | 20 | #include "absl/base/attributes.h" 21 | #include "absl/status/status.h" 22 | #include "absl/strings/str_cat.h" 23 | #include "absl/strings/string_view.h" 24 | #include "envlogger/platform/default/source_location.h" 25 | 26 | namespace envlogger { 27 | namespace internal { 28 | 29 | class ABSL_MUST_USE_RESULT StatusBuilder { 30 | public: 31 | StatusBuilder(const StatusBuilder& sb); 32 | StatusBuilder& operator=(const StatusBuilder& sb); 33 | // Creates a `StatusBuilder` based on an original status. If logging is 34 | // enabled, it will use `location` as the location from which the log message 35 | // occurs. 36 | StatusBuilder(const absl::Status& original_status, source_location location) 37 | : status_(original_status), 38 | line_(location.line()), 39 | file_(location.file_name()), 40 | stream_(new std::ostringstream) {} 41 | 42 | StatusBuilder(absl::Status&& original_status, source_location location) 43 | : status_(std::move(original_status)), 44 | line_(location.line()), 45 | file_(location.file_name()), 46 | stream_(new std::ostringstream) {} 47 | 48 | // Creates a `StatusBuilder` from a status code. If logging is enabled, it 49 | // will use `location` as the location from which the log message occurs. 50 | StatusBuilder(absl::StatusCode code, source_location location) 51 | : status_(code, ""), 52 | line_(location.line()), 53 | file_(location.file_name()), 54 | stream_(new std::ostringstream) {} 55 | 56 | StatusBuilder(const absl::Status& original_status, const char* file, int line) 57 | : status_(original_status), 58 | line_(line), 59 | file_(file), 60 | stream_(new std::ostringstream) {} 61 | 62 | bool ok() const { return status_.ok(); } 63 | 64 | StatusBuilder& SetAppend(); 65 | 66 | StatusBuilder& SetPrepend(); 67 | 68 | StatusBuilder& SetNoLogging(); 69 | 70 | template 71 | StatusBuilder& operator<<(const T& msg) { 72 | if (status_.ok()) return *this; 73 | *stream_ << msg; 74 | return *this; 75 | } 76 | 77 | operator absl::Status() const&; 78 | operator absl::Status() &&; 79 | 80 | absl::Status JoinMessageToStatus(); 81 | 82 | private: 83 | // Specifies how to join the error message in the original status and any 84 | // additional message that has been streamed into the builder. 85 | enum class MessageJoinStyle { 86 | kAnnotate, 87 | kAppend, 88 | kPrepend, 89 | }; 90 | 91 | // The status that the result will be based on. 92 | absl::Status status_; 93 | // The line to record if this file is logged. 94 | int line_; 95 | // Not-owned: The file to record if this status is logged. 96 | const char* file_; 97 | bool no_logging_ = false; 98 | // The additional messages added with `<<`. 99 | std::unique_ptr stream_; 100 | // Specifies how to join the message in `status_` and `stream_`. 101 | MessageJoinStyle join_style_ = MessageJoinStyle::kAnnotate; 102 | }; 103 | 104 | inline StatusBuilder AlreadyExistsErrorBuilder(source_location location) { 105 | return StatusBuilder(absl::StatusCode::kAlreadyExists, location); 106 | } 107 | 108 | inline StatusBuilder FailedPreconditionErrorBuilder(source_location location) { 109 | return StatusBuilder(absl::StatusCode::kFailedPrecondition, location); 110 | } 111 | 112 | inline StatusBuilder InternalErrorBuilder(source_location location) { 113 | return StatusBuilder(absl::StatusCode::kInternal, location); 114 | } 115 | 116 | inline StatusBuilder InvalidArgumentErrorBuilder(source_location location) { 117 | return StatusBuilder(absl::StatusCode::kInvalidArgument, location); 118 | } 119 | 120 | inline StatusBuilder NotFoundErrorBuilder(source_location location) { 121 | return StatusBuilder(absl::StatusCode::kNotFound, location); 122 | } 123 | 124 | inline StatusBuilder UnavailableErrorBuilder(source_location location) { 125 | return StatusBuilder(absl::StatusCode::kUnavailable, location); 126 | } 127 | 128 | inline StatusBuilder UnimplementedErrorBuilder(source_location location) { 129 | return StatusBuilder(absl::StatusCode::kUnimplemented, location); 130 | } 131 | 132 | inline StatusBuilder UnknownErrorBuilder(source_location location) { 133 | return StatusBuilder(absl::StatusCode::kUnknown, location); 134 | } 135 | 136 | } // namespace internal 137 | } // namespace envlogger 138 | 139 | #endif // THIRD_PARTY_PY_ENVLOGGER_PLATFORM_DEFAULT_STATUS_BUILDER_H_ 140 | -------------------------------------------------------------------------------- /envlogger/platform/filesystem.h: -------------------------------------------------------------------------------- 1 | // Copyright 2024 DeepMind Technologies Limited.. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #ifndef THIRD_PARTY_PY_ENVLOGGER_PLATFORM_FILESYSTEM_H_ 16 | #define THIRD_PARTY_PY_ENVLOGGER_PLATFORM_FILESYSTEM_H_ 17 | 18 | #include 19 | #include 20 | #include 21 | 22 | #include "absl/status/status.h" 23 | #include "absl/status/statusor.h" 24 | #include "absl/strings/string_view.h" 25 | #include "envlogger/platform/default/filesystem.h" 26 | 27 | namespace envlogger { 28 | namespace file { 29 | 30 | // Join multiple paths together. 31 | std::string JoinPath(absl::string_view dirname, absl::string_view basename); 32 | 33 | // Creates a directory with the given path. 34 | // Fails if the directory cannot be created (e.g. it already exists). 35 | absl::Status CreateDir(absl::string_view path); 36 | 37 | // Returns the list of subdirectories under the specified path. If the sentinel 38 | // is not empty, then only subdirectories that contain a file with that name 39 | // will be present. 40 | absl::StatusOr> GetSubdirectories( 41 | absl::string_view path, absl::string_view sentinel = ""); 42 | 43 | // Recursively deletes the specified path. 44 | absl::Status RecursivelyDelete(absl::string_view path); 45 | 46 | // Returns the file size in bytes. 47 | absl::StatusOr GetSize(absl::string_view path); 48 | 49 | } // namespace file 50 | } // namespace envlogger 51 | 52 | #endif // THIRD_PARTY_PY_ENVLOGGER_PLATFORM_FILESYSTEM_H_ 53 | -------------------------------------------------------------------------------- /envlogger/platform/parse_text_proto.h: -------------------------------------------------------------------------------- 1 | // Copyright 2024 DeepMind Technologies Limited.. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #ifndef THIRD_PARTY_PY_ENVLOGGER_PLATFORM_PARSE_TEXT_PROTO_H_ 16 | #define THIRD_PARTY_PY_ENVLOGGER_PLATFORM_PARSE_TEXT_PROTO_H_ 17 | 18 | #include "envlogger/platform/default/parse_text_proto.h" 19 | 20 | #endif // THIRD_PARTY_PY_ENVLOGGER_PLATFORM_PARSE_TEXT_PROTO_H_ 21 | -------------------------------------------------------------------------------- /envlogger/platform/proto_testutil.h: -------------------------------------------------------------------------------- 1 | // Copyright 2024 DeepMind Technologies Limited.. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #ifndef THIRD_PARTY_PY_ENVLOGGER_PLATFORM_PROTO_TESTUTIL_H_ 16 | #define THIRD_PARTY_PY_ENVLOGGER_PLATFORM_PROTO_TESTUTIL_H_ 17 | 18 | #include "envlogger/platform/default/proto_testutil.h" 19 | 20 | #endif // THIRD_PARTY_PY_ENVLOGGER_PLATFORM_PROTO_TESTUTIL_H_ 21 | -------------------------------------------------------------------------------- /envlogger/platform/riegeli_file_reader.h: -------------------------------------------------------------------------------- 1 | // Copyright 2024 DeepMind Technologies Limited.. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #ifndef THIRD_PARTY_PY_ENVLOGGER_PLATFORM_RIEGELI_FILE_READER_H_ 16 | #define THIRD_PARTY_PY_ENVLOGGER_PLATFORM_RIEGELI_FILE_READER_H_ 17 | 18 | #include "envlogger/platform/default/riegeli_file_reader.h" 19 | 20 | #endif // THIRD_PARTY_PY_ENVLOGGER_PLATFORM_RIEGELI_FILE_READER_H_ 21 | -------------------------------------------------------------------------------- /envlogger/platform/riegeli_file_writer.h: -------------------------------------------------------------------------------- 1 | // Copyright 2024 DeepMind Technologies Limited.. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #ifndef THIRD_PARTY_PY_ENVLOGGER_PLATFORM_RIEGELI_FILE_WRITER_H_ 16 | #define THIRD_PARTY_PY_ENVLOGGER_PLATFORM_RIEGELI_FILE_WRITER_H_ 17 | 18 | #include "envlogger/platform/default/riegeli_file_writer.h" 19 | 20 | #endif // THIRD_PARTY_PY_ENVLOGGER_PLATFORM_RIEGELI_FILE_WRITER_H_ 21 | -------------------------------------------------------------------------------- /envlogger/platform/status_macros.h: -------------------------------------------------------------------------------- 1 | // Copyright 2024 DeepMind Technologies Limited.. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #ifndef THIRD_PARTY_PY_ENVLOGGER_PLATFORM_STATUS_MACROS_H_ 16 | #define THIRD_PARTY_PY_ENVLOGGER_PLATFORM_STATUS_MACROS_H_ 17 | 18 | #include "envlogger/platform/default/status_macros.h" 19 | 20 | #endif // THIRD_PARTY_PY_ENVLOGGER_PLATFORM_STATUS_MACROS_H_ 21 | -------------------------------------------------------------------------------- /envlogger/platform/test_macros.h: -------------------------------------------------------------------------------- 1 | // Copyright 2024 DeepMind Technologies Limited.. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #ifndef THIRD_PARTY_PY_ENVLOGGER_PLATFORM_TEST_MACROS_H_ 16 | #define THIRD_PARTY_PY_ENVLOGGER_PLATFORM_TEST_MACROS_H_ 17 | 18 | #include "gtest/gtest.h" 19 | #include "envlogger/platform/status_macros.h" 20 | 21 | #define CONCAT_IMPL(x, y) x##y 22 | #define CONCAT_MACRO(x, y) CONCAT_IMPL(x, y) 23 | 24 | #define ENVLOGGER_ASSERT_OK_AND_ASSIGN(lhs, rexpr) \ 25 | ENVLOGGER_ASSERT_OK_AND_ASSIGN_IMPL(CONCAT_MACRO(_status_or, __COUNTER__), \ 26 | lhs, rexpr) 27 | 28 | #define ENVLOGGER_ASSERT_OK_AND_ASSIGN_IMPL(statusor, lhs, rexpr) \ 29 | auto statusor = (rexpr); \ 30 | ASSERT_TRUE(statusor.status().ok()) << statusor.status(); \ 31 | lhs = std::move(statusor.value()) 32 | 33 | #define ENVLOGGER_EXPECT_OK(expr) ENVLOGGER_CHECK_OK(expr) 34 | 35 | #endif // THIRD_PARTY_PY_ENVLOGGER_PLATFORM_TEST_MACROS_H_ 36 | -------------------------------------------------------------------------------- /envlogger/proto/BUILD.bazel: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited.. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Protocol buffers for environment logger. 16 | load("@com_google_protobuf//:protobuf.bzl", "py_proto_library") 17 | load("@rules_proto//proto:defs.bzl", "proto_library") 18 | 19 | package(default_visibility = ["//visibility:public"]) 20 | 21 | proto_library( 22 | name = "storage_proto", 23 | srcs = ["storage.proto"], 24 | ) 25 | 26 | py_proto_library( 27 | name = "storage_py_pb2", 28 | srcs = ["storage.proto"], 29 | ) 30 | 31 | cc_proto_library( 32 | name = "storage_cc_proto", 33 | deps = [":storage_proto"], 34 | ) 35 | -------------------------------------------------------------------------------- /envlogger/proto/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 DeepMind Technologies Limited.. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /envlogger/proto/storage.proto: -------------------------------------------------------------------------------- 1 | // Copyright 2024 DeepMind Technologies Limited.. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | syntax = "proto3"; 16 | 17 | package envlogger; 18 | 19 | // Datum is basically a multidimensional Matrix (or Tensor if you like). 20 | message Datum { 21 | message Shape { 22 | // One dimension of the tensor. 23 | message Dim { 24 | // Size of the tensor in that dimension. 25 | 26 | // -438 is a dimension size reserved for pure scalars. 27 | // Pure scalars are different from zero-dimensional matrices in many 28 | // programming languages, so this value is used to differentiate between 29 | // them. 30 | // It comes from the sum of negative ASCII codes: 31 | // S = -83 32 | // C = -67 33 | // A = -65 34 | // L = -76 35 | // A = -65 36 | // R = -82 37 | // ------ 38 | // -438 39 | int64 size = 1; 40 | } 41 | 42 | // Dimensions of the tensor, such as {"input", 30}, {"output", 40} 43 | // for a 30 x 40 2D tensor. The names are optional. 44 | // 45 | // The order of entries in "dim" matters: It indicates the layout of the 46 | // values in the tensor in-memory representation. 47 | // 48 | // The first entry in "dim" is the outermost dimension used to layout the 49 | // values, the last entry is the innermost dimension. This matches the 50 | // in-memory layout of RowMajor Eigen tensors and xt::xarray. 51 | repeated Dim dim = 2; 52 | } 53 | 54 | Shape shape = 1; 55 | 56 | // Storage of basic value types. 57 | // Please refer to the API of your programming language to determine the exact 58 | // type that's obtained from decoding Datums. 59 | message Values { 60 | // Types supported natively by protobuf. 61 | // 32-bit float. 62 | repeated float float_values = 1; 63 | // float32 values encoded as big-endian bytes. 64 | bytes float_values_buffer = 15; 65 | // 64-bit float. 66 | repeated double double_values = 2; 67 | // 32-bit signed integer. 68 | repeated int32 int32_values = 3; 69 | // 64-bit signed integer. 70 | repeated int64 int64_values = 4; 71 | // 32-bit unsigned integer. 72 | repeated uint32 uint32_values = 5; 73 | // 64-bit unsigned integer. 74 | repeated uint64 uint64_values = 6; 75 | // boolean value. 76 | repeated bool bool_values = 7; 77 | // string value. 78 | repeated string string_values = 8; 79 | 80 | // Opaque bytes to store arbitrary user-defined values. 81 | repeated bytes bytes_values = 9; 82 | 83 | // Arbitrarily long signed ints stored in big-endian byte order. 84 | // These correspond to normal Python int()s. 85 | repeated bytes bigint_values = 10; 86 | 87 | // Small ints are not supported by protobuf so we encode them as bytes in 88 | // big-endian byte order with fixed-length (1 byte for int8 and 2 bytes for 89 | // int16). 90 | // https://developers.google.com/protocol-buffers/docs/proto3#scalar 91 | 92 | // int8 corresponds to an np.int8 in Python and int8 in C++. 93 | bytes int8_values = 11; 94 | // int16 corresponds to an np.int16 in Python and int16 in C++. 95 | bytes int16_values = 12; 96 | // uint8 corresponds to an np.uint8 in Python and uint8 in C++. 97 | bytes uint8_values = 13; 98 | // uint16 corresponds to an np.uint16 in Python and uint16 in C++. 99 | bytes uint16_values = 14; 100 | } 101 | Values values = 2; 102 | } 103 | 104 | message Data { 105 | // Array represents a sequence of homogeneous elements. 106 | // 107 | // Our APIs expect and check that all elements in Array are of the same type. 108 | // An Array corresponds to a Python list() and to a C++ std::vector. 109 | message Array { 110 | repeated Data values = 1; 111 | } 112 | 113 | // Tuple represents a sequence of heterogeneous elements. 114 | // 115 | // A Tuple corresponds to a Python tuple() and to a C++ std::tuple<>. 116 | message Tuple { 117 | repeated Data values = 1; 118 | } 119 | 120 | // Dict represents a mapping from elements to elements. 121 | // 122 | // A Dict corresponds to a Python dict() and to a C++ 123 | // std::{unordered_}map. 124 | message Dict { 125 | // Specialization for string keys. 126 | map values = 1; 127 | 128 | // A list of key-value pairs. 129 | // This field is used for arbitrary key types. 130 | // NOTE: This is currently only used by the Python codec. C++ support will 131 | // be added in a future change and this note will be removed. 132 | Array kvs = 2; 133 | } 134 | 135 | // Notice that the format does NOT differentiate between empty values, that 136 | // is, an empty Array is represented equally as an empty Dict or empty Tuple. 137 | // Empty `Data` objects are completely empty and do not occupy any space on 138 | // disk and are not sent over the wire. 139 | oneof value { 140 | Datum datum = 1; 141 | Array array = 2; 142 | Tuple tuple = 3; 143 | Dict dict = 4; 144 | } 145 | } 146 | -------------------------------------------------------------------------------- /envlogger/reader.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 DeepMind Technologies Limited.. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Reader of EnvironmentLogger data.""" 17 | 18 | from collections.abc import Sequence 19 | import copy 20 | from typing import Any, Optional, Union 21 | 22 | from absl import logging 23 | from envlogger import step_data 24 | from envlogger.backends import backend_reader 25 | from envlogger.backends import backend_type 26 | from envlogger.backends import in_memory_backend 27 | from envlogger.backends import riegeli_backend_reader 28 | from envlogger.converters import spec_codec 29 | 30 | 31 | class Reader: 32 | """Reader of trajectories generated by EnvLogger.""" 33 | 34 | def __init__(self, 35 | *backend_args, 36 | backend: Union[ 37 | backend_reader.BackendReader, 38 | backend_type.BackendType] = backend_type.BackendType.RIEGELI, 39 | **backend_kwargs): 40 | logging.info('backend: %r', backend) 41 | logging.info('backend_args: %r', backend_args) 42 | logging.info('backend_kwargs: %r', backend_kwargs) 43 | # Set backend. 44 | if isinstance(backend, backend_reader.BackendReader): 45 | self._backend = backend 46 | elif isinstance(backend, backend_type.BackendType): 47 | self._backend = { 48 | backend_type.BackendType.RIEGELI: 49 | riegeli_backend_reader.RiegeliBackendReader, 50 | backend_type.BackendType.IN_MEMORY: 51 | in_memory_backend.InMemoryBackendReader, 52 | }[backend](*backend_args, **backend_kwargs) 53 | else: 54 | raise TypeError(f'Unsupported backend: {backend}') 55 | self._set_specs() 56 | 57 | def copy(self): 58 | c = copy.copy(self) 59 | 60 | c._backend = self._backend.copy() 61 | c._observation = self._observation_spec 62 | c._action_spec = self._action_spec 63 | c._reward_spec = self._reward_spec 64 | c._discount_spec = self._discount_spec 65 | 66 | return c 67 | 68 | def close(self): 69 | self._backend.close() 70 | 71 | def __enter__(self): 72 | return self 73 | 74 | def __exit__(self, exc_type, exc_value, tb): 75 | self.close() 76 | 77 | def __del__(self): 78 | self.close() 79 | 80 | def metadata(self): 81 | return self._backend.metadata() 82 | 83 | @property 84 | def steps(self) -> Sequence[step_data.StepData]: 85 | return self._backend.steps 86 | 87 | @property 88 | def episodes(self) -> Sequence[Sequence[step_data.StepData]]: 89 | return self._backend.episodes 90 | 91 | def episode_metadata(self) -> Sequence[Optional[Any]]: 92 | return self._backend.episode_metadata() 93 | 94 | def observation_spec(self): 95 | return self._observation_spec 96 | 97 | def action_spec(self): 98 | return self._action_spec 99 | 100 | def reward_spec(self): 101 | return self._reward_spec 102 | 103 | def discount_spec(self): 104 | return self._discount_spec 105 | 106 | def _set_specs(self) -> None: 107 | """Extracts and decodes environment specs from the logged data.""" 108 | metadata = self._backend.metadata() or {} 109 | env_specs = spec_codec.decode_environment_specs( 110 | metadata.get('environment_specs', {})) 111 | self._observation_spec = env_specs['observation_spec'] 112 | self._action_spec = env_specs['action_spec'] 113 | self._reward_spec = env_specs['reward_spec'] 114 | self._discount_spec = env_specs['discount_spec'] 115 | -------------------------------------------------------------------------------- /envlogger/requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py 2 | dm_env 3 | numpy 4 | protobuf 5 | tensorflow 6 | tfds-nightly 7 | -------------------------------------------------------------------------------- /envlogger/setup.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 DeepMind Technologies Limited.. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Install script for setuptools.""" 17 | 18 | import os 19 | import posixpath 20 | import shutil 21 | 22 | import pkg_resources 23 | import setuptools 24 | from setuptools.command import build_ext 25 | from setuptools.command import build_py 26 | 27 | PROJECT_NAME = 'envlogger' 28 | 29 | __version__ = '1.2' 30 | 31 | _ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) 32 | 33 | _ENVLOGGER_PROTOS = ( 34 | 'proto/storage.proto', 35 | ) 36 | 37 | 38 | class _GenerateProtoFiles(setuptools.Command): 39 | """Command to generate protobuf bindings for EnvLogger protos.""" 40 | 41 | descriptions = 'Generates Python protobuf bindings for EnvLogger protos.' 42 | user_options = [] 43 | 44 | def initialize_options(self): 45 | pass 46 | 47 | def finalize_options(self): 48 | pass 49 | 50 | def run(self): 51 | # We have to import grpc_tools here, after setuptools has installed 52 | # setup_requires dependencies. 53 | from grpc_tools import protoc 54 | 55 | grpc_protos_include = pkg_resources.resource_filename( 56 | 'grpc_tools', '_proto') 57 | 58 | for proto_path in _ENVLOGGER_PROTOS: 59 | proto_args = [ 60 | 'grpc_tools.protoc', 61 | '--proto_path={}'.format(grpc_protos_include), 62 | '--proto_path={}'.format(_ROOT_DIR), 63 | '--python_out={}'.format(_ROOT_DIR), 64 | '--grpc_python_out={}'.format(_ROOT_DIR), 65 | os.path.join(_ROOT_DIR, proto_path), 66 | ] 67 | if protoc.main(proto_args) != 0: 68 | raise RuntimeError('ERROR: {}'.format(proto_args)) 69 | 70 | 71 | class _BuildPy(build_py.build_py): 72 | """Generate protobuf bindings in build_py stage.""" 73 | 74 | def run(self): 75 | self.run_command('generate_protos') 76 | build_py.build_py.run(self) 77 | 78 | 79 | class BazelExtension(setuptools.Extension): 80 | """A C/C++ extension that is defined as a Bazel BUILD target.""" 81 | 82 | def __init__(self, bazel_target): 83 | self.bazel_target = bazel_target 84 | self.relpath, self.target_name = ( 85 | posixpath.relpath(bazel_target, '//').split(':')) 86 | ext_name = os.path.join( 87 | self.relpath.replace(posixpath.sep, os.path.sep), self.target_name) 88 | super().__init__(ext_name, sources=[]) 89 | 90 | 91 | class _BuildExt(build_ext.build_ext): 92 | """A command that runs Bazel to build a C/C++ extension.""" 93 | 94 | def run(self): 95 | self.run_command('generate_protos') 96 | self.bazel_build() 97 | build_ext.build_ext.run(self) 98 | 99 | def bazel_build(self): 100 | 101 | if not os.path.exists(self.build_temp): 102 | os.makedirs(self.build_temp) 103 | 104 | bazel_argv = [ 105 | 'bazel', 106 | 'build', 107 | '...', 108 | '--symlink_prefix=' + os.path.join(self.build_temp, 'bazel-'), 109 | '--compilation_mode=' + ('dbg' if self.debug else 'opt'), 110 | '--verbose_failures', 111 | ] 112 | 113 | self.spawn(bazel_argv) 114 | 115 | for ext in self.extensions: 116 | ext_bazel_bin_path = os.path.join( 117 | self.build_temp, 'bazel-bin', 118 | ext.relpath, ext.target_name + '.so') 119 | 120 | ext_name = ext.name 121 | ext_dest_path = self.get_ext_fullpath(ext_name) 122 | ext_dest_dir = os.path.dirname(ext_dest_path) 123 | if not os.path.exists(ext_dest_dir): 124 | os.makedirs(ext_dest_dir) 125 | shutil.copyfile(ext_bazel_bin_path, ext_dest_path) 126 | 127 | # Copy things from /external to their own libs 128 | # E.g. /external/some_repo/some_lib --> /some_lib 129 | if ext_name.startswith('external/'): 130 | split_path = ext_name.split('/') 131 | ext_name = '/'.join(split_path[2:]) 132 | ext_dest_path = self.get_ext_fullpath(ext_name) 133 | ext_dest_dir = os.path.dirname(ext_dest_path) 134 | if not os.path.exists(ext_dest_dir): 135 | os.makedirs(ext_dest_dir) 136 | shutil.copyfile(ext_bazel_bin_path, ext_dest_path) 137 | 138 | 139 | setuptools.setup( 140 | name=PROJECT_NAME, 141 | version=__version__, 142 | description='EnvLogger: A tool for recording trajectories.', 143 | author='DeepMind', 144 | license='Apache 2.0', 145 | ext_modules=[ 146 | BazelExtension('//envlogger/backends/python:episode_info'), 147 | BazelExtension('//envlogger/backends/python:riegeli_dataset_reader'), 148 | BazelExtension('//envlogger/backends/python:riegeli_dataset_writer'), 149 | ], 150 | cmdclass={ 151 | 'build_ext': _BuildExt, 152 | 'build_py': _BuildPy, 153 | 'generate_protos': _GenerateProtoFiles, 154 | }, 155 | packages=setuptools.find_packages(), 156 | setup_requires=[ 157 | # Some software packages have problems with older versions already 158 | # installed by pip. In particular DeepMind Acme uses grpcio-tools 1.45.0 159 | # (as of 2022-04-20) so we use the same version here. 160 | 'grpcio-tools>=1.45.0', 161 | ], 162 | install_requires=[ 163 | 'absl-py', 164 | 'dm_env', 165 | 'numpy', 166 | 'protobuf>=3.14', 167 | 'setuptools!=50.0.0', # https://github.com/pypa/setuptools/issues/2350 168 | ], 169 | extras_require={ 170 | 'tfds': [ 171 | 'tensorflow', 172 | 'tfds-nightly', 173 | ], 174 | }) 175 | -------------------------------------------------------------------------------- /envlogger/step_data.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 DeepMind Technologies Limited.. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Data type that's written to and returned from storage. 17 | """ 18 | 19 | from typing import Any, NamedTuple 20 | 21 | import dm_env 22 | 23 | 24 | class StepData(NamedTuple): 25 | """Payload that's written at every dm_env.Environment.step() call. 26 | 27 | `StepData` contains the data that's written to logs (i.e. to disk somewhere). 28 | 29 | Attributes: 30 | timestep: The dm_env.TimeStep generated by the environment. 31 | action: The action that led generated `timestep`. 32 | custom_data: Any client-specific data to be written along-side `timestep` 33 | and `action`. It must be supported by converters/codec.py. 34 | """ 35 | timestep: dm_env.TimeStep 36 | action: Any 37 | custom_data: Any = None 38 | -------------------------------------------------------------------------------- /envlogger/testing/BUILD.bazel: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited.. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Testing. 16 | load("@rules_python//python:defs.bzl", "py_library") 17 | 18 | package(default_visibility = ["//visibility:public"]) 19 | 20 | py_library( 21 | name = "catch_env", 22 | srcs = ["catch_env.py"], 23 | ) 24 | -------------------------------------------------------------------------------- /envlogger/testing/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 DeepMind Technologies Limited.. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /envlogger/testing/catch_env.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 DeepMind Technologies Limited.. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Catch reinforcement learning environment.""" 17 | 18 | import dm_env 19 | from dm_env import specs 20 | import numpy as np 21 | 22 | _ACTIONS = (-1, 0, 1) # Left, no-op, right. 23 | 24 | 25 | class Catch(dm_env.Environment): 26 | """A Catch environment built on the `dm_env.Environment` class. 27 | 28 | The agent must move a paddle to intercept falling balls. Falling balls only 29 | move downwards on the column they are in. 30 | 31 | The observation is an array shape (rows, columns), with binary values: 32 | zero if a space is empty; 1 if it contains the paddle or a ball. 33 | 34 | The actions are discrete, and by default there are three available: 35 | stay, move left, and move right. 36 | 37 | The episode terminates when the ball reaches the bottom of the screen. 38 | """ 39 | 40 | def __init__(self, rows=10, columns=5, seed=1): 41 | """Initializes a new Catch environment. 42 | 43 | Args: 44 | rows: number of rows. 45 | columns: number of columns. 46 | seed: random seed for the RNG. 47 | """ 48 | self._rows = rows 49 | self._columns = columns 50 | self._rng = np.random.RandomState(seed) 51 | self._board = np.zeros((rows, columns), dtype=np.float32) 52 | self._ball_x = None 53 | self._ball_y = None 54 | self._paddle_x = None 55 | self._paddle_y = self._rows - 1 56 | self._reset_next_step = True 57 | 58 | def reset(self): 59 | """Returns the first `TimeStep` of a new episode.""" 60 | self._reset_next_step = False 61 | self._ball_x = self._rng.randint(self._columns) 62 | self._ball_y = 0 63 | self._paddle_x = self._columns // 2 64 | return dm_env.restart(self._observation()) 65 | 66 | def step(self, action): 67 | """Updates the environment according to the action.""" 68 | if self._reset_next_step: 69 | return self.reset() 70 | 71 | # Move the paddle. 72 | dx = _ACTIONS[action] 73 | self._paddle_x = np.clip(self._paddle_x + dx, 0, self._columns - 1) 74 | 75 | # Drop the ball. 76 | self._ball_y += 1 77 | 78 | # Check for termination. 79 | if self._ball_y == self._paddle_y: 80 | reward = 1. if self._paddle_x == self._ball_x else -1. 81 | self._reset_next_step = True 82 | return dm_env.termination(reward=reward, observation=self._observation()) 83 | else: 84 | return dm_env.transition(reward=0., observation=self._observation()) 85 | 86 | def observation_spec(self): 87 | """Returns the observation spec.""" 88 | return specs.BoundedArray(shape=self._board.shape, dtype=self._board.dtype, 89 | name="board", minimum=0, maximum=1) 90 | 91 | def action_spec(self): 92 | """Returns the action spec.""" 93 | return specs.DiscreteArray( 94 | dtype=int, num_values=len(_ACTIONS), name="action") 95 | 96 | def _observation(self): 97 | self._board.fill(0.) 98 | self._board[self._ball_y, self._ball_x] = 1. 99 | self._board[self._paddle_y, self._paddle_x] = 1. 100 | return self._board.copy() 101 | -------------------------------------------------------------------------------- /patches/BUILD.bazel: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited.. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | licenses(["notice"]) # Apache 2.0 16 | -------------------------------------------------------------------------------- /patches/crc32.BUILD.bazel: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited.. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | package(default_visibility = ["//visibility:public"]) 16 | 17 | licenses(["notice"]) 18 | 19 | crc32c_arm64_HDRS = [ 20 | "src/crc32c_arm64.h", 21 | ] 22 | 23 | crc32c_arm64_SRCS = [ 24 | "src/crc32c_arm64.cc", 25 | ] 26 | 27 | crc32c_sse42_HDRS = [ 28 | "src/crc32c_sse42.h", 29 | ] 30 | 31 | crc32c_sse42_SRCS = [ 32 | "src/crc32c_sse42.cc", 33 | ] 34 | 35 | crc32c_HDRS = [ 36 | "src/crc32c_arm64.h", 37 | "src/crc32c_arm64_linux_check.h", 38 | "src/crc32c_internal.h", 39 | "src/crc32c_prefetch.h", 40 | "src/crc32c_read_le.h", 41 | "src/crc32c_round_up.h", 42 | "src/crc32c_sse42.h", 43 | "src/crc32c_sse42_check.h", 44 | "include/crc32c/crc32c.h", 45 | ] 46 | 47 | crc32c_SRCS = [ 48 | "src/crc32c_portable.cc", 49 | "src/crc32c.cc", 50 | ] 51 | 52 | config_setting( 53 | name = "windows", 54 | values = {"cpu": "x64_windows"}, 55 | visibility = ["//visibility:public"], 56 | ) 57 | 58 | config_setting( 59 | name = "linux_x86_64", 60 | values = {"cpu": "k8"}, 61 | visibility = ["//visibility:public"], 62 | ) 63 | 64 | config_setting( 65 | name = "darwin", 66 | values = {"cpu": "darwin"}, 67 | visibility = ["//visibility:public"], 68 | ) 69 | 70 | sse42_copts = select({ 71 | ":windows": ["/arch:AVX"], 72 | ":linux_x86_64": ["-msse4.2"], 73 | ":darwin": ["-msse4.2"], 74 | "//conditions:default": [], 75 | }) 76 | 77 | sse42_enabled = select({ 78 | ":windows": "1", 79 | ":linux_x86_64": "1", 80 | ":darwin": "1", 81 | "//conditions:default": "0", 82 | }) 83 | 84 | genrule( 85 | name = "generate_config", 86 | srcs = ["src/crc32c_config.h.in"], 87 | outs = ["crc32c/crc32c_config.h"], 88 | cmd = """ 89 | sed -e 's/#cmakedefine01/#define/' \ 90 | -e 's/ BYTE_ORDER_BIG_ENDIAN/ BYTE_ORDER_BIG_ENDIAN 0/' \ 91 | -e 's/ HAVE_BUILTIN_PREFETCH/ HAVE_BUILTIN_PREFETCH 0/' \ 92 | -e 's/ HAVE_MM_PREFETCH/ HAVE_MM_PREFETCH 0/' \ 93 | -e 's/ HAVE_SSE42/ HAVE_SSE42 1/' \ 94 | -e 's/ HAVE_ARM64_CRC32C/ HAVE_ARM64_CRC32C 0/' \ 95 | -e 's/ HAVE_STRONG_GETAUXVAL/ HAVE_STRONG_GETAUXVAL 0/' \ 96 | -e 's/ HAVE_WEAK_GETAUXVAL/ HAVE_WEAK_GETAUXVAL 0/' \ 97 | -e 's/ CRC32C_TESTS_BUILT_WITH_GLOG/ CRC32C_TESTS_BUILT_WITH_GLOG 0/' \ 98 | < $< > $@ 99 | """, 100 | ) 101 | 102 | cc_library( 103 | name = "crc32c", 104 | srcs = crc32c_SRCS + crc32c_sse42_SRCS + crc32c_arm64_SRCS, 105 | hdrs = crc32c_HDRS + ["crc32c/crc32c_config.h"], 106 | deps = [], 107 | includes = ["include"], 108 | copts = sse42_copts, 109 | ) 110 | -------------------------------------------------------------------------------- /patches/gmp.BUILD.bazel: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited.. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Build target for the GMP library. 16 | 17 | load("@rules_cc//cc:defs.bzl", "cc_library") 18 | 19 | cc_library( 20 | name = "gmp", 21 | hdrs = ["gmpxx.h"], 22 | includes = ["."], 23 | linkopts = [ 24 | "-lgmp", 25 | "-lgmpxx", 26 | ], 27 | visibility = ["//visibility:public"], 28 | ) 29 | -------------------------------------------------------------------------------- /patches/net_zstd.BUILD.bazel: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited.. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | package(default_visibility = ["//visibility:public"]) 16 | 17 | licenses(["notice"]) 18 | 19 | cc_library( 20 | name = "zstdlib", 21 | srcs = glob([ 22 | "common/*.c", 23 | "common/*.h", 24 | "compress/*.c", 25 | "compress/*.h", 26 | "decompress/*.c", 27 | "decompress/*.h", 28 | ]), 29 | hdrs = ["zstd.h"], 30 | ) 31 | -------------------------------------------------------------------------------- /patches/proto_utils.cc.diff: -------------------------------------------------------------------------------- 1 | --- pybind11_protobuf/proto_utils.cc 2 | +++ pybind11_protobuf/proto_utils.cc 3 | @@ -13,7 +13,9 @@ 4 | #include "google/protobuf/any.pb.h" 5 | #include "google/protobuf/descriptor.pb.h" 6 | #include "google/protobuf/descriptor.h" 7 | +#include "google/protobuf/io/zero_copy_stream_impl.h" 8 | #include "google/protobuf/message.h" 9 | + 10 | #include "absl/strings/str_format.h" 11 | 12 | void pybind11_proto_casters_collision() { 13 | @@ -1024,8 +1026,8 @@ 14 | any_proto.value()); 15 | } else { 16 | bytes serialized(nullptr, any_proto.value().size()); 17 | - absl::SNPrintF(PYBIND11_BYTES_AS_STRING(serialized.ptr()), 18 | - any_proto.value().size(), any_proto.value().c_str()); 19 | + any_proto.value().copy(PYBIND11_BYTES_AS_STRING(serialized.ptr()), 20 | + any_proto.value().size()); 21 | getattr(py_proto, "ParseFromString")(serialized); 22 | return true; 23 | } 24 | -------------------------------------------------------------------------------- /patches/riegeli.diff: -------------------------------------------------------------------------------- 1 | --- 2 | python/riegeli/records/BUILD | 4 +++- 3 | 1 file changed, 3 insertions(+), 1 deletion(-) 4 | 5 | diff --git a/python/riegeli/records/BUILD b/python/riegeli/records/BUILD 6 | index cde6d546..3daa9dad 100644 7 | --- a/python/riegeli/records/BUILD 8 | +++ b/python/riegeli/records/BUILD 9 | @@ -90,5 +90,7 @@ py_library( 10 | py_proto_library( 11 | name = "records_metadata_py_pb2", 12 | srcs = ["records_metadata.proto"], 13 | - deps = ["@com_google_protobuf//:protobuf_python"], 14 | + deps = [ 15 | + "@com_google_protobuf//:well_known_types_py_pb2", 16 | + ], 17 | ) 18 | -- 19 | 2.25.1 20 | 21 | -------------------------------------------------------------------------------- /patches/snappy.BUILD.bazel: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited.. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | package(default_visibility = ["//visibility:public"]) 16 | 17 | licenses(["notice"]) 18 | 19 | cc_library( 20 | name = "snappy", 21 | srcs = [ 22 | "config.h", 23 | "snappy.cc", 24 | "snappy-internal.h", 25 | "snappy-sinksource.cc", 26 | "snappy-stubs-internal.cc", 27 | "snappy-stubs-internal.h", 28 | "snappy-stubs-public.h", 29 | ], 30 | hdrs = ["snappy-sinksource.h", "snappy.h"], 31 | copts = ["-DHAVE_CONFIG_H", "-Wno-sign-compare"], 32 | ) 33 | 34 | genrule( 35 | name = "config_h", 36 | outs = ["config.h"], 37 | cmd = "\n".join([ 38 | "cat <<'EOF' >$@", 39 | "#define HAVE_STDDEF_H 1", 40 | "#define HAVE_STDINT_H 1", 41 | "", 42 | "#ifdef __has_builtin", 43 | "# if !defined(HAVE_BUILTIN_EXPECT) && __has_builtin(__builtin_expect)", 44 | "# define HAVE_BUILTIN_EXPECT 1", 45 | "# endif", 46 | "# if !defined(HAVE_BUILTIN_CTZ) && __has_builtin(__builtin_ctzll)", 47 | "# define HAVE_BUILTIN_CTZ 1", 48 | "# endif", 49 | "#elif defined(__GNUC__) && (__GNUC__ > 3 || __GNUC__ == 3 && __GNUC_MINOR__ >= 4)", 50 | "# ifndef HAVE_BUILTIN_EXPECT", 51 | "# define HAVE_BUILTIN_EXPECT 1", 52 | "# endif", 53 | "# ifndef HAVE_BUILTIN_CTZ", 54 | "# define HAVE_BUILTIN_CTZ 1", 55 | "# endif", 56 | "#endif", 57 | "", 58 | "#ifdef __has_include", 59 | "# if !defined(HAVE_BYTESWAP_H) && __has_include()", 60 | "# define HAVE_BYTESWAP_H 1", 61 | "# endif", 62 | "# if !defined(HAVE_UNISTD_H) && __has_include()", 63 | "# define HAVE_UNISTD_H 1", 64 | "# endif", 65 | "# if !defined(HAVE_SYS_ENDIAN_H) && __has_include()", 66 | "# define HAVE_SYS_ENDIAN_H 1", 67 | "# endif", 68 | "# if !defined(HAVE_SYS_MMAN_H) && __has_include()", 69 | "# define HAVE_SYS_MMAN_H 1", 70 | "# endif", 71 | "# if !defined(HAVE_SYS_UIO_H) && __has_include()", 72 | "# define HAVE_SYS_UIO_H 1", 73 | "# endif", 74 | "#endif", 75 | "", 76 | "#ifndef SNAPPY_IS_BIG_ENDIAN", 77 | "# ifdef __s390x__", 78 | "# define SNAPPY_IS_BIG_ENDIAN 1", 79 | "# elif defined(__BYTE_ORDER__) && defined(__ORDER_BIG_ENDIAN__) && __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__", 80 | "# define SNAPPY_IS_BIG_ENDIAN 1", 81 | "# endif", 82 | "#endif", 83 | "EOF", 84 | ]), 85 | ) 86 | 87 | genrule( 88 | name = "snappy_stubs_public_h", 89 | srcs = ["snappy-stubs-public.h.in"], 90 | outs = ["snappy-stubs-public.h"], 91 | cmd = ("sed " + 92 | "-e 's/$${\\(.*\\)_01}/\\1/g' " + 93 | "-e 's/$${SNAPPY_MAJOR}/1/g' " + 94 | "-e 's/$${SNAPPY_MINOR}/1/g' " + 95 | "-e 's/$${SNAPPY_PATCHLEVEL}/4/g' " + 96 | "$< >$@"), 97 | ) 98 | -------------------------------------------------------------------------------- /patches/xtensor.BUILD.bazel: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited.. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Build target for the xtensor library. 16 | 17 | load("@rules_cc//cc:defs.bzl", "cc_library") 18 | 19 | filegroup( 20 | name = "all", 21 | srcs = glob(["include/**"]), 22 | ) 23 | 24 | cc_library( 25 | name = "xtensor", 26 | hdrs = [":all"], 27 | strip_include_prefix = "include/", 28 | deps = ["@xtl//:xtl"], 29 | visibility = ["//visibility:public"], 30 | defines = ["XTENSOR_GLIBCXX_USE_CXX11_ABI"], 31 | ) 32 | -------------------------------------------------------------------------------- /patches/xtl.BUILD.bazel: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited.. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Build target for the xtl library. 16 | 17 | load("@rules_cc//cc:defs.bzl", "cc_library") 18 | 19 | filegroup( 20 | name = "all", 21 | srcs = glob(["include/**"]), 22 | ) 23 | 24 | cc_library( 25 | name = "xtl", 26 | hdrs = [":all"], 27 | strip_include_prefix = "include/", 28 | visibility = ["//visibility:public"], 29 | ) 30 | -------------------------------------------------------------------------------- /patches/zlib.BUILD.bazel: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited.. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | package(default_visibility = ["//visibility:public"]) 16 | 17 | licenses(["notice"]) 18 | 19 | cc_library( 20 | name = "zlib", 21 | srcs = [ 22 | "adler32.c", 23 | "compress.c", 24 | "crc32.c", 25 | "crc32.h", 26 | "deflate.c", 27 | "deflate.h", 28 | "gzclose.c", 29 | "gzguts.h", 30 | "gzlib.c", 31 | "gzread.c", 32 | "gzwrite.c", 33 | "infback.c", 34 | "inffast.c", 35 | "inffast.h", 36 | "inffixed.h", 37 | "inflate.c", 38 | "inflate.h", 39 | "inftrees.c", 40 | "inftrees.h", 41 | "trees.c", 42 | "trees.h", 43 | "uncompr.c", 44 | "zutil.c", 45 | "zutil.h", 46 | ], 47 | hdrs = [ 48 | "zconf.h", 49 | "zlib.h", 50 | ], 51 | copts = ["-Wno-implicit-function-declaration"], 52 | includes = ["."], 53 | ) 54 | --------------------------------------------------------------------------------