├── .bazelrc
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── WORKSPACE.bazel
├── docker
├── Dockerfile
└── build.sh
├── docs
└── images
│ ├── envlogger.png
│ └── timings.png
├── envlogger
├── BUILD.bazel
├── __init__.py
├── backends
│ ├── BUILD.bazel
│ ├── __init__.py
│ ├── backend_reader.py
│ ├── backend_type.py
│ ├── backend_writer.py
│ ├── cc
│ │ ├── BUILD.bazel
│ │ ├── __init__.py
│ │ ├── episode_info.h
│ │ ├── riegeli_dataset_io_constants.h
│ │ ├── riegeli_dataset_io_test.cc
│ │ ├── riegeli_dataset_reader.cc
│ │ ├── riegeli_dataset_reader.h
│ │ ├── riegeli_dataset_reader_test.cc
│ │ ├── riegeli_dataset_writer.cc
│ │ ├── riegeli_dataset_writer.h
│ │ ├── riegeli_dataset_writer_test.cc
│ │ ├── riegeli_shard_io_test.cc
│ │ ├── riegeli_shard_reader.cc
│ │ ├── riegeli_shard_reader.h
│ │ ├── riegeli_shard_reader_test.cc
│ │ ├── riegeli_shard_writer.cc
│ │ ├── riegeli_shard_writer.h
│ │ └── riegeli_shard_writer_test.cc
│ ├── cross_language_test
│ │ ├── BUILD.bazel
│ │ ├── cc_reader.cc
│ │ ├── cross_language_test.py
│ │ └── py_writer.py
│ ├── in_memory_backend.py
│ ├── in_memory_backend_test.py
│ ├── python
│ │ ├── BUILD.bazel
│ │ ├── __init__.py
│ │ ├── episode_info.cc
│ │ ├── episode_info_test.py
│ │ ├── riegeli_dataset_reader.cc
│ │ ├── riegeli_dataset_test.py
│ │ └── riegeli_dataset_writer.cc
│ ├── riegeli_backend_reader.py
│ ├── riegeli_backend_writer.py
│ ├── riegeli_backend_writer_test.py
│ ├── rlds_utils.py
│ ├── rlds_utils_test.py
│ ├── schedulers.py
│ ├── schedulers_test.py
│ ├── tfds_backend_testlib.py
│ ├── tfds_backend_writer.py
│ └── tfds_backend_writer_test.py
├── converters
│ ├── BUILD.bazel
│ ├── __init__.py
│ ├── codec.py
│ ├── codec_test.py
│ ├── make_visitor.h
│ ├── make_visitor_test.cc
│ ├── spec_codec.py
│ ├── spec_codec_test.py
│ ├── xtensor_codec.cc
│ ├── xtensor_codec.h
│ └── xtensor_codec_test.cc
├── environment_logger.py
├── environment_logger_test.py
├── environment_wrapper.py
├── examples
│ ├── BUILD.bazel
│ ├── __init__.py
│ ├── random_agent_catch.py
│ └── tfds_random_agent_catch.py
├── platform
│ ├── BUILD.bazel
│ ├── bundle.h
│ ├── default
│ │ ├── BUILD.bazel
│ │ ├── bundle.cc
│ │ ├── bundle.h
│ │ ├── filesystem.cc
│ │ ├── filesystem.h
│ │ ├── parse_text_proto.h
│ │ ├── proto_testutil.h
│ │ ├── riegeli_file_reader.h
│ │ ├── riegeli_file_writer.h
│ │ ├── source_location.h
│ │ ├── status_builder.cc
│ │ ├── status_builder.h
│ │ └── status_macros.h
│ ├── filesystem.h
│ ├── parse_text_proto.h
│ ├── proto_testutil.h
│ ├── riegeli_file_reader.h
│ ├── riegeli_file_writer.h
│ ├── status_macros.h
│ └── test_macros.h
├── proto
│ ├── BUILD.bazel
│ ├── __init__.py
│ └── storage.proto
├── reader.py
├── requirements.txt
├── setup.py
├── step_data.py
└── testing
│ ├── BUILD.bazel
│ ├── __init__.py
│ └── catch_env.py
└── patches
├── BUILD.bazel
├── crc32.BUILD.bazel
├── gmp.BUILD.bazel
├── highwayhash.BUILD.bazel
├── net_zstd.BUILD.bazel
├── proto_utils.cc.diff
├── riegeli.diff
├── snappy.BUILD.bazel
├── xtensor.BUILD.bazel
├── xtl.BUILD.bazel
└── zlib.BUILD.bazel
/.bazelrc:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited..
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # The `-std=c++17` option is required to compile the Riegeli library and
16 | # `-D_GLIBCXX_USE_CXX11_ABI=0` option is for compatibility with the shared
17 | # libraries compiled with the same setting.
18 |
19 | build -c opt
20 | build --cxxopt="-mavx"
21 | build --cxxopt="-std=c++17"
22 | build --cxxopt="-D_GLIBCXX_USE_CXX11_ABI=0"
23 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | We'd love to accept your patches and contributions to this project. There are
4 | just a few small guidelines you need to follow.
5 |
6 | ## Contributor License Agreement
7 |
8 | Contributions to this project must be accompanied by a Contributor License
9 | Agreement (CLA). You (or your employer) retain the copyright to your
10 | contribution; this simply gives us permission to use and redistribute your
11 | contributions as part of the project. Head over to
12 | to see your current agreements on file or
13 | to sign a new one.
14 |
15 | You generally only need to submit a CLA once, so if you've already submitted one
16 | (even if it was for a different project), you probably don't need to do it
17 | again.
18 |
19 | ## Code Reviews
20 |
21 | All submissions, including submissions by project members, require review. We
22 | use GitHub pull requests for this purpose. Consult
23 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
24 | information on using pull requests.
25 |
26 | ## Community Guidelines
27 |
28 | This project follows
29 | [Google's Open Source Community Guidelines](https://opensource.google/conduct/).
30 |
--------------------------------------------------------------------------------
/docker/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM ubuntu:20.04
2 |
3 | # ARG variable for Python 3 minor version.
4 | # This can be set with `--build-arg PY3_VERSION=10` for Python 3.10.
5 | ARG PY3_VERSION=10
6 |
7 | # Set up timezone to avoid getting stuck at `tzdata` setup.
8 | ENV TZ=America/Montreal
9 | RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone
10 | RUN apt update
11 | RUN apt install -y tzdata
12 |
13 | # Install `software-properties-common` to obtain add-apt-repository.
14 | RUN apt update && apt install -y software-properties-common
15 |
16 | # Add deadsnake PPA to get Python 3.10.
17 | RUN add-apt-repository ppa:deadsnakes/ppa
18 |
19 | # Install necessary packages.
20 | RUN apt-get update && apt-get install -y git curl wget software-properties-common python3.$PY3_VERSION python3.$PY3_VERSION-dev libgmp-dev gcc-9 g++-9 tmux vim
21 |
22 | # Install distutils if not Python 3.10 to get `distutils.util`.
23 | RUN if [ "$PY3_VERSION" != "10" ]; then apt install -y python3-distutils python3-apt; fi
24 |
25 | # Install pip.
26 | RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python3.$PY3_VERSION
27 |
28 | # Download bazel.
29 | RUN wget https://github.com/bazelbuild/bazel/releases/download/6.2.1/bazel-6.2.1-linux-x86_64
30 | RUN chmod +x /bazel-6.2.1-linux-x86_64
31 | RUN mv /bazel-6.2.1-linux-x86_64 /usr/bin/bazel
32 |
33 | # Add python alternatives.
34 | RUN update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.$PY3_VERSION 1
35 |
36 | # Override gcc/g++.
37 | RUN update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-9 60 --slave /usr/bin/g++ g++ /usr/bin/g++-9
38 |
39 | # Install some basic things for all python versions.
40 | RUN echo 1 | update-alternatives --config python3
41 | # Get latest `pip`.
42 | RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python3
43 | RUN python3 -m pip install --no-cache --upgrade pip setuptools
44 |
45 | ADD /envlogger/requirements.txt /tmp/requirements.txt
46 | RUN python3 -m pip install --no-cache -r /tmp/requirements.txt grpcio-tools
47 |
48 |
49 | # Add `python` so that `/usr/bin/env` finds it. This is used by `bazel`.
50 | RUN ln -s /usr/bin/python3 /usr/bin/python
51 |
52 | ADD . /envlogger/
53 | WORKDIR /envlogger
54 |
--------------------------------------------------------------------------------
/docker/build.sh:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited..
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | #!/bin/bash
16 | set -e
17 | DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
18 | cd $DIR/..
19 |
20 | # Default to Python 3.11.
21 | PY3_VERSION=${1:-11}
22 | echo Python version 3.${PY3_VERSION}
23 |
24 | # Default image label to "envlogger"
25 | IMAGE_LABEL=${2:-envlogger}
26 | echo Output docker image label: ${IMAGE_LABEL}
27 |
28 | docker build -t ${IMAGE_LABEL} -f docker/Dockerfile . --build-arg PY3_VERSION=${PY3_VERSION}
29 |
--------------------------------------------------------------------------------
/docs/images/envlogger.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-deepmind/envlogger/db88178769bf7e6a01e2ec52ff2b816db39adb2e/docs/images/envlogger.png
--------------------------------------------------------------------------------
/docs/images/timings.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-deepmind/envlogger/db88178769bf7e6a01e2ec52ff2b816db39adb2e/docs/images/timings.png
--------------------------------------------------------------------------------
/envlogger/BUILD.bazel:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited..
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # Build targets for environment logger.
16 | load("@rules_python//python:defs.bzl", "py_library")
17 |
18 | package(default_visibility = ["//visibility:public"])
19 |
20 | py_library(
21 | name = "envlogger",
22 | srcs = ["__init__.py"],
23 | deps = [
24 | ":environment_logger",
25 | ":reader",
26 | ":step_data",
27 | "//envlogger/backends:backend_type",
28 | "//envlogger/backends:riegeli_backend_writer",
29 | "//envlogger/backends:schedulers",
30 | "//envlogger/proto:storage_py_pb2",
31 | ],
32 | )
33 |
34 | py_library(
35 | name = "environment_logger",
36 | srcs = ["environment_logger.py"],
37 | deps = [
38 | ":environment_wrapper",
39 | ":step_data",
40 | "//envlogger/backends:backend_type",
41 | "//envlogger/backends:backend_writer",
42 | "//envlogger/backends:in_memory_backend",
43 | "//envlogger/backends:riegeli_backend_writer",
44 | "//envlogger/converters:spec_codec",
45 | ],
46 | )
47 |
48 | py_test(
49 | name = "environment_logger_test",
50 | srcs = ["environment_logger_test.py"],
51 | deps = [
52 | ":environment_logger",
53 | ":reader",
54 | ":step_data",
55 | "//envlogger/backends:backend_writer",
56 | "//envlogger/backends:in_memory_backend",
57 | "//envlogger/backends:schedulers",
58 | "//envlogger/converters:codec",
59 | "//envlogger/converters:spec_codec",
60 | "//envlogger/proto:storage_py_pb2",
61 | "//envlogger/testing:catch_env",
62 | "@com_google_riegeli//python/riegeli",
63 | ],
64 | )
65 |
66 | py_library(
67 | name = "environment_wrapper",
68 | srcs = ["environment_wrapper.py"],
69 | deps = [
70 | ],
71 | )
72 |
73 | py_library(
74 | name = "reader",
75 | srcs = ["reader.py"],
76 | data = [
77 | "//envlogger/backends/python:episode_info.so",
78 | "//envlogger/backends/python:riegeli_dataset_reader.so",
79 | ],
80 | deps = [
81 | ":step_data",
82 | "//envlogger/backends:backend_reader",
83 | "//envlogger/backends:backend_type",
84 | "//envlogger/backends:in_memory_backend",
85 | "//envlogger/backends:riegeli_backend_reader",
86 | "//envlogger/converters:spec_codec",
87 | ],
88 | )
89 |
90 | py_library(
91 | name = "step_data",
92 | srcs = ["step_data.py"],
93 | deps = [
94 | ],
95 | )
96 |
--------------------------------------------------------------------------------
/envlogger/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 DeepMind Technologies Limited..
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """A one-stop import for commonly used modules in EnvLogger."""
17 |
18 | from envlogger import environment_logger
19 | from envlogger import reader
20 | from envlogger import step_data
21 | from envlogger.backends import backend_type
22 | from envlogger.backends import riegeli_backend_writer
23 | from envlogger.backends import schedulers
24 | from envlogger.proto import storage_pb2
25 |
26 | EnvLogger = environment_logger.EnvLogger
27 | Reader = reader.Reader
28 | BackendType = backend_type.BackendType
29 | StepData = step_data.StepData
30 | Scheduler = schedulers.Scheduler
31 | RiegeliBackendWriter = riegeli_backend_writer.RiegeliBackendWriter
32 | Data = storage_pb2.Data
33 | Datum = storage_pb2.Datum
34 |
--------------------------------------------------------------------------------
/envlogger/backends/BUILD.bazel:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited..
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # Backends.
16 | load("@rules_python//python:defs.bzl", "py_library", "py_test")
17 |
18 | package(default_visibility = ["//visibility:public"])
19 |
20 | py_library(
21 | name = "backend_type",
22 | srcs = ["backend_type.py"],
23 | )
24 |
25 | py_library(
26 | name = "schedulers",
27 | srcs = ["schedulers.py"],
28 | deps = [
29 | "//envlogger:step_data",
30 | ],
31 | )
32 |
33 | py_test(
34 | name = "schedulers_test",
35 | srcs = ["schedulers_test.py"],
36 | deps = [
37 | ":schedulers",
38 | "//envlogger:step_data",
39 | ],
40 | )
41 |
42 | py_library(
43 | name = "backend_reader",
44 | srcs = ["backend_reader.py"],
45 | data = ["//envlogger/backends/python:episode_info.so"],
46 | deps = [
47 | "//envlogger:step_data",
48 | "//envlogger/converters:codec",
49 | ],
50 | )
51 |
52 | py_library(
53 | name = "backend_writer",
54 | srcs = ["backend_writer.py"],
55 | deps = [
56 | "//envlogger:step_data",
57 | ],
58 | )
59 |
60 | py_library(
61 | name = "in_memory_backend",
62 | srcs = ["in_memory_backend.py"],
63 | data = ["//envlogger/backends/python:episode_info.so"],
64 | deps = [
65 | ":backend_reader",
66 | ":backend_writer",
67 | "//envlogger:step_data",
68 | ],
69 | )
70 |
71 | py_test(
72 | name = "in_memory_backend_test",
73 | srcs = ["in_memory_backend_test.py"],
74 | deps = [
75 | ":backend_writer",
76 | ":in_memory_backend",
77 | ":schedulers",
78 | "//envlogger:step_data",
79 | "//envlogger/testing:catch_env",
80 | ],
81 | )
82 |
83 | py_library(
84 | name = "riegeli_backend_reader",
85 | srcs = ["riegeli_backend_reader.py"],
86 | data = [
87 | "//envlogger/backends/python:episode_info.so",
88 | "//envlogger/backends/python:riegeli_dataset_reader.so",
89 | ],
90 | deps = [
91 | ":backend_reader",
92 | "//envlogger:step_data",
93 | "//envlogger/converters:codec",
94 | "//envlogger/proto:storage_py_pb2",
95 | ],
96 | )
97 |
98 | py_library(
99 | name = "riegeli_backend_writer",
100 | srcs = ["riegeli_backend_writer.py"],
101 | data = [
102 | "//envlogger/backends/python:riegeli_dataset_writer.so",
103 | ],
104 | deps = [
105 | ":backend_writer",
106 | ":schedulers",
107 | "//envlogger:step_data",
108 | "//envlogger/converters:codec",
109 | ],
110 | )
111 |
112 | py_test(
113 | name = "riegeli_backend_writer_test",
114 | srcs = ["riegeli_backend_writer_test.py"],
115 | deps = [
116 | ":backend_writer",
117 | ":riegeli_backend_reader",
118 | ":riegeli_backend_writer",
119 | ":schedulers",
120 | "//envlogger:step_data",
121 | "//envlogger/testing:catch_env",
122 | ],
123 | )
124 |
125 | py_library(
126 | name = "rlds_utils",
127 | srcs = ["rlds_utils.py"],
128 | deps = [
129 | "//envlogger:step_data",
130 | ],
131 | )
132 |
133 | py_library(
134 | name = "tfds_backend_writer",
135 | srcs = ["tfds_backend_writer.py"],
136 | deps = [
137 | ":backend_writer",
138 | ":rlds_utils",
139 | "//envlogger:step_data",
140 | ],
141 | )
142 |
143 | py_library(
144 | name = "tfds_backend_testlib",
145 | srcs = ["tfds_backend_testlib.py"],
146 | deps = [
147 | ":backend_writer",
148 | ":schedulers",
149 | ":tfds_backend_writer",
150 | "//envlogger:step_data",
151 | "//envlogger/testing:catch_env",
152 | ],
153 | )
154 |
155 | py_test(
156 | name = "rlds_utils_test",
157 | srcs = ["rlds_utils_test.py"],
158 | deps = [
159 | ":rlds_utils",
160 | ":tfds_backend_testlib",
161 | "//envlogger:step_data",
162 | ],
163 | )
164 |
165 | py_test(
166 | name = "tfds_backend_writer_test",
167 | srcs = ["tfds_backend_writer_test.py"],
168 | deps = [
169 | ":backend_writer",
170 | ":rlds_utils",
171 | ":tfds_backend_testlib",
172 | ":tfds_backend_writer",
173 | "//envlogger:step_data",
174 | ],
175 | )
176 |
--------------------------------------------------------------------------------
/envlogger/backends/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 DeepMind Technologies Limited..
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
--------------------------------------------------------------------------------
/envlogger/backends/backend_reader.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 DeepMind Technologies Limited..
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Abstract interface for reading trajectories."""
17 |
18 | import abc
19 | from collections.abc import Callable, Iterator, Sequence
20 | from typing import Any, Generic, Optional, TypeVar, Union
21 |
22 | from absl import logging
23 | from envlogger import step_data
24 | from envlogger.backends.python import episode_info
25 | from envlogger.converters import codec
26 |
27 |
28 | T = TypeVar('T')
29 |
30 |
31 | class _SequenceAdapter(Generic[T], Sequence[T]):
32 | """Convenient visitor for episodes/steps."""
33 |
34 | def __init__(self, count: int, get_nth_item: Callable[[int], T]):
35 | """Constructor.
36 |
37 | Args:
38 | count: Total number of items.
39 | get_nth_item: Function to get the nth item.
40 | """
41 | self._count = count
42 | self._index = 0
43 | self._get_nth_item = get_nth_item
44 |
45 | def __getitem__(self, index: Union[int, slice]) -> Union[T, list[T]]:
46 | """Retrieves items from this sequence.
47 |
48 | Args:
49 | index: item index or slice of indices.
50 |
51 | Returns:
52 | The item at `index` if index is of type `int`, or a list of items if
53 | `index` is a slice. If `index` is a negative integer, then it is
54 | equivalent to index + len(self).
55 |
56 | Raises:
57 | IndexError: if index is an integer outside of the bounds [-length,
58 | length - 1].
59 | """
60 | if isinstance(index, slice):
61 | indices = index.indices(len(self))
62 | return [self._get_nth_item(i) for i in range(*indices)]
63 | if index >= self._count or index < -self._count:
64 | raise IndexError(f'`index`=={index} is out of the range [{-self._count}, '
65 | f'{self._count - 1}].')
66 | index = index if index >= 0 else index + self._count
67 | return self._get_nth_item(index)
68 |
69 | def __len__(self) -> int:
70 | return self._count
71 |
72 | def __iter__(self) -> Iterator[T]:
73 | while self._index < len(self):
74 | yield self[self._index]
75 | self._index += 1
76 | self._index = 0
77 |
78 | def __next__(self) -> T:
79 | if self._index < len(self):
80 | index = self._index
81 | self._index += 1
82 | return self[index]
83 | else:
84 | raise StopIteration()
85 |
86 |
87 | class BackendReader(metaclass=abc.ABCMeta):
88 | """Base class for trajectory readers."""
89 |
90 | def __init__(self):
91 | self._init_visitors()
92 |
93 | def copy(self) -> 'BackendReader':
94 | """Returns a copy of self."""
95 |
96 | c = self._copy()
97 | c._init_visitors()
98 | return c
99 |
100 | @abc.abstractmethod
101 | def _copy(self) -> 'BackendReader':
102 | """Implementation-specific copy behavior."""
103 |
104 | def _init_visitors(self):
105 | """Initializes visitors."""
106 |
107 | logging.info('Creating visitors.')
108 | self._steps = _SequenceAdapter(
109 | count=self._get_num_steps(), get_nth_item=self._get_nth_step)
110 | self._episodes = _SequenceAdapter(
111 | count=self._get_num_episodes(), get_nth_item=self._get_nth_episode)
112 | self._episode_metadata = _SequenceAdapter(
113 | count=self._get_num_episodes(),
114 | get_nth_item=self._get_nth_episode_metadata)
115 | logging.info('Done creating visitors.')
116 |
117 | @abc.abstractmethod
118 | def _get_nth_step(self, i: int) -> step_data.StepData:
119 | pass
120 |
121 | @abc.abstractmethod
122 | def _get_num_steps(self) -> int:
123 | pass
124 |
125 | @abc.abstractmethod
126 | def _get_num_episodes(self) -> int:
127 | pass
128 |
129 | @abc.abstractmethod
130 | def _get_nth_episode_info(self,
131 | i: int,
132 | include_metadata: bool = False
133 | ) -> episode_info.EpisodeInfo:
134 | pass
135 |
136 | def _get_nth_episode(self, i: int) -> Sequence[step_data.StepData]:
137 | """Yields timesteps for episode `i` (0-based)."""
138 | episode = self._get_nth_episode_info(i, include_metadata=False)
139 |
140 | def get_nth_step_from_episode(j: int):
141 | return self._get_nth_step(episode.start + j)
142 |
143 | return _SequenceAdapter(
144 | count=episode.num_steps, get_nth_item=get_nth_step_from_episode)
145 |
146 | def _get_nth_episode_metadata(self, i: int) -> Optional[Any]:
147 | """Returns the metadata for episode `i` (0-based)."""
148 | episode = self._get_nth_episode_info(i, include_metadata=True)
149 | return codec.decode(episode.metadata)
150 |
151 | def __enter__(self):
152 | return self
153 |
154 | def __exit__(self, exc_type, exc_value, tb):
155 | self.close()
156 |
157 | def __del__(self):
158 | self.close()
159 |
160 | @abc.abstractmethod
161 | def close(self) -> None:
162 | pass
163 |
164 | @abc.abstractmethod
165 | def metadata(self) -> dict[str, Any]:
166 | pass
167 |
168 | @property
169 | def episodes(self) -> Sequence[Sequence[step_data.StepData]]:
170 | return self._episodes
171 |
172 | def episode_metadata(self) -> Sequence[Optional[Any]]:
173 | return self._episode_metadata
174 |
175 | @property
176 | def steps(self) -> Sequence[step_data.StepData]:
177 | return self._steps
178 |
--------------------------------------------------------------------------------
/envlogger/backends/backend_type.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 DeepMind Technologies Limited..
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """An enumeration that specifies the logging backend."""
17 |
18 | import enum
19 |
20 |
21 | class BackendType(enum.IntEnum):
22 | RIEGELI = 0
23 | IN_MEMORY = 1
24 |
--------------------------------------------------------------------------------
/envlogger/backends/backend_writer.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 DeepMind Technologies Limited..
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Abstract trajectory logging interface."""
17 |
18 | import abc
19 | from typing import Any, Optional
20 |
21 | from envlogger import step_data
22 | from envlogger.backends import schedulers
23 |
24 |
25 | class BackendWriter(metaclass=abc.ABCMeta):
26 | """Abstract trajectory logging interface."""
27 |
28 | def __init__(
29 | self,
30 | metadata: Optional[dict[str, Any]] = None,
31 | scheduler: Optional[schedulers.Scheduler] = None,
32 | ):
33 | """BackendWriter base class.
34 |
35 | Args:
36 | metadata: Any dataset-level custom data to be written.
37 | scheduler: A callable that takes the current timestep, current
38 | action, the environment itself and returns True if the current step
39 | should be logged, False otherwise. This function is called _before_
40 | `step_fn`, meaning that if it returns False, `step_fn` will not be
41 | called at all. NOTE: This scheduler should NEVER skip the first timestep
42 | in the episode, otherwise `EnvLogger` will not know that such episode
43 | really exists.
44 | """
45 | self._scheduler = scheduler
46 | self._metadata = metadata
47 |
48 | def record_step(self, data: step_data.StepData, is_new_episode: bool) -> None:
49 | if (self._scheduler is not None and not self._scheduler(data)):
50 | return
51 | self._record_step(data, is_new_episode)
52 |
53 | @abc.abstractmethod
54 | def set_episode_metadata(self, data: Any) -> None:
55 | pass
56 |
57 | @abc.abstractmethod
58 | def _record_step(self, data: step_data.StepData,
59 | is_new_episode: bool) -> None:
60 | pass
61 |
62 | @abc.abstractmethod
63 | def close(self) -> None:
64 | pass
65 |
66 | def __del__(self):
67 | self.close()
68 |
69 | def metadata(self):
70 | return self._metadata
71 |
--------------------------------------------------------------------------------
/envlogger/backends/cc/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 DeepMind Technologies Limited..
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
--------------------------------------------------------------------------------
/envlogger/backends/cc/episode_info.h:
--------------------------------------------------------------------------------
1 | // Copyright 2024 DeepMind Technologies Limited..
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // http://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | // Struct for efficiently retrieving episodic information.
16 | #ifndef THIRD_PARTY_PY_ENVLOGGER_BACKENDS_CC_EPISODE_INFO_H_
17 | #define THIRD_PARTY_PY_ENVLOGGER_BACKENDS_CC_EPISODE_INFO_H_
18 |
19 | #include
20 | #include
21 |
22 | #include "envlogger/proto/storage.pb.h"
23 |
24 | namespace envlogger {
25 |
26 | struct EpisodeInfo {
27 | // The step index where this episode starts.
28 | // This can be either a local step within a single trajectory file, or a
29 | // global step across many shards.
30 | int64_t start;
31 | // The number of steps in this episode.
32 | int64_t num_steps;
33 | // Optional metadata which is only filled if requested.
34 | std::optional metadata;
35 | };
36 |
37 | } // namespace envlogger
38 |
39 | #endif // THIRD_PARTY_PY_ENVLOGGER_BACKENDS_CC__EPISODE_INFO_H_
40 |
--------------------------------------------------------------------------------
/envlogger/backends/cc/riegeli_dataset_io_constants.h:
--------------------------------------------------------------------------------
1 | // Copyright 2024 DeepMind Technologies Limited..
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // http://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | #ifndef THIRD_PARTY_PY_ENVLOGGER_BACKENDS_CC_RIEGELI_DATASET_IO_CONSTANTS_H_
16 | #define THIRD_PARTY_PY_ENVLOGGER_BACKENDS_CC_RIEGELI_DATASET_IO_CONSTANTS_H_
17 |
18 | #include "absl/strings/string_view.h"
19 |
20 | namespace envlogger {
21 | // These constants are used internally in our codebase and should not be relied
22 | // upon by clients.
23 | namespace internal {
24 | // The riegeli filename for metadata set at Init() time.
25 | inline constexpr absl::string_view kMetadataFilename = "metadata.riegeli";
26 | // Steps (timesteps, actions and per-step metadata).
27 | inline constexpr absl::string_view kStepsFilename = "steps.riegeli";
28 | // Step offsets for faster seeking into kStepsFilename.
29 | inline constexpr absl::string_view kStepOffsetsFilename =
30 | "step_offsets.riegeli";
31 | // Episodic metadata.
32 | inline constexpr absl::string_view kEpisodeMetadataFilename =
33 | "episode_metadata.riegeli";
34 | // Episode offsets for faster seeking into kEpisodeMetadataFilename and also
35 | // episode starts.
36 | inline constexpr absl::string_view kEpisodeIndexFilename =
37 | "episode_index.riegeli";
38 | } // namespace internal
39 | } // namespace envlogger
40 |
41 | #endif // THIRD_PARTY_PY_ENVLOGGER_BACKENDS_CC_RIEGELI_DATASET_IO_CONSTANTS_H_
42 |
--------------------------------------------------------------------------------
/envlogger/backends/cc/riegeli_dataset_io_test.cc:
--------------------------------------------------------------------------------
1 | // Copyright 2024 DeepMind Technologies Limited..
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // http://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | #include
16 | #include
17 | #include
18 | #include
19 | #include
20 | #include
21 |
22 | #include "gmock/gmock.h"
23 | #include "gtest/gtest.h"
24 | #include "absl/flags/flag.h"
25 | #include "absl/status/status.h"
26 | #include "absl/strings/string_view.h"
27 | #include "absl/time/clock.h"
28 | #include "absl/time/time.h"
29 | #include "envlogger/backends/cc/episode_info.h"
30 | #include "envlogger/backends/cc/riegeli_dataset_reader.h"
31 | #include "envlogger/backends/cc/riegeli_dataset_writer.h"
32 | #include "envlogger/backends/cc/riegeli_shard_writer.h"
33 | #include "envlogger/platform/filesystem.h"
34 | #include "envlogger/platform/parse_text_proto.h"
35 | #include "envlogger/platform/proto_testutil.h"
36 | #include "envlogger/platform/test_macros.h"
37 | #include "envlogger/proto/storage.pb.h"
38 | #include "riegeli/records/record_reader.h"
39 | #include "riegeli/records/record_writer.h"
40 |
41 | namespace envlogger {
42 | namespace {
43 |
44 | using ::testing::Eq;
45 | using ::testing::Not;
46 | using ::testing::Value;
47 |
48 | // A simple matcher to compare the output of RiegeliDatasetReader::Episode().
49 | MATCHER_P2(EqualsEpisode, start_index, num_steps, "") {
50 | return Value(arg.start, start_index) && Value(arg.num_steps, num_steps);
51 | }
52 |
53 | TEST(RiegeliDatasetTest, MetadataTest) {
54 | const std::string data_dir =
55 | file::JoinPath(getenv("TEST_TMPDIR"), "metadata");
56 | const Data metadata =
57 | ParseTextProtoOrDie("datum: { values: { int32_values: 1234 } }");
58 | const int max_episodes_per_shard = -1;
59 |
60 | ENVLOGGER_EXPECT_OK(file::CreateDir(data_dir));
61 | {
62 | RiegeliDatasetWriter writer;
63 | ENVLOGGER_EXPECT_OK(writer.Init(data_dir, metadata, max_episodes_per_shard,
64 | "transpose,brotli:6,chunk_size:1M"));
65 | // Write a single step to pass RiegeliDatasetReader::Init()'s strict checks.
66 | Data data;
67 | data.mutable_datum()->mutable_values()->add_float_values(1.234f);
68 | writer.AddStep(data, /*is_new_episode=*/true);
69 | writer.Flush();
70 | }
71 |
72 | RiegeliDatasetReader reader;
73 | ENVLOGGER_EXPECT_OK(reader.Init(data_dir));
74 | const auto actual_metadata = reader.Metadata();
75 | EXPECT_THAT(actual_metadata, Not(Eq(std::nullopt)));
76 | EXPECT_THAT(*actual_metadata, EqualsProto(metadata));
77 |
78 | ENVLOGGER_EXPECT_OK(file::RecursivelyDelete(data_dir));
79 | }
80 |
81 | } // namespace
82 | } // namespace envlogger
83 |
--------------------------------------------------------------------------------
/envlogger/backends/cc/riegeli_dataset_reader.h:
--------------------------------------------------------------------------------
1 | // Copyright 2024 DeepMind Technologies Limited..
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // http://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | #ifndef THIRD_PARTY_PY_ENVLOGGER_BACKENDS_CC_RIEGELI_DATASET_READER_H_
16 | #define THIRD_PARTY_PY_ENVLOGGER_BACKENDS_CC_RIEGELI_DATASET_READER_H_
17 |
18 | #include
19 | #include
20 | #include
21 | #include
22 | #include
23 | #include
24 | #include
25 |
26 | #include "google/protobuf/message.h"
27 | #include "absl/status/status.h"
28 | #include "absl/status/statusor.h"
29 | #include "absl/strings/string_view.h"
30 | #include "envlogger/backends/cc/episode_info.h"
31 | #include "envlogger/backends/cc/riegeli_shard_reader.h"
32 | #include "envlogger/backends/cc/riegeli_shard_writer.h"
33 | #include "envlogger/proto/storage.pb.h"
34 |
35 | namespace envlogger {
36 |
37 | // A RiegeliDatasetReader contains trajectory information for the entire
38 | // lifetime of a single Reinforcement Learning (RL) actor. Each
39 | // RiegeliDatasetReader represents a one dimensional trajectory composed of 0 or
40 | // more steps. Steps can be grouped into non-overlapping, contiguous episodes
41 | // for episodic RL environments.
42 | //
43 | // Internally, a RiegeliDatasetReader is sharded into chronologically ordered
44 | // sub directories called "timestamp directories". Each of these sub directories
45 | // contain their own trajectories and indexes, so in order to find a specific
46 | // step or episode we first need to determine the sub directory and then their
47 | // internal index. An episode is never split between two shards.
48 | class RiegeliDatasetReader {
49 | public:
50 | RiegeliDatasetReader() = default;
51 | RiegeliDatasetReader(RiegeliDatasetReader&&) = default;
52 | ~RiegeliDatasetReader();
53 |
54 | // Clones a RiegeliDatasetReader.
55 | // The cloned reader can be safely used in a different thread.
56 | absl::StatusOr Clone();
57 |
58 | // Releases resources and closes the underlying files.
59 | void Close();
60 |
61 | absl::Status Init(absl::string_view data_dir);
62 |
63 | // Returns metadata associated with this RiegeliDatasetReader.
64 | std::optional Metadata() const;
65 |
66 | int64_t NumSteps() const;
67 | int64_t NumEpisodes() const;
68 |
69 | // Returns step data at index `step_index`.
70 | // Returns nullopt if `step_index` is not in [0, NumSteps()).
71 | template
72 | std::optional Step(int64_t step_index);
73 |
74 | // Returns information for accessing a specific episode.
75 | //
76 | // NOTE: `start` in the returned object refers to the "global" step index, not
77 | // the local one in the shard. That is, this value can be passed to Step()
78 | // below without any modification.
79 | //
80 | // Returns nullopt if `episode_index` is not in [0, NumEpisodes()).
81 | std::optional Episode(int64_t episode_index,
82 | bool include_metadata = false);
83 |
84 | // Returns a shard reader for the specified episode index.
85 | absl::StatusOr GetShard(int64_t episode_index);
86 |
87 | std::string DataDir() const { return data_dir_; }
88 |
89 | private:
90 | // A Shard represents a single timestamp directory.
91 | struct Shard {
92 | // The path to the timestamp directory.
93 | std::string timestamp_dir;
94 | // The index internal to this timestamp directory.
95 | RiegeliShardReader index;
96 |
97 | // The global step index at which this shard starts.
98 | int64_t global_step_index = -1;
99 |
100 | // The cumulative number of steps up to this shard (in the order that it's
101 | // inserted in shards_).
102 | int64_t cumulative_steps = 0;
103 | // The cumulative number of episodes up to this shard (in the order that
104 | // it's inserted in shards_).
105 | int64_t cumulative_episodes = 0;
106 | };
107 |
108 | // Returns the first element in `shards_` that is not greater than
109 | // `global_index`.
110 | std::pair FindShard(
111 | int64_t global_index, std::function extractor);
112 |
113 | // The data directory associated with this reader.
114 | std::string data_dir_;
115 |
116 | // The total number of steps across all timestamp directories.
117 | int64_t total_num_steps_ = 0;
118 | // The total number of episodes across all timestamp directories.
119 | // Note: This implementation differs from the Python one because the latter
120 | // depends on an explicit "max_episodes_per_file" entry in the metadata while
121 | // this one calculates the same thing via binary search.
122 | int64_t total_num_episodes_ = 0;
123 |
124 | std::optional metadata_;
125 |
126 | // The list of all shards in this RiegeliDatasetReader.
127 | std::vector shards_;
128 | };
129 |
130 | ////////////////////////////////////////////////////////////////////////////////
131 | // Implementation details of RiegeliDatasetReader::Step.
132 | ////////////////////////////////////////////////////////////////////////////////
133 |
134 | template
135 | std::optional RiegeliDatasetReader::Step(int64_t step_index) {
136 | if (step_index < 0 || step_index >= NumSteps()) return std::nullopt;
137 |
138 | Shard* shard = nullptr;
139 | int64_t local_step_index = -1;
140 | std::tie(shard, local_step_index) =
141 | FindShard(step_index, [](const RiegeliDatasetReader::Shard& shard) {
142 | return shard.cumulative_steps;
143 | });
144 | return shard->index.Step(local_step_index);
145 | }
146 |
147 | } // namespace envlogger
148 |
149 | #endif // THIRD_PARTY_PY_ENVLOGGER_BACKENDS_CC_RIEGELI_DATASET_READER_H_
150 |
--------------------------------------------------------------------------------
/envlogger/backends/cc/riegeli_dataset_writer.cc:
--------------------------------------------------------------------------------
1 | // Copyright 2024 DeepMind Technologies Limited..
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // http://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | #include "envlogger/backends/cc/riegeli_dataset_writer.h"
16 |
17 | #include
18 | #include
19 | #include
20 | #include
21 | #include
22 | #include
23 |
24 | #include "glog/logging.h"
25 | #include "google/protobuf/message.h"
26 | #include "absl/status/status.h"
27 | #include "absl/status/statusor.h"
28 | #include "absl/strings/string_view.h"
29 | #include "absl/time/clock.h"
30 | #include "absl/time/time.h"
31 | #include "envlogger/backends/cc/episode_info.h"
32 | #include "envlogger/backends/cc/riegeli_dataset_io_constants.h"
33 | #include "envlogger/backends/cc/riegeli_shard_reader.h"
34 | #include "envlogger/backends/cc/riegeli_shard_writer.h"
35 | #include "envlogger/platform/filesystem.h"
36 | #include "envlogger/platform/riegeli_file_writer.h"
37 | #include "envlogger/platform/status_macros.h"
38 | #include "envlogger/proto/storage.pb.h"
39 | #include "riegeli/base/types.h"
40 | #include "riegeli/records/record_reader.h"
41 | #include "riegeli/records/record_writer.h"
42 |
43 | namespace envlogger {
44 |
45 | namespace {
46 |
47 | // Writes `data` as a single record in the Riegeli file pointed by `filepath`.
48 | absl::Status WriteSingleRiegeliRecord(const absl::string_view filepath,
49 | const Data& data) {
50 | riegeli::RecordWriter writer(
51 | RiegeliFileWriter(filepath),
52 | riegeli::RecordWriterBase::Options().set_transpose(true));
53 | if (!writer.WriteRecord(data)) return writer.status();
54 | if (!writer.Flush(riegeli::FlushType::kFromMachine)) return writer.status();
55 | if (!writer.Close()) return writer.status();
56 |
57 | return absl::OkStatus();
58 | }
59 |
60 | std::string NewTimestampDirName(absl::Time time) {
61 | return absl::FormatTime("%Y%m%dT%H%M%S%E6f", time, absl::UTCTimeZone());
62 | }
63 |
64 | absl::Status CreateRiegeliShardWriter(absl::string_view data_dir,
65 | absl::string_view writer_options,
66 | RiegeliShardWriter* writer) {
67 | const std::string dirname = NewTimestampDirName(absl::Now());
68 | const std::string timestamp_dir = file::JoinPath(data_dir, dirname);
69 | ENVLOGGER_RETURN_IF_ERROR(file::CreateDir(timestamp_dir));
70 | writer->Flush(); // Flush before creating a new one.
71 | ENVLOGGER_RETURN_IF_ERROR(writer->Init(
72 | /*steps_filepath=*/file::JoinPath(timestamp_dir,
73 | internal::kStepsFilename),
74 | /*step_offsets_filepath=*/
75 | file::JoinPath(timestamp_dir, internal::kStepOffsetsFilename),
76 | /*episode_metadata_filepath=*/
77 | file::JoinPath(timestamp_dir, internal::kEpisodeMetadataFilename),
78 | /*episode_index_filepath=*/
79 | file::JoinPath(timestamp_dir, internal::kEpisodeIndexFilename),
80 | writer_options));
81 | return absl::OkStatus();
82 | }
83 |
84 | } // namespace
85 |
86 | absl::Status RiegeliDatasetWriter::Init(std::string data_dir,
87 | const Data& metadata,
88 | int64_t max_episodes_per_shard,
89 | std::string writer_options,
90 | int episode_counter) {
91 | if (data_dir.empty()) return absl::NotFoundError("Empty data_dir.");
92 |
93 | data_dir_ = data_dir;
94 | writer_options_ = writer_options;
95 | const std::string metadata_filepath =
96 | file::JoinPath(data_dir, internal::kMetadataFilename);
97 | if (!file::GetSize(metadata_filepath).ok()) {
98 | // If metadata does not yet exist, write it.
99 | ENVLOGGER_RETURN_IF_ERROR(WriteSingleRiegeliRecord(
100 | /*filepath=*/metadata_filepath, /*data=*/metadata));
101 | }
102 | max_episodes_per_shard_ = max_episodes_per_shard;
103 | if (max_episodes_per_shard_ <= 0) {
104 | ENVLOGGER_RETURN_IF_ERROR(
105 | CreateRiegeliShardWriter(data_dir, writer_options_, &writer_));
106 | }
107 | episode_counter_ = episode_counter;
108 |
109 | return absl::OkStatus();
110 | }
111 |
112 | bool RiegeliDatasetWriter::AddStep(const google::protobuf::Message& data,
113 | bool is_new_episode) {
114 | if (is_new_episode) {
115 | if (max_episodes_per_shard_ > 0 &&
116 | episode_counter_++ % max_episodes_per_shard_ == 0) {
117 | ENVLOGGER_CHECK_OK(
118 | CreateRiegeliShardWriter(data_dir_, writer_options_, &writer_));
119 | }
120 | }
121 | return writer_.AddStep(data, is_new_episode);
122 | }
123 |
124 | void RiegeliDatasetWriter::SetEpisodeMetadata(const Data& data) {
125 | writer_.SetEpisodeMetadata(data);
126 | }
127 |
128 | void RiegeliDatasetWriter::Flush() { writer_.Flush(); }
129 |
130 | void RiegeliDatasetWriter::Close() { writer_.Close(); }
131 |
132 | } // namespace envlogger
133 |
--------------------------------------------------------------------------------
/envlogger/backends/cc/riegeli_dataset_writer.h:
--------------------------------------------------------------------------------
1 | // Copyright 2024 DeepMind Technologies Limited..
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // http://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | #ifndef THIRD_PARTY_PY_ENVLOGGER_BACKENDS_CC_RIEGELI_DATASET_WRITER_H_
16 | #define THIRD_PARTY_PY_ENVLOGGER_BACKENDS_CC_RIEGELI_DATASET_WRITER_H_
17 |
18 | #include
19 | #include
20 | #include
21 | #include
22 |
23 | #include "google/protobuf/message.h"
24 | #include "absl/status/status.h"
25 | #include "absl/strings/string_view.h"
26 | #include "envlogger/backends/cc/episode_info.h"
27 | #include "envlogger/backends/cc/riegeli_shard_reader.h"
28 | #include "envlogger/backends/cc/riegeli_shard_writer.h"
29 | #include "envlogger/proto/storage.pb.h"
30 |
31 | namespace envlogger {
32 |
33 | // Automates creating trajectories that can be efficiently read from disk by
34 | // RiegeliDatasetReader.
35 | class RiegeliDatasetWriter {
36 | public:
37 | RiegeliDatasetWriter() = default;
38 |
39 | // Initializes this writer to the following `data_dir`.
40 | // `metadata` is a client-specific payload.
41 | // `max_episodes_per_shard` determines the maximum number of episodes a single
42 | // RiegeliShardWriter shard will hold. If non-positive, a single shard file
43 | // will hold all steps and episodes.
44 | // `episode_counter` sets the current episode that this writer is processing.
45 | // It influences the shards that are created if `max_episodes_per_shard` is
46 | // positive.
47 | //
48 | // IMPORTANT: `data_dir` MUST exist _before_ calling Init().
49 | absl::Status Init(
50 | std::string data_dir, const Data& metadata = Data(),
51 | int64_t max_episodes_per_shard = 0,
52 | std::string writer_options = "transpose,brotli:6,chunk_size:1M",
53 | int episode_counter = 0);
54 |
55 | // Adds `data` to the trajectory and marks it as a new episode if
56 | // `is_new_episode==true`.
57 | // Returns true if `data` has been written, false otherwise.
58 | bool AddStep(const google::protobuf::Message& data, bool is_new_episode = false);
59 |
60 | // Sets episodic metadata for the _current_ episode.
61 | // This can be called multiple times but the value will be written only when a
62 | // new episode comes in or when this writer is about to be destructed.
63 | // Notice that calling this before an `AddStep(..., /*is_new_episode=*/true)`
64 | // is called leads to this writer ignoring the `data` that's passed.
65 | void SetEpisodeMetadata(const Data& data);
66 |
67 | void Flush();
68 |
69 | // Finalizes all writes and releases all handles.
70 | void Close();
71 |
72 | // Const accessors to internal state.
73 | std::string DataDir() const { return data_dir_; }
74 | std::string WriterOptions() const { return writer_options_; }
75 | int64_t MaxEpisodesPerShard() const { return max_episodes_per_shard_; }
76 | int EpisodeCounter() const { return episode_counter_; }
77 |
78 | private:
79 | std::string data_dir_;
80 | std::string writer_options_;
81 | int64_t max_episodes_per_shard_ = 0;
82 | int episode_counter_ = 0;
83 | RiegeliShardWriter writer_;
84 | };
85 |
86 | } // namespace envlogger
87 |
88 | #endif // THIRD_PARTY_PY_ENVLOGGER_BACKENDS_CC_RIEGELI_DATASET_WRITER_H_
89 |
--------------------------------------------------------------------------------
/envlogger/backends/cc/riegeli_shard_reader.h:
--------------------------------------------------------------------------------
1 | // Copyright 2024 DeepMind Technologies Limited..
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // http://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | #ifndef THIRD_PARTY_PY_ENVLOGGER_BACKENDS_CC_RIEGELI_SHARD_READER_H_
16 | #define THIRD_PARTY_PY_ENVLOGGER_BACKENDS_CC_RIEGELI_SHARD_READER_H_
17 |
18 | #include
19 | #include
20 | #include
21 | #include
22 | #include
23 |
24 | #include "glog/logging.h"
25 | #include "google/protobuf/message.h"
26 | #include "absl/status/status.h"
27 | #include "absl/status/statusor.h"
28 | #include "absl/strings/str_cat.h"
29 | #include "absl/strings/string_view.h"
30 | #include "envlogger/backends/cc/episode_info.h"
31 | #include "envlogger/platform/riegeli_file_reader.h"
32 | #include "envlogger/proto/storage.pb.h"
33 | #include "riegeli/base/object.h"
34 | #include "riegeli/records/record_reader.h"
35 |
36 | namespace envlogger {
37 |
38 | // A RiegeliShardReader refers to a single Riegeli index file (usually named
39 | // "index.riegeli") that's associated with a Riegeli trajectory file (usually
40 | // named "trajectories.riegeli"). The index file is an external index into the
41 | // trajectory so that accesses have minimal latency.
42 | //
43 | // This class reads the data from one of these index files and returns
44 | // information for accessing individual steps and episodes in the trajectories
45 | // file.
46 | // RiegeliShardReader is not thread-safe (i.e. multiple threads should not
47 | // concurrently call `Step()` or `Episode()`). However, it is possible to create
48 | // cheap copies of RiegeliShardReader using the Clone function, creating a new
49 | // reader which will share with the original reader the relatively expensive
50 | // in-memory index but will contain their individual Riegeli file handlers.
51 | // These copies can then be safely passed to threads.
52 | //
53 | // Note on nomenclature:
54 | // - The term "index" when used in an array refers to the position in the
55 | // array and is always 0-based throughout this library. For example, in the
56 | // array [13, 17, 24, 173] the element 13 is at index 0, 17 at index 1, 24 at
57 | // index 2 and 173 at index 3.
58 | // - The term "offset" refers to a file offset in bytes. This is the unit used
59 | // by Riegeli to position its reading head to read a particular record.
60 | class RiegeliShardReader {
61 | public:
62 | RiegeliShardReader() = default;
63 |
64 | RiegeliShardReader(RiegeliShardReader&&) = default;
65 | RiegeliShardReader& operator=(RiegeliShardReader&&) = default;
66 |
67 | ~RiegeliShardReader();
68 |
69 | // Reads trajectory data written by RiegeliShardWriter.
70 | //
71 | // If `step_offsets_filepath` is empty the constructor will return early
72 | // leaving the object members empty. All methods should fail if called on a
73 | // object in this state. This is used by clients for efficiently initializing
74 | // objects to minimize memory allocations by preallocating everything
75 | // beforehand.
76 | absl::Status Init(absl::string_view steps_filepath,
77 | absl::string_view step_offsets_filepath,
78 | absl::string_view episode_metadata_filepath,
79 | absl::string_view episode_index_filepath);
80 |
81 | // Clones a RiegeliShardReader.
82 | // The cloned reader owns its own file handles but shares the ShardData
83 | // with the original reader.
84 | // The cloned reader can safely be used in a different thread.
85 | absl::StatusOr Clone();
86 |
87 | // Returns the number of steps indexed by this RiegeliShardReader.
88 | int64_t NumSteps() const;
89 |
90 | // Returns the number of episodes indexed by this RiegeliShardReader.
91 | int64_t NumEpisodes() const;
92 |
93 | // Returns step data at index `step_index`.
94 | // Returns nullopt if step_index is not in [0, NumSteps()).
95 | template
96 | std::optional Step(int64_t step_index);
97 |
98 | // Returns information for accessing a specific episode.
99 | // Returns nullopt if `episode_index` is not in [0, NumEpisodes()).
100 | std::optional Episode(int64_t episode_index,
101 | bool include_metadata = false);
102 |
103 | // Releases resources and closes the underlying files.
104 | void Close();
105 |
106 | private:
107 | struct ShardData {
108 | std::string steps_filepath;
109 | std::string episode_metadata_filepath;
110 |
111 | // Riegeli offsets for quickly accessing steps.
112 | std::vector step_offsets;
113 |
114 | // Riegeli offsets for quickly accessing episodic metadata.
115 | std::vector episode_metadata_offsets;
116 |
117 | // The first steps of each episodes.
118 | // Episode index -> Step index.
119 | std::vector episode_starts;
120 | };
121 |
122 | // Structure of the shard.
123 | // Note that all instances created with Clone() will share that information.
124 | std::shared_ptr shard_;
125 |
126 | riegeli::RecordReader steps_reader_{riegeli::kClosed};
127 | riegeli::RecordReader episode_metadata_reader_{
128 | riegeli::kClosed};
129 | };
130 |
131 | ////////////////////////////////////////////////////////////////////////////////
132 | // Implementation details of RiegeliShardReader::Step.
133 | ////////////////////////////////////////////////////////////////////////////////
134 |
135 | template
136 | std::optional RiegeliShardReader::Step(int64_t step_index) {
137 | const auto& step_offsets = shard_->step_offsets;
138 | if (step_index < 0 ||
139 | step_index >= static_cast(step_offsets.size())) {
140 | return std::nullopt;
141 | }
142 |
143 | const int64_t offset = step_offsets[step_index];
144 | if (!steps_reader_.Seek(offset)) {
145 | VLOG(0) << absl::StrCat("Failed to seek to offset ", offset,
146 | " status: ", steps_reader_.status().ToString());
147 | return std::nullopt;
148 | }
149 | T data;
150 | const bool read_status = steps_reader_.ReadRecord(data);
151 | if (!read_status) return std::nullopt;
152 |
153 | return data;
154 | }
155 |
156 | } // namespace envlogger
157 |
158 | #endif // THIRD_PARTY_PY_ENVLOGGER_BACKENDS_CC_RIEGELI_SHARD_READER_H_
159 |
--------------------------------------------------------------------------------
/envlogger/backends/cc/riegeli_shard_writer.h:
--------------------------------------------------------------------------------
1 | // Copyright 2024 DeepMind Technologies Limited..
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // http://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | #ifndef THIRD_PARTY_PY_ENVLOGGER_BACKENDS_CC_RIEGELI_SHARD_WRITER_H_
16 | #define THIRD_PARTY_PY_ENVLOGGER_BACKENDS_CC_RIEGELI_SHARD_WRITER_H_
17 |
18 | #include
19 | #include
20 | #include
21 |
22 | #include "google/protobuf/message.h"
23 | #include "absl/status/status.h"
24 | #include "absl/strings/string_view.h"
25 | #include "envlogger/backends/cc/episode_info.h"
26 | #include "envlogger/platform/riegeli_file_writer.h"
27 | #include "envlogger/proto/storage.pb.h"
28 | #include "riegeli/base/object.h"
29 | #include "riegeli/records/record_writer.h"
30 |
31 | namespace envlogger {
32 |
33 | // Creates Riegeli index files to allow for efficient record lookup from another
34 | // Riegeli file (i.e. an external index).
35 | class RiegeliShardWriter {
36 | public:
37 | RiegeliShardWriter() = default;
38 | ~RiegeliShardWriter();
39 |
40 | // IMPORTANT: The directory where these file live MUST exist _before_ calling
41 | // Init().
42 | //
43 | // steps_filepath:
44 | // Path to riegeli file that'll contain the actual trajectory data as Data
45 | // objects, including timesteps (observations, rewards, discounts), actions
46 | // and step metadata.
47 | // Each riegeli entry corresponds to a single step.
48 | // Each entry can be seeked in O(1) time via an external index
49 | // (step_offsets_filepath).
50 | // step_offsets_filepath:
51 | // An external index into steps_filepath. This allows clients to access any
52 | // individual step in O(1). Riegeli entries are Datum objects.
53 | // episode_metadata_filepath:
54 | // Path to riegeli file that'll contain episodic metadata (e.g. discounted
55 | // returns, summarized data etc). Riegeli entries are Data objects.
56 | // Each entry corresponds to the metadata of a single episode.
57 | // Metadata is stored only for episodes that have it (i.e. if an episode
58 | // does not store any metadata, nothing is stored here).
59 | // Each entry can be seeked in O(1) time via an external index
60 | // (episode_index_filepath).
61 | // episode_index_filepath:
62 | // Path to a riegeli file containing (|Episodes| x 2) tensors where:
63 | // * 1st dim is a step index indicating the start of the episode.
64 | // * 2nd dim is a riegeli offset for optional metadata (for looking up into
65 | // episode_metadata_filepath). -1 indicates that no metadata exists for a
66 | // specific episode.
67 | // Each riegeli entry is a Datum object.
68 | absl::Status Init(absl::string_view steps_filepath,
69 | absl::string_view step_offsets_filepath,
70 | absl::string_view episode_metadata_filepath,
71 | absl::string_view episode_index_filepath,
72 | absl::string_view writer_options);
73 |
74 | // Adds `data` to the trajectory and marks it as a new episode if
75 | // `is_new_episode==true`.
76 | // Returns true if `data` has been written, false otherwise.
77 | bool AddStep(const google::protobuf::Message& data, bool is_new_episode = false);
78 |
79 | // Sets episodic metadata for the _current_ episode.
80 | // This can be called multiple times but the value will be written only when a
81 | // new episode comes in or when this writer is about to be destructed.
82 | // Notice that calling this before an `AddStep(..., /*is_new_episode=*/true)`
83 | // is called leads to this writer ignoring the `data` that's passed.
84 | void SetEpisodeMetadata(const Data& data);
85 |
86 | // Flushes the index to disk.
87 | void Flush();
88 |
89 | // Finalizes all writes and releases all handles.
90 | void Close();
91 |
92 | private:
93 | // The number of steps in the last flush.
94 | int num_steps_at_flush_ = 0;
95 |
96 | std::vector step_offsets_;
97 |
98 | // The first steps of each episodes.
99 | // Episode index -> Step index.
100 | std::vector episode_starts_;
101 | // The riegeli offset into `episode_metadata_writer_`.
102 | std::vector episode_offsets_;
103 |
104 | // Metadata for the _current_ episode.
105 | std::optional episode_metadata_;
106 |
107 | // Steps, episodes and their riegeli numeric offsets.
108 | riegeli::RecordWriter steps_writer_{riegeli::kClosed};
109 | riegeli::RecordWriter step_offsets_writer_{
110 | riegeli::kClosed};
111 | riegeli::RecordWriter episode_metadata_writer_{
112 | riegeli::kClosed};
113 | riegeli::RecordWriter episode_index_writer_{
114 | riegeli::kClosed};
115 |
116 | };
117 |
118 | } // namespace envlogger
119 |
120 | #endif // THIRD_PARTY_PY_ENVLOGGER_BACKENDS_CC_RIEGELI_SHARD_WRITER_H_
121 |
--------------------------------------------------------------------------------
/envlogger/backends/cross_language_test/BUILD.bazel:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited..
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # Tests for reading/writing from different languages.
16 |
17 | load("@rules_python//python:defs.bzl", "py_binary", "py_test")
18 |
19 | cc_binary(
20 | name = "cc_reader",
21 | testonly = 1,
22 | srcs = ["cc_reader.cc"],
23 | deps = [
24 | "//envlogger/backends/cc:riegeli_dataset_reader",
25 | "//envlogger/converters:xtensor_codec",
26 | "//envlogger/platform:proto_testutil",
27 | "//envlogger/proto:storage_cc_proto",
28 | "@com_github_google_glog//:glog",
29 | "@com_google_absl//absl/flags:flag",
30 | "@com_google_absl//absl/flags:parse",
31 | "@com_google_googletest//:gtest",
32 | "@gmp",
33 | "@xtensor",
34 | ],
35 | )
36 |
37 | py_binary(
38 | name = "py_writer",
39 | srcs = ["py_writer.py"],
40 | deps = [
41 | "//envlogger",
42 | ],
43 | )
44 |
45 | py_test(
46 | name = "cross_language_test",
47 | srcs = ["cross_language_test.py"],
48 | data = [
49 | ":cc_reader",
50 | ":py_writer",
51 | ],
52 | deps = [
53 | "@rules_python//python/runfiles",
54 | ],
55 | )
56 |
--------------------------------------------------------------------------------
/envlogger/backends/cross_language_test/cc_reader.cc:
--------------------------------------------------------------------------------
1 | // Copyright 2024 DeepMind Technologies Limited..
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // http://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | #include
16 | #include
17 | #include
18 | #include
19 |
20 | #include "absl/flags/parse.h"
21 | #include "gmock/gmock.h"
22 | #include "gtest/gtest.h"
23 | #include "absl/flags/flag.h"
24 | #include
25 | #include "envlogger/backends/cc/riegeli_dataset_reader.h"
26 | #include "envlogger/converters/xtensor_codec.h"
27 | #include "envlogger/platform/proto_testutil.h"
28 | #include "envlogger/proto/storage.pb.h"
29 | #include "xtensor/xarray.hpp"
30 |
31 | ABSL_FLAG(std::string, trajectories_dir, "", "Path to reader trajectory.");
32 |
33 | using ::testing::DoubleEq;
34 | using ::testing::Eq;
35 | using ::testing::FloatEq;
36 | using ::testing::IsTrue;
37 | using ::testing::SizeIs;
38 |
39 | int main(int argc, char** argv) {
40 | absl::ParseCommandLine(argc, argv);
41 |
42 | VLOG(0) << "Starting C++ Reader...";
43 | VLOG(0) << "--trajectories_dir: " << absl::GetFlag(FLAGS_trajectories_dir);
44 | envlogger::RiegeliDatasetReader reader;
45 | const absl::Status init_status =
46 | reader.Init(absl::GetFlag(FLAGS_trajectories_dir));
47 | VLOG(0) << "init_status: " << init_status;
48 |
49 | VLOG(0) << "reader.NumSteps(): " << reader.NumSteps();
50 | for (int64_t i = 0; i < reader.NumSteps(); ++i) {
51 | std::optional step = reader.Step(i);
52 | EXPECT_THAT(step.has_value(), IsTrue())
53 | << "All steps should be readable. Step " << i << " is not available.";
54 |
55 | envlogger::DataView step_view(std::addressof(*step));
56 | EXPECT_THAT(step_view->value_case(), Eq(envlogger::Data::ValueCase::kTuple))
57 | << "Each step should be a tuple.";
58 | EXPECT_THAT(step_view, SizeIs(3))
59 | << "Each step should consist of (timestep, action, custom data)";
60 | const envlogger::Data& timestep = *step_view[0];
61 | const envlogger::Data& action = *step_view[1];
62 | const envlogger::Data& custom_data = *step_view[2];
63 | VLOG(1) << "timestep: " << timestep.ShortDebugString();
64 | VLOG(1) << "action: " << action.ShortDebugString();
65 | VLOG(1) << "custom_data: " << custom_data.ShortDebugString();
66 | envlogger::DataView timestep_view(×tep);
67 | EXPECT_THAT(timestep_view->value_case(),
68 | Eq(envlogger::Data::ValueCase::kTuple))
69 | << "Each timestep should be a tuple.";
70 | EXPECT_THAT(timestep_view, SizeIs(4))
71 | << "Each timestep should consist of (step type, reward, discount, "
72 | "observation)";
73 |
74 | // Check timestep values.
75 | // Check step type.
76 | const envlogger::Data& step_type = *timestep_view[0];
77 | VLOG(2) << "step_type: " << step_type.ShortDebugString();
78 | std::optional decoded_step_type =
79 | envlogger::Decode(step_type.datum());
80 | EXPECT_THAT(decoded_step_type.has_value(), IsTrue())
81 | << "Failed to decode step_type";
82 | EXPECT_THAT(absl::holds_alternative(*decoded_step_type),
83 | IsTrue());
84 | const mpz_class step_type_decoded =
85 | absl::get(*decoded_step_type);
86 | EXPECT_THAT(cmp(step_type_decoded, i ? 1 : 0), Eq(0));
87 | // Check reward.
88 | const envlogger::Data& reward = *timestep_view[1];
89 | VLOG(2) << "reward: " << reward.ShortDebugString();
90 | std::optional decoded_reward =
91 | envlogger::Decode(reward.datum());
92 | EXPECT_THAT(decoded_reward.has_value(), IsTrue())
93 | << "Failed to decode reward";
94 | EXPECT_THAT(absl::holds_alternative(*decoded_reward), IsTrue());
95 | const double r = absl::get(*decoded_reward);
96 | EXPECT_THAT(r, DoubleEq(i / 100.0));
97 | // Check discount.
98 | const envlogger::Data& discount = *timestep_view[2];
99 | VLOG(2) << "discount: " << discount.ShortDebugString();
100 | std::optional decoded_discount =
101 | envlogger::Decode(discount.datum());
102 | EXPECT_THAT(decoded_discount.has_value(), IsTrue())
103 | << "Failed to decode discount";
104 | EXPECT_THAT(absl::holds_alternative(*decoded_discount), IsTrue());
105 | const double gamma = absl::get(*decoded_discount);
106 | EXPECT_THAT(gamma, DoubleEq(0.99));
107 | // Check observation.
108 | const envlogger::Data& observation = *timestep_view[3];
109 | VLOG(2) << "observation: " << observation.ShortDebugString();
110 | std::optional decoded_obs =
111 | envlogger::Decode(observation.datum());
112 | EXPECT_THAT(decoded_obs.has_value(), IsTrue())
113 | << "Failed to decode observation";
114 | EXPECT_THAT(absl::holds_alternative>(*decoded_obs),
115 | IsTrue());
116 | const xt::xarray& obs = absl::get>(*decoded_obs);
117 | EXPECT_THAT(obs, SizeIs(1));
118 | EXPECT_THAT(obs(0), FloatEq(i));
119 |
120 | // Check action.
121 | std::optional decoded_action =
122 | envlogger::Decode(action.datum());
123 | EXPECT_THAT(decoded_action.has_value(), IsTrue())
124 | << "Failed to decode action";
125 | EXPECT_THAT(absl::holds_alternative(*decoded_action), IsTrue());
126 | const int a = absl::get(*decoded_action);
127 | EXPECT_THAT(a, Eq(100 - i));
128 |
129 | // There should be no custom data, but it should still be a valid pointer.
130 | EXPECT_THAT(custom_data, EqualsProto(envlogger::Data()));
131 | }
132 |
133 | reader.Close();
134 |
135 | return 0;
136 | }
137 |
--------------------------------------------------------------------------------
/envlogger/backends/cross_language_test/cross_language_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 DeepMind Technologies Limited..
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests writing trajectories in one language and reading from another."""
17 |
18 | from collections.abc import Sequence
19 | import os
20 | import shutil
21 | import subprocess
22 |
23 | from absl import logging
24 | from absl.testing import absltest
25 | from absl.testing import parameterized
26 |
27 | from rules_python.python.runfiles import runfiles
28 |
29 |
30 | def _execute_binary(rel_path: str, args: Sequence[str]) -> bytes:
31 | r = runfiles.Create()
32 | path = r.Rlocation(os.path.join('__main__', 'envlogger', rel_path))
33 | cmd = [path] + args
34 | return subprocess.check_output(cmd, env=r.EnvVars())
35 |
36 |
37 | class CrossLanguageTest(parameterized.TestCase):
38 |
39 | def test_py_writer_cc_reader(self):
40 | # Set up a trajectory directory.
41 | trajectories_dir = os.path.join(absltest.TEST_TMPDIR.value, 'my_trajectory')
42 | logging.info('trajectories_dir: %r', trajectories_dir)
43 | os.makedirs(trajectories_dir)
44 |
45 | # Find Python writer and run it.
46 | py_writer_output = _execute_binary(
47 | 'backends/cross_language_test/py_writer',
48 | args=[f'--trajectories_dir={trajectories_dir}'])
49 | logging.info('py_writer_output: %r', py_writer_output)
50 |
51 | # Find C++ reader and run it.
52 | cc_reader_output = _execute_binary(
53 | 'backends/cross_language_test/cc_reader',
54 | args=[f'--trajectories_dir={trajectories_dir}'])
55 | logging.info('cc_reader_output: %r', cc_reader_output)
56 |
57 | # If everything went well, there should be no
58 | # `subprocess.CalledProcessError`.
59 |
60 | logging.info('Cleaning up trajectories_dir %r', trajectories_dir)
61 | shutil.rmtree(trajectories_dir)
62 |
63 |
64 | if __name__ == '__main__':
65 | absltest.main()
66 |
--------------------------------------------------------------------------------
/envlogger/backends/cross_language_test/py_writer.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 DeepMind Technologies Limited..
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """A simple python binary that creates a simple RL trajectory."""
17 |
18 | from collections.abc import Sequence
19 |
20 | from absl import app
21 | from absl import flags
22 | from absl import logging
23 | import dm_env
24 | import envlogger
25 | import numpy as np
26 |
27 | _TRAJECTORIES_DIR = flags.DEFINE_string(
28 | 'trajectories_dir', None, 'Path to write trajectory.', required=True)
29 |
30 |
31 | def main(argv: Sequence[str]) -> None:
32 | if len(argv) > 1:
33 | raise app.UsageError('Too many command-line arguments.')
34 |
35 | logging.info('Starting Python-based writer...')
36 | logging.info('--trajectories_dir: %r', _TRAJECTORIES_DIR.value)
37 |
38 | writer = envlogger.RiegeliBackendWriter(
39 | data_directory=_TRAJECTORIES_DIR.value, metadata={'my_data': [1, 2, 3]})
40 | writer.record_step(
41 | envlogger.StepData(
42 | timestep=dm_env.TimeStep(
43 | observation=np.array([0.0], dtype=np.float32),
44 | reward=0.0,
45 | discount=0.99,
46 | step_type=dm_env.StepType.FIRST),
47 | action=np.int32(100)),
48 | is_new_episode=True)
49 | for i in range(1, 100):
50 | writer.record_step(
51 | envlogger.StepData(
52 | timestep=dm_env.TimeStep(
53 | observation=np.array([float(i)], dtype=np.float32),
54 | reward=i / 100.0,
55 | discount=0.99,
56 | step_type=dm_env.StepType.MID),
57 | action=np.int32(100 - i)),
58 | is_new_episode=False)
59 | writer.close()
60 |
61 |
62 | if __name__ == '__main__':
63 | app.run(main)
64 |
--------------------------------------------------------------------------------
/envlogger/backends/in_memory_backend.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 DeepMind Technologies Limited..
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Environment logger backend that stores all data in RAM.
17 | """
18 |
19 | import copy
20 | from typing import Any
21 |
22 | from envlogger import step_data
23 | from envlogger.backends import backend_reader
24 | from envlogger.backends import backend_writer
25 | from envlogger.backends.python import episode_info
26 |
27 |
28 |
29 | class InMemoryBackendWriter(backend_writer.BackendWriter):
30 | """Backend that stores trajectory data in memory."""
31 |
32 | def __init__(self, **base_kwargs):
33 | super().__init__(**base_kwargs)
34 | self.steps = []
35 | self.episode_metadata = {}
36 | self.episode_start_indices = []
37 |
38 | def _record_step(self, data: step_data.StepData,
39 | is_new_episode: bool) -> None:
40 | if is_new_episode:
41 | self.episode_start_indices.append(len(self.steps))
42 | self.steps.append(data)
43 |
44 | def set_episode_metadata(self, data: Any) -> None:
45 | current_episode = len(self.episode_start_indices)
46 | if current_episode > 0:
47 | self.episode_metadata[current_episode] = data
48 |
49 | def close(self) -> None:
50 | pass
51 |
52 |
53 | class InMemoryBackendReader(backend_reader.BackendReader):
54 | """Reader that reads data from an InMemoryBackend."""
55 |
56 | def __init__(self, in_memory_backend_writer: InMemoryBackendWriter):
57 | self._backend = in_memory_backend_writer
58 | super().__init__()
59 |
60 | def _copy(self) -> 'InMemoryBackendReader':
61 | return copy.deepcopy(self)
62 |
63 | def close(self) -> None:
64 | pass
65 |
66 | def _get_nth_step(self, i: int) -> step_data.StepData:
67 | return self._backend.steps[i]
68 |
69 | def _get_nth_episode_info(self,
70 | i: int,
71 | include_metadata: bool = False
72 | ) -> episode_info.EpisodeInfo:
73 | if i == len(self._backend.episode_start_indices) - 1: # Last episode.
74 | length = len(self._backend.steps) - self._backend.episode_start_indices[i]
75 | else:
76 | length = (self._backend.episode_start_indices[i + 1] -
77 | self._backend.episode_start_indices[i])
78 | episode_metadata = self._backend.episode_metadata.get(i, None)
79 | return episode_info.EpisodeInfo(
80 | start=self._backend.episode_start_indices[i],
81 | num_steps=length,
82 | metadata=episode_metadata)
83 |
84 | def _get_num_steps(self) -> int:
85 | return len(self._backend.steps)
86 |
87 | def _get_num_episodes(self) -> int:
88 | return len(self._backend.episode_start_indices)
89 |
90 | def metadata(self):
91 | return self._backend.metadata()
92 |
--------------------------------------------------------------------------------
/envlogger/backends/python/BUILD.bazel:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited..
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | load("@pybind11_bazel//:build_defs.bzl", "pybind_extension")
16 |
17 | package(default_visibility = ["//envlogger:__subpackages__"])
18 |
19 | pybind_extension(
20 | name = "riegeli_dataset_reader",
21 | srcs = ["riegeli_dataset_reader.cc"],
22 | deps = [
23 | "//envlogger/backends/cc:riegeli_dataset_reader",
24 | "//envlogger/proto:storage_cc_proto",
25 | "@com_google_riegeli//riegeli/bytes:string_writer",
26 | "@com_google_riegeli//riegeli/endian:endian_writing",
27 | "@pybind11_protobuf//pybind11_protobuf:proto_casters",
28 | ],
29 | )
30 |
31 | pybind_extension(
32 | name = "riegeli_dataset_writer",
33 | srcs = ["riegeli_dataset_writer.cc"],
34 | deps = [
35 | "//envlogger/backends/cc:riegeli_dataset_writer",
36 | "//envlogger/proto:storage_cc_proto",
37 | "@com_google_riegeli//riegeli/bytes:string_writer",
38 | "@com_google_riegeli//riegeli/endian:endian_writing",
39 | "@pybind11_protobuf//pybind11_protobuf:proto_casters",
40 | ],
41 | )
42 |
43 | pybind_extension(
44 | name = "episode_info",
45 | srcs = ["episode_info.cc"],
46 | deps = [
47 | "//envlogger/backends/cc:episode_info",
48 | "//envlogger/proto:storage_cc_proto",
49 | "@pybind11_protobuf//pybind11_protobuf:proto_casters",
50 | ],
51 | )
52 |
53 | py_test(
54 | name = "episode_info_test",
55 | srcs = ["episode_info_test.py"],
56 | data = [":episode_info.so"],
57 | deps = [
58 | "//envlogger/proto:storage_py_pb2",
59 | ],
60 | )
61 |
--------------------------------------------------------------------------------
/envlogger/backends/python/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 DeepMind Technologies Limited..
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
--------------------------------------------------------------------------------
/envlogger/backends/python/episode_info.cc:
--------------------------------------------------------------------------------
1 | // Copyright 2024 DeepMind Technologies Limited..
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // http://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | #include "envlogger/backends/cc/episode_info.h"
16 |
17 | #include
18 | #include
19 |
20 | #include "envlogger/proto/storage.pb.h"
21 | #include "pybind11//pybind11.h"
22 | #include "pybind11//stl.h"
23 | #include "pybind11_protobuf/proto_casters.h"
24 |
25 | PYBIND11_MODULE(episode_info, m) {
26 | pybind11::google::ImportProtoModule();
27 |
28 | m.doc() = "EpisodeInfo bindings.";
29 |
30 | pybind11::class_(m, "EpisodeInfo")
31 | .def(pybind11::init>(),
32 | pybind11::arg("start") = 0, pybind11::arg("num_steps") = 0,
33 | pybind11::arg("metadata") = std::nullopt)
34 | .def_readwrite("start", &envlogger::EpisodeInfo::start)
35 | .def_readwrite("num_steps", &envlogger::EpisodeInfo::num_steps)
36 | .def_readwrite("metadata", &envlogger::EpisodeInfo::metadata);
37 | }
38 |
--------------------------------------------------------------------------------
/envlogger/backends/python/episode_info_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 DeepMind Technologies Limited..
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for episode_info.cc."""
17 |
18 | import random
19 |
20 | from absl.testing import absltest
21 | from absl.testing import parameterized
22 | from envlogger.backends.python import episode_info
23 | from envlogger.proto import storage_pb2
24 |
25 |
26 | class EpisodeInfoTest(parameterized.TestCase):
27 |
28 | def test_empty_episode_info(self):
29 | episode = episode_info.EpisodeInfo()
30 | self.assertEqual(episode.start, 0)
31 | self.assertEqual(episode.num_steps, 0)
32 | self.assertIsNone(episode.metadata)
33 |
34 | def test_episode_info_init_with_random_kwargs(self):
35 | random_starts = [random.randint(-1, 10000) for _ in range(100)]
36 | random_num_steps = [random.randint(-1, 10000) for _ in range(100)]
37 | random_metadata = []
38 |
39 | dimension = storage_pb2.Datum.Shape.Dim()
40 | dimension.size = -438
41 | for _ in range(100):
42 | metadata = storage_pb2.Data()
43 | metadata.datum.shape.dim.append(dimension)
44 | metadata.datum.values.int32_values.append(random.randint(-1, 10000))
45 | random_metadata.append(metadata)
46 |
47 | for start, num_steps, metadata in zip(random_starts, random_num_steps,
48 | random_metadata):
49 | episode = episode_info.EpisodeInfo(
50 | start=start, num_steps=num_steps, metadata=metadata)
51 | self.assertEqual(episode.start, start)
52 | self.assertEqual(episode.num_steps, num_steps)
53 | self.assertSequenceEqual(episode.metadata.datum.values.int32_values,
54 | metadata.datum.values.int32_values)
55 |
56 |
57 | if __name__ == '__main__':
58 | absltest.main()
59 |
--------------------------------------------------------------------------------
/envlogger/backends/python/riegeli_dataset_writer.cc:
--------------------------------------------------------------------------------
1 | // Copyright 2024 DeepMind Technologies Limited..
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // http://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | #include "envlogger/backends/cc/riegeli_dataset_writer.h"
16 |
17 | #include
18 | #include
19 | #include
20 |
21 | #include "envlogger/proto/storage.pb.h"
22 | #include "pybind11//pybind11.h"
23 | #include "pybind11//stl.h"
24 | #include "pybind11_protobuf/proto_casters.h"
25 | #include "riegeli/bytes/string_writer.h"
26 | #include "riegeli/endian/endian_writing.h"
27 |
28 | namespace {
29 |
30 | // This traverses the proto to convert repeated 'float_values' into big-endian
31 | // byte arrays 'float_values_buffer'. In python this allows using
32 | // np.frombuffer(float_values_buffer) which is overall 2.3x more efficient than
33 | // building the np array from the repeated field (this includes the cost of
34 | // doing the conversion in C++).
35 | // In the future, we could do the same for other data types stored in repeated
36 | // fields.
37 | void OptimizeDataProto(envlogger::Data* data) {
38 | switch (data->value_case()) {
39 | case envlogger::Data::kDatum: {
40 | auto* datum = data->mutable_datum();
41 | if (!datum->values().float_values().empty()) {
42 | riegeli::StringWriter writer(
43 | datum->mutable_values()->mutable_float_values_buffer());
44 | writer.SetWriteSizeHint(datum->values().float_values_size() *
45 | sizeof(float));
46 | riegeli::WriteBigEndianFloats(datum->values().float_values(), writer);
47 | writer.Close();
48 | datum->mutable_values()->clear_float_values();
49 | }
50 | break;
51 | }
52 | case envlogger::Data::kArray:
53 | for (auto& value : *data->mutable_array()->mutable_values()) {
54 | OptimizeDataProto(&value);
55 | }
56 | break;
57 | case envlogger::Data::kTuple:
58 | for (auto& value : *data->mutable_tuple()->mutable_values()) {
59 | OptimizeDataProto(&value);
60 | }
61 | break;
62 | case envlogger::Data::kDict:
63 | for (auto& value : *data->mutable_dict()->mutable_values()) {
64 | OptimizeDataProto(&value.second);
65 | }
66 | break;
67 | case envlogger::Data::VALUE_NOT_SET:
68 | break;
69 | }
70 | }
71 | } // namespace
72 |
73 | PYBIND11_MODULE(riegeli_dataset_writer, m) {
74 | pybind11::google::ImportProtoModule();
75 | pybind11::module::import("envlogger.backends.python.episode_info");
76 |
77 | m.doc() = "RiegeliDatasetWriter bindings.";
78 |
79 | pybind11::class_(m, "RiegeliDatasetWriter")
80 | .def(pybind11::init<>())
81 | // Initializes the writer with the given arguments.
82 | // If successful, `void` is returned with no side effects. Otherwise a
83 | // `RuntimeError` is raised with an accompanying message.
84 | // Note: `absl::Status` isn't used because there are incompatibilities
85 | // between slightly different versions of `pybind11_abseil` when used with
86 | // different projects. Please see
87 | // https://github.com/deepmind/envlogger/issues/3 for more details.
88 | .def(
89 | "init",
90 | [](envlogger::RiegeliDatasetWriter* self, std::string data_dir,
91 | const envlogger::Data& metadata = envlogger::Data(),
92 | int64_t max_episodes_per_shard = 0,
93 | std::string writer_options =
94 | "transpose,brotli:6,chunk_size:1M") -> void {
95 | const absl::Status status = self->Init(
96 | data_dir, metadata, max_episodes_per_shard, writer_options);
97 | if (!status.ok()) throw std::runtime_error(status.ToString());
98 | },
99 | pybind11::arg("data_dir"),
100 | pybind11::arg("metadata") = envlogger::Data(),
101 | pybind11::arg("max_episodes_per_shard") = 0,
102 | pybind11::arg("writer_options") = "transpose,brotli:6,chunk_size:1M")
103 | .def("add_step", &envlogger::RiegeliDatasetWriter::AddStep,
104 | pybind11::arg("data"), pybind11::arg("is_new_episode") = false)
105 | .def("set_episode_metadata",
106 | &envlogger::RiegeliDatasetWriter::SetEpisodeMetadata,
107 | pybind11::arg("data"))
108 | .def("flush", &envlogger::RiegeliDatasetWriter::Flush)
109 | .def("close", &envlogger::RiegeliDatasetWriter::Close)
110 | // Accessors.
111 | .def("data_dir", &envlogger::RiegeliDatasetWriter::DataDir)
112 | .def("max_episodes_per_shard",
113 | &envlogger::RiegeliDatasetWriter::MaxEpisodesPerShard)
114 | .def("writer_options", &envlogger::RiegeliDatasetWriter::WriterOptions)
115 | .def("episode_counter", &envlogger::RiegeliDatasetWriter::EpisodeCounter)
116 | // Pickling support.
117 | .def(pybind11::pickle(
118 | [](const envlogger::RiegeliDatasetWriter& self) { // __getstate__().
119 | pybind11::dict output;
120 | output["data_dir"] = self.DataDir();
121 | output["max_episodes_per_shard"] = self.MaxEpisodesPerShard();
122 | output["writer_options"] = self.WriterOptions();
123 | output["episode_counter_"] = self.EpisodeCounter();
124 | return output;
125 | },
126 | [](pybind11::dict d) { // __setstate__().
127 | const std::string data_dir = d["data_dir"].cast();
128 | const int64_t max_episodes_per_shard =
129 | d["max_episodes_per_shard"].cast();
130 | const std::string writer_options =
131 | d["writer_options"].cast();
132 |
133 | auto writer = std::make_unique();
134 | const absl::Status status = writer->Init(
135 | /*data_dir=*/data_dir, /*metadata=*/envlogger::Data(),
136 | /*max_episodes_per_shard=*/max_episodes_per_shard,
137 | /*writer_options=*/writer_options);
138 | if (!status.ok()) {
139 | throw std::runtime_error(
140 | "Failed to initialize RiegeliDatasetWriter: " +
141 | status.ToString());
142 | }
143 | return writer;
144 | }));
145 | }
146 |
--------------------------------------------------------------------------------
/envlogger/backends/riegeli_backend_reader.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 DeepMind Technologies Limited..
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """For reading trajectory data from riegeli files."""
17 |
18 | import copy
19 | from typing import Any
20 |
21 | from absl import logging
22 | import dm_env
23 | from envlogger import step_data
24 | from envlogger.backends import backend_reader
25 | from envlogger.backends.python import episode_info
26 | from envlogger.backends.python import riegeli_dataset_reader
27 | from envlogger.converters import codec
28 | from envlogger.proto import storage_pb2
29 |
30 |
31 |
32 |
33 |
34 | class RiegeliBackendReader(backend_reader.BackendReader):
35 | """A class that reads logs produced by an EnvironmentLoggerWrapper instance.
36 |
37 | Attributes:
38 | episodes: Traverse the data episode-wise in list-like fashion.
39 | steps: Traverse the data stepwise in list-like fashion.
40 | """
41 |
42 | def __init__(self, data_directory: str):
43 | self._reader = riegeli_dataset_reader.RiegeliDatasetReader()
44 | try:
45 | self._reader.init(data_directory)
46 | except RuntimeError as e:
47 | error_message = str(e)
48 | if error_message.startswith('NOT_FOUND: Empty steps in '):
49 | # This case happens frequently when clients abruptly kill the
50 | # EnvironmentLogger without calling its .close() method, which then
51 | # causes the last shard to be truncated. This can be because the client
52 | # exited successfully and "forgot" to call .close(), which is a bug, but
53 | # also because of a preempted work unit, which is expected to happen
54 | # under distributed settings.
55 | # We can't do much to fix the bad usages, but we can be a bit more
56 | # permissive and try to read the successful shards.
57 | logging.exception("""Ignoring error due to empty step offset file.
58 | *********************************
59 | **** You likely forgot to ***
60 | **** call close() on your env ***
61 | **** ***
62 | *********************************""")
63 | else:
64 | raise
65 |
66 | self._metadata = codec.decode(self._reader.metadata()) or {}
67 | super().__init__()
68 |
69 | def _copy(self) -> 'RiegeliBackendReader':
70 | c = copy.copy(self)
71 |
72 | c._metadata = copy.deepcopy(self._metadata)
73 | c._reader = self._reader.clone()
74 |
75 | return c
76 |
77 | def close(self):
78 | if self._reader is not None:
79 | self._reader.close()
80 | self._reader = None
81 |
82 | def _decode_step_data(self, data: tuple[Any, Any, Any]) -> step_data.StepData:
83 | """Recovers dm_env.TimeStep from logged data (either dict or tuple)."""
84 | # Recover the TimeStep from the first tuple element.
85 | timestep = dm_env.TimeStep(
86 | dm_env.StepType(data[0][0]), data[0][1], data[0][2], data[0][3])
87 | return step_data.StepData(timestep, data[1], data[2])
88 |
89 | def _get_num_steps(self):
90 | return self._reader.num_steps
91 |
92 | def _get_num_episodes(self):
93 | return self._reader.num_episodes
94 |
95 | def _get_nth_step(self, i: int) -> step_data.StepData:
96 | """Returns the timestep given by offset `i` (0-based)."""
97 | serialized_data = self._reader.serialized_step(i)
98 | data = storage_pb2.Data.FromString(serialized_data)
99 | return self._decode_step_data(codec.decode(data))
100 |
101 | def _get_nth_episode_info(self,
102 | i: int,
103 | include_metadata: bool = False
104 | ) -> episode_info.EpisodeInfo:
105 | """Returns the index of the start of nth episode, and its length."""
106 | return self._reader.episode(i, include_metadata)
107 |
108 | def metadata(self):
109 | return self._metadata
110 |
--------------------------------------------------------------------------------
/envlogger/backends/riegeli_backend_writer.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 DeepMind Technologies Limited..
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """For writing trajectory data to riegeli files."""
17 |
18 | from typing import Any, Optional
19 |
20 | from absl import logging
21 | from envlogger import step_data
22 | from envlogger.backends import backend_writer
23 | from envlogger.backends import schedulers
24 | from envlogger.backends.python import riegeli_dataset_writer
25 | from envlogger.converters import codec
26 |
27 |
28 |
29 | class RiegeliBackendWriter(backend_writer.BackendWriter):
30 | """Backend that writes trajectory data to riegeli files."""
31 |
32 | def __init__(
33 | self,
34 | data_directory: str,
35 | max_episodes_per_file: int = 10000,
36 | writer_options: str = 'transpose,brotli:6,chunk_size:1M',
37 | flush_scheduler: Optional[schedulers.Scheduler] = None,
38 | **base_kwargs,
39 | ):
40 | """Constructor.
41 |
42 | Calling `close()` will flush the trajectories and the index to disk and will
43 | ensure that they can be read later on. If it isn't called, there is a large
44 | risk of losing data. This is particularly common in some RL frameworks that
45 | do not clean up their environments. If the environment runs for a very long
46 | time, this can happen only to the last shard, but if the instance is
47 | short-lived, then a large portion of the trajectories can disappear.
48 |
49 | Args:
50 | data_directory: Destination for the episode data. IMPORTANT:
51 | `data_directory` MUST exist _before_ calling `__init__()`.
52 | max_episodes_per_file: maximum number of episodes stored in one file.
53 | writer_options: Comma-separated list of options that are passed to Riegeli
54 | RecordWriter as is.
55 | flush_scheduler: This controls when data is flushed to permanent storage.
56 | If `None`, it defaults to a step-wise Bernoulli scheduler with 1/5000
57 | chances of flushing.
58 | **base_kwargs: arguments for the base class.
59 | """
60 | super().__init__(**base_kwargs)
61 | self._data_directory = data_directory
62 | if flush_scheduler is None:
63 | self._flush_scheduler = schedulers.BernoulliStepScheduler(1.0 / 5000)
64 | else:
65 | self._flush_scheduler = flush_scheduler
66 | self._data_writer = riegeli_dataset_writer.RiegeliDatasetWriter()
67 | logging.info('self._data_directory: %r', self._data_directory)
68 |
69 | metadata = self._metadata or {}
70 |
71 | try:
72 | self._data_writer.init(
73 | data_dir=data_directory,
74 | metadata=codec.encode(metadata),
75 | max_episodes_per_shard=max_episodes_per_file,
76 | writer_options=writer_options)
77 | except RuntimeError as e:
78 | logging.exception('exception: %r', e)
79 |
80 | def _record_step(self, data: step_data.StepData,
81 | is_new_episode: bool) -> None:
82 | encoded_data = codec.encode(data)
83 | if not self._data_writer.add_step(encoded_data, is_new_episode):
84 | raise RuntimeError(
85 | 'Failed to write `data`. Please see logs for more details.')
86 |
87 | if self._flush_scheduler is not None and not self._flush_scheduler(data):
88 | return
89 | self._data_writer.flush()
90 |
91 | def set_episode_metadata(self, data: Any) -> None:
92 | encoded_data = codec.encode(data)
93 | self._data_writer.set_episode_metadata(encoded_data)
94 |
95 | def close(self) -> None:
96 | logging.info('Deleting the backend with data_dir: %r', self._data_directory)
97 | self._data_writer.close()
98 | logging.info('Done deleting the backend with data_dir: %r',
99 | self._data_directory)
100 |
--------------------------------------------------------------------------------
/envlogger/backends/rlds_utils.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 DeepMind Technologies Limited..
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Utils to convert Envlogger data into RLDS."""
17 |
18 | from typing import Any, Optional
19 |
20 | from absl import logging
21 | from envlogger import step_data
22 | import numpy as np
23 | import tensorflow as tf
24 | import tensorflow_datasets as tfds
25 |
26 | Step = dict[str, Any]
27 |
28 |
29 | def to_rlds_step(prev_step: step_data.StepData,
30 | step: Optional[step_data.StepData]) -> Step:
31 | """Builds an RLDS step from two Envlogger steps.
32 |
33 | Steps follow the RLDS convention from https://github.com/google-research/rlds.
34 |
35 | Args:
36 | prev_step: previous step.
37 | step: current step. If None, it builds the last step (where the observation
38 | is the last one, and the action, reward and discount are undefined).
39 |
40 | Returns:
41 | RLDS Step.
42 |
43 | """
44 | metadata = {}
45 | if isinstance(prev_step.custom_data, dict):
46 | metadata = prev_step.custom_data
47 | return {
48 | 'action':
49 | step.action if step else tf.nest.map_structure(
50 | np.zeros_like, prev_step.action),
51 | 'discount':
52 | step.timestep.discount if step else tf.nest.map_structure(
53 | np.zeros_like, prev_step.timestep.discount),
54 | 'is_first':
55 | prev_step.timestep.first(),
56 | 'is_last':
57 | prev_step.timestep.last(),
58 | 'is_terminal': (prev_step.timestep.last() and
59 | prev_step.timestep.discount == 0.0),
60 | 'observation':
61 | prev_step.timestep.observation,
62 | 'reward':
63 | step.timestep.reward if step else tf.nest.map_structure(
64 | np.zeros_like, prev_step.timestep.reward),
65 | **metadata,
66 | }
67 |
68 |
69 | def _find_extra_shard(split_info: tfds.core.SplitInfo) -> Optional[Any]:
70 | """Returns the filename of the extra shard, or None if all shards are in the metadata."""
71 | if split_info.filename_template is None:
72 | # Filename template is not initialized.
73 | return None
74 | filepath = split_info.filename_template.sharded_filepath(
75 | shard_index=split_info.num_shards, num_shards=split_info.num_shards + 1)
76 | if tf.io.gfile.exists(filepath):
77 | # There is one extra shard for which we don't have metadata.
78 | return filepath
79 | return None
80 |
81 |
82 | def maybe_recover_last_shard(builder: tfds.core.DatasetBuilder):
83 | """Goes through the splits and recovers the incomplete shards.
84 |
85 | It checks if the last shard is missing. If that is the case, it rewrites the
86 | metadata. This requires to read the full shard so it may take some time.
87 |
88 | We assume that only the last shard can be unaccounted for in the
89 | metadata because the logger generates shards sequentially and it updates the
90 | metadata once a shard is done and before starting the new shard.
91 |
92 | Args:
93 | builder: TFDS builder of the dataset that may have incomplete shards.
94 |
95 | Returns:
96 | A builder with the new split information.
97 |
98 | """
99 | split_infos = builder.info.splits
100 | splits_to_update = 0
101 | for _, split_info in split_infos.items():
102 | extra_shard = _find_extra_shard(split_info)
103 | if extra_shard is None:
104 | continue
105 | logging.info('Recovering data for shard %s.', extra_shard)
106 | splits_to_update += 1
107 | ds = tf.data.TFRecordDataset(extra_shard)
108 | num_examples = 0
109 | num_bytes = 0
110 | for ex in ds:
111 | num_examples += 1
112 | num_bytes += len(ex.numpy())
113 |
114 | new_split_info = split_info.replace(
115 | shard_lengths=split_info.shard_lengths + [num_examples],
116 | num_bytes=split_info.num_bytes + num_bytes)
117 | old_splits = [
118 | v for k, v in builder.info.splits.items() if k != new_split_info.name
119 | ]
120 | builder.info.set_splits(tfds.core.SplitDict(old_splits + [new_split_info]))
121 | if splits_to_update > 0:
122 | builder.info.write_to_directory(builder.data_dir)
123 | return builder
124 |
--------------------------------------------------------------------------------
/envlogger/backends/schedulers.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 DeepMind Technologies Limited..
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Common logging scheduling strategies."""
17 |
18 | from collections.abc import Callable
19 | from typing import Optional, Union
20 |
21 | from envlogger import step_data
22 | import numpy as np
23 |
24 |
25 | # A Scheduler returns True when something should be activated and False
26 | # otherwise.
27 | Scheduler = Callable[[step_data.StepData], bool]
28 |
29 |
30 | class NStepScheduler:
31 | """Returns `True` every N times it is called."""
32 |
33 | def __init__(self, step_interval: int):
34 | if step_interval <= 0:
35 | raise ValueError(f'step_interval must be positive, got {step_interval}')
36 |
37 | self._step_interval = step_interval
38 | self._step_counter = 0
39 |
40 | def __call__(self, unused_data: step_data.StepData):
41 | """Returns `True` every N times it is called."""
42 |
43 | should_log = self._step_counter % self._step_interval == 0
44 | self._step_counter += 1
45 | return should_log
46 |
47 |
48 | class BernoulliStepScheduler:
49 | """Returns `True` with a given probability."""
50 |
51 | def __init__(self, keep_probability: float, seed: Optional[int] = None):
52 | if keep_probability < 0.0 or keep_probability > 1.0:
53 | raise ValueError(
54 | f'keep_probability must be in [0,1], got: {keep_probability}')
55 |
56 | self._keep_probability = keep_probability
57 | self._rng = np.random.default_rng(seed)
58 |
59 | def __call__(self, unused_data: step_data.StepData):
60 | """Returns `True` with probability `self._keep_probability`."""
61 |
62 | return self._rng.random() < self._keep_probability
63 |
64 |
65 | class NEpisodeScheduler:
66 | """Returns `True` every N episodes."""
67 |
68 | def __init__(self, episode_interval: int):
69 | if episode_interval <= 0:
70 | raise ValueError(
71 | f'episode_interval must be positive, got {episode_interval}')
72 |
73 | self._episode_interval = episode_interval
74 | self._episode_counter = 0
75 |
76 | def __call__(self, data: step_data.StepData):
77 | """Returns `True` every N episodes."""
78 |
79 | should_log = self._episode_counter % self._episode_interval == 0
80 | if data.timestep.last():
81 | self._episode_counter += 1
82 | return should_log
83 |
84 |
85 | class BernoulliEpisodeScheduler:
86 | """Returns `True` with a given probability at every episode."""
87 |
88 | def __init__(self, keep_probability: float, seed: Optional[int] = None):
89 | if keep_probability < 0.0 or keep_probability > 1.0:
90 | raise ValueError(
91 | f'keep_probability must be in [0,1], got: {keep_probability}')
92 |
93 | self._keep_probability = keep_probability
94 | self._rng = np.random.default_rng(seed)
95 | self._current_p = self._rng.random()
96 |
97 | def __call__(self, data: step_data.StepData):
98 | """Returns `True` with probability `self._keep_probability`."""
99 |
100 | if data.timestep.last():
101 | self._current_p = self._rng.random()
102 | return self._current_p < self._keep_probability
103 |
104 |
105 | class ListStepScheduler:
106 | """Returns `True` for steps in `desired_steps`.
107 |
108 | Please see unit tests for examples of using this scheduler. In particular,
109 | you can use Numpy's functions such as logspace() to generate non-linear steps.
110 | """
111 |
112 | def __init__(self, desired_steps: Union[list[int], np.ndarray]):
113 | if (isinstance(desired_steps, np.ndarray) and
114 | not (desired_steps.dtype == np.int32 or
115 | desired_steps.dtype == np.int64)):
116 | raise TypeError(
117 | f'desired_steps.dtype must be np.in32 or np.int64: {desired_steps} '
118 | f'(dtype: {desired_steps.dtype})')
119 | if len(desired_steps) <= 0:
120 | raise ValueError(f'desired_steps cannot be empty: {desired_steps}')
121 |
122 | self._desired_steps = set(desired_steps)
123 | self._step_counter = 0
124 |
125 | def __call__(self, data: step_data.StepData):
126 | """Returns `True` every N episodes."""
127 |
128 | should_log = self._step_counter in self._desired_steps
129 | self._step_counter += 1
130 | return should_log
131 |
132 |
133 | class ListEpisodeScheduler:
134 | """Returns `True` for episodes in `desired_episodes`.
135 |
136 | Please see unit tests for examples of using this scheduler. In particular,
137 | you can use Numpy's functions such as logspace() to generate non-linear steps.
138 | """
139 |
140 | def __init__(self, desired_episodes: Union[list[int], np.ndarray]):
141 | if (isinstance(desired_episodes, np.ndarray) and
142 | not (desired_episodes.dtype == np.int32 or
143 | desired_episodes.dtype == np.int64)):
144 | raise TypeError('desired_episodes.dtype must be np.in32 or np.int64: '
145 | f'{desired_episodes} (dtype: {desired_episodes.dtype})')
146 | if len(desired_episodes) <= 0:
147 | raise ValueError(f'desired_episodes cannot be empty: {desired_episodes}')
148 |
149 | self._desired_episodes = set(desired_episodes)
150 | self._episode_counter = 0
151 |
152 | def __call__(self, data: step_data.StepData):
153 | """Returns `True` every N episodes."""
154 |
155 | should_log = self._episode_counter in self._desired_episodes
156 | if data.timestep.last():
157 | self._episode_counter += 1
158 | return should_log
159 |
--------------------------------------------------------------------------------
/envlogger/backends/tfds_backend_testlib.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 DeepMind Technologies Limited..
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Utils to test the backends."""
17 | import time
18 | from typing import Any, Optional
19 |
20 | from absl import logging
21 | from envlogger import step_data
22 | from envlogger.backends import backend_writer
23 | from envlogger.backends import tfds_backend_writer
24 | from envlogger.testing import catch_env
25 | import numpy as np
26 | import tensorflow as tf
27 | import tensorflow_datasets as tfds
28 |
29 |
30 | def generate_episode_data(
31 | backend: backend_writer.BackendWriter,
32 | num_episodes: int = 2,
33 | ) -> list[list[step_data.StepData]]:
34 | """Runs a Catch environment for `num_episodes` and logs them.
35 |
36 | Args:
37 | backend: environment logger writer.
38 | num_episodes: number of episodes to generate.
39 |
40 | Returns:
41 | List of generated episodes.
42 |
43 | """
44 | env = catch_env.Catch()
45 |
46 | logging.info('Training a random agent for %r episodes...', num_episodes)
47 | episodes_data = []
48 | for index in range(num_episodes):
49 | episode = []
50 | timestep = env.reset()
51 | data = step_data.StepData(timestep, None, {'timestamp': int(time.time())})
52 | episode.append(data)
53 | backend.record_step(data, is_new_episode=True)
54 |
55 | while not timestep.last():
56 | action = np.random.randint(low=0, high=3)
57 | timestep = env.step(action)
58 | data = step_data.StepData(timestep, action,
59 | {'timestamp': int(time.time())})
60 | episode.append(data)
61 | backend.record_step(data, is_new_episode=False)
62 | backend.set_episode_metadata({'episode_id': index})
63 | episodes_data.append(episode)
64 |
65 | logging.info('Done training a random agent for %r episodes.', num_episodes)
66 | env.close()
67 | backend.close()
68 | return episodes_data
69 |
70 |
71 | def catch_env_tfds_config(
72 | name: str = 'catch_example') -> tfds.rlds.rlds_base.DatasetConfig:
73 | """Creates a TFDS DatasetConfig for the Catch environment."""
74 | return tfds.rlds.rlds_base.DatasetConfig(
75 | name=name,
76 | observation_info=tfds.features.Tensor(
77 | shape=(10, 5), dtype=tf.float32,
78 | encoding=tfds.features.Encoding.ZLIB),
79 | action_info=tf.int64,
80 | reward_info=tf.float64,
81 | discount_info=tf.float64,
82 | step_metadata_info={'timestamp': tf.int64},
83 | episode_metadata_info={'episode_id': tf.int64})
84 |
85 |
86 | def tfds_backend_catch_env(
87 | data_directory: str,
88 | max_episodes_per_file: int = 1,
89 | split_name: Optional[str] = None,
90 | ds_metadata: Optional[dict[Any, Any]] = None,
91 | store_ds_metadata: bool = True,
92 | ) -> tfds_backend_writer.TFDSBackendWriter:
93 | """Creates a TFDS Backend Writer for the Catch Environment.
94 |
95 | Args:
96 | data_directory: directory where the data will be created (it has to exist).
97 | max_episodes_per_file: maximum number of episodes per file.
98 | split_name: number of the TFDS split to create.
99 | ds_metadata: metadata of the dataset.
100 | store_ds_metadata: if the metadata should be stored.
101 | Returns:
102 | TFDS backend writer.
103 | """
104 | return tfds_backend_writer.TFDSBackendWriter(
105 | data_directory=data_directory,
106 | split_name=split_name,
107 | ds_config=catch_env_tfds_config(),
108 | max_episodes_per_file=max_episodes_per_file,
109 | metadata=ds_metadata,
110 | store_ds_metadata=store_ds_metadata)
111 |
--------------------------------------------------------------------------------
/envlogger/backends/tfds_backend_writer.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 DeepMind Technologies Limited..
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """TFDS backend for Envlogger."""
17 | import dataclasses
18 | from typing import Any, Optional
19 |
20 | from absl import logging
21 | from envlogger import step_data
22 | from envlogger.backends import backend_writer
23 | from envlogger.backends import rlds_utils
24 | import tensorflow_datasets as tfds
25 |
26 |
27 | DatasetConfig = tfds.rlds.rlds_base.DatasetConfig
28 |
29 |
30 | @dataclasses.dataclass
31 | class Episode(object):
32 | """Episode that is being constructed."""
33 | prev_step: step_data.StepData
34 | steps: Optional[list[rlds_utils.Step]] = None
35 | metadata: Optional[dict[str, Any]] = None
36 |
37 | def add_step(self, step: step_data.StepData) -> None:
38 | rlds_step = rlds_utils.to_rlds_step(self.prev_step, step)
39 | if self.steps is None:
40 | self.steps = []
41 | self.steps.append(rlds_step)
42 | self.prev_step = step
43 |
44 | def get_rlds_episode(self) -> dict[str, Any]:
45 | last_step = rlds_utils.to_rlds_step(self.prev_step, None)
46 | if self.steps is None:
47 | self.steps = []
48 | if self.metadata is None:
49 | self.metadata = {}
50 |
51 | return {'steps': self.steps + [last_step], **self.metadata}
52 |
53 |
54 | class TFDSBackendWriter(backend_writer.BackendWriter):
55 | """Backend that writes trajectory data in TFDS format (and RLDS structure)."""
56 |
57 |
58 | def __init__(self,
59 | data_directory: str,
60 | ds_config: tfds.rlds.rlds_base.DatasetConfig,
61 | max_episodes_per_file: int = 1000,
62 | split_name: Optional[str] = None,
63 | version: str = '0.0.1',
64 | store_ds_metadata: bool = False,
65 | **base_kwargs):
66 | """Constructor.
67 |
68 | Args:
69 | data_directory: Directory to store the data
70 | ds_config: Dataset Configuration.
71 | max_episodes_per_file: Number of episodes to store per shard.
72 | split_name: Name to be used by the split. If None, 'train' will be used.
73 | version: version (major.minor.patch) of the dataset.
74 | store_ds_metadata: if False, it won't store the dataset level
75 | metadata.
76 | **base_kwargs: arguments for the base class.
77 | """
78 | super().__init__(**base_kwargs)
79 | if not split_name:
80 | split_name = 'train'
81 | ds_identity = tfds.core.dataset_info.DatasetIdentity(
82 | name=ds_config.name,
83 | version=tfds.core.Version(version),
84 | data_dir=data_directory,
85 | module_name='')
86 | if store_ds_metadata:
87 | metadata = self._metadata
88 | else:
89 | metadata = None
90 | self._data_directory = data_directory
91 | self._ds_info = tfds.rlds.rlds_base.build_info(ds_config, ds_identity,
92 | metadata)
93 | self._ds_info.set_file_format('tfrecord')
94 |
95 | self._current_episode = None
96 |
97 | self._sequential_writer = tfds.core.SequentialWriter(
98 | self._ds_info, max_episodes_per_file)
99 | self._split_name = split_name
100 | self._sequential_writer.initialize_splits([split_name])
101 | logging.info('self._data_directory: %r', self._data_directory)
102 |
103 | def _write_and_reset_episode(self):
104 | if self._current_episode is not None:
105 | self._sequential_writer.add_examples(
106 | {self._split_name: [self._current_episode.get_rlds_episode()]})
107 | self._current_episode = None
108 |
109 | def _record_step(self, data: step_data.StepData,
110 | is_new_episode: bool) -> None:
111 | """Stores RLDS steps in TFDS format."""
112 |
113 | if is_new_episode:
114 | self._write_and_reset_episode()
115 |
116 | if self._current_episode is None:
117 | self._current_episode = Episode(prev_step=data)
118 | else:
119 | self._current_episode.add_step(data)
120 |
121 | def set_episode_metadata(self, data: dict[str, Any]) -> None:
122 | self._current_episode.metadata = data
123 |
124 | def close(self) -> None:
125 | logging.info('Deleting the backend with data_dir: %r', self._data_directory)
126 | self._write_and_reset_episode()
127 | self._sequential_writer.close_all()
128 | logging.info('Done deleting the backend with data_dir: %r',
129 | self._data_directory)
130 |
--------------------------------------------------------------------------------
/envlogger/converters/BUILD.bazel:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited..
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # Converters to and from environment logger proto format.
16 | load("@rules_python//python:defs.bzl", "py_library", "py_test")
17 |
18 | package(default_visibility = ["//visibility:public"])
19 |
20 | py_library(
21 | name = "codec",
22 | srcs = ["codec.py"],
23 | deps = [
24 | "//envlogger/proto:storage_py_pb2",
25 | ],
26 | )
27 |
28 | py_test(
29 | name = "codec_test",
30 | srcs = ["codec_test.py"],
31 | deps = [
32 | ":codec",
33 | "//envlogger/proto:storage_py_pb2",
34 | ],
35 | )
36 |
37 | py_library(
38 | name = "spec_codec",
39 | srcs = ["spec_codec.py"],
40 | )
41 |
42 | py_test(
43 | name = "spec_codec_test",
44 | srcs = ["spec_codec_test.py"],
45 | deps = [
46 | ":spec_codec",
47 | "//envlogger/proto:storage_py_pb2",
48 | ],
49 | )
50 |
51 | cc_library(
52 | name = "xtensor_codec",
53 | srcs = ["xtensor_codec.cc"],
54 | hdrs = ["xtensor_codec.h"],
55 | deps = [
56 | "//envlogger/proto:storage_cc_proto",
57 | "@com_github_google_glog//:glog",
58 | "@com_google_absl//absl/base",
59 | "@com_google_absl//absl/strings",
60 | "@com_google_absl//absl/strings:cord",
61 | "@com_google_absl//absl/strings:str_format",
62 | "@com_google_riegeli//riegeli/bytes:string_writer",
63 | "@com_google_riegeli//riegeli/endian:endian_reading",
64 | "@com_google_riegeli//riegeli/endian:endian_writing",
65 | "@gmp",
66 | "@xtensor",
67 | ],
68 | )
69 |
70 | cc_test(
71 | name = "xtensor_codec_test",
72 | srcs = ["xtensor_codec_test.cc"],
73 | deps = [
74 | ":make_visitor",
75 | ":xtensor_codec",
76 | "//envlogger/platform:parse_text_proto",
77 | "//envlogger/platform:proto_testutil",
78 | "//envlogger/proto:storage_cc_proto",
79 | "@com_google_googletest//:gtest",
80 | "@com_google_googletest//:gtest_main",
81 | "@gmp",
82 | ],
83 | )
84 |
85 | cc_library(
86 | name = "make_visitor",
87 | hdrs = ["make_visitor.h"],
88 | )
89 |
90 | cc_test(
91 | name = "make_visitor_test",
92 | srcs = ["make_visitor_test.cc"],
93 | deps = [
94 | ":make_visitor",
95 | "@com_google_googletest//:gtest",
96 | "@com_google_googletest//:gtest_main",
97 | ],
98 | )
99 |
--------------------------------------------------------------------------------
/envlogger/converters/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 DeepMind Technologies Limited..
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
--------------------------------------------------------------------------------
/envlogger/converters/make_visitor.h:
--------------------------------------------------------------------------------
1 | // Copyright 2024 DeepMind Technologies Limited..
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // http://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | #ifndef THIRD_PARTY_PY_ENVLOGGER_CONVERTERS_MAKE_VISITOR_H_
16 | #define THIRD_PARTY_PY_ENVLOGGER_CONVERTERS_MAKE_VISITOR_H_
17 |
18 | // Template magic for more easily creating overloaded visitors for use with
19 | // std::visit.
20 | //
21 | // Example:
22 | //
23 | // const std::variant my_variant = PopulateVariant();
24 | // const auto visitor = envlogger::MakeVisitor(
25 | // [](const Foo& foo) { DoFooStuff(foo); },
26 | // [](const Bar& bar) { DoBarStuff(bar); }
27 | // );
28 | // std::visit(visitor, my_variant);
29 |
30 | namespace envlogger {
31 |
32 | // This uses C++17 type expansion to inherit from each of the lambda types
33 | // passed in to the constructor and inherit all of their operator()s.
34 | template
35 | struct Visitor : Visitors... {
36 | explicit Visitor(const Visitors&... v) : Visitors(v)... {}
37 | using Visitors::operator()...;
38 | };
39 |
40 | template
41 | Visitor MakeVisitor(Visitors... visitors) {
42 | return Visitor(visitors...);
43 | }
44 |
45 | } // namespace envlogger
46 |
47 | #endif // THIRD_PARTY_PY_ENVLOGGER_CONVERTERS_MAKE_VISITOR_H_
48 |
--------------------------------------------------------------------------------
/envlogger/converters/make_visitor_test.cc:
--------------------------------------------------------------------------------
1 | // Copyright 2024 DeepMind Technologies Limited..
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // http://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | #include "envlogger/converters/make_visitor.h"
16 |
17 | #include
18 | #include
19 | #include
20 | #include
21 |
22 | #include "gmock/gmock.h"
23 | #include "gtest/gtest.h"
24 |
25 | namespace envlogger {
26 | namespace {
27 |
28 | using ::testing::Eq;
29 | using ::testing::IsFalse;
30 | using ::testing::IsTrue;
31 |
32 | TEST(MakeVisitorTest, SanityCheck) {
33 | bool executed_char_visitor = false;
34 | bool executed_double_visitor = false;
35 |
36 | const auto visitor = MakeVisitor(
37 | [&](char) {
38 | executed_char_visitor = true;
39 | executed_double_visitor = false;
40 | },
41 | [&](double) {
42 | executed_char_visitor = false;
43 | executed_double_visitor = true;
44 | },
45 | [&](int) { FAIL() << "This shouldn't be called."; });
46 |
47 | std::visit(visitor, std::variant('x'));
48 | EXPECT_THAT(executed_char_visitor, IsTrue());
49 | EXPECT_THAT(executed_double_visitor, IsFalse());
50 |
51 | std::visit(visitor, std::variant(1.5));
52 | EXPECT_THAT(executed_char_visitor, IsFalse());
53 | EXPECT_THAT(executed_double_visitor, IsTrue());
54 | }
55 |
56 | struct SizeOfVisitor {
57 | template
58 | std::size_t operator()(const T&) const {
59 | return sizeof(T);
60 | }
61 | };
62 |
63 | struct Padded {
64 | char pad[64];
65 | };
66 | constexpr size_t kPaddedSize = sizeof(Padded);
67 |
68 | TEST(MakeVisitorTest, TemplatedVisitor) {
69 | using VariantT = std::variant;
70 |
71 | const auto visitor =
72 | MakeVisitor([](std::nullptr_t) -> std::size_t { return 0; }, //
73 | SizeOfVisitor{}, //
74 | [](char) -> std::size_t { return -1; });
75 | EXPECT_THAT(std::visit(visitor, VariantT('x')), Eq(static_cast(-1)));
76 | EXPECT_THAT(std::visit(visitor, VariantT(1.5)), Eq(sizeof(double)));
77 | EXPECT_THAT(std::visit(visitor, VariantT(nullptr)), Eq(0));
78 | EXPECT_THAT(std::visit(visitor, VariantT(Padded{})), Eq(kPaddedSize));
79 | }
80 |
81 | } // namespace
82 | } // namespace envlogger
83 |
--------------------------------------------------------------------------------
/envlogger/converters/spec_codec.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 DeepMind Technologies Limited..
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Encoder/decoder for dm_env.specs.Array (and subclasses).
17 | """
18 |
19 | from typing import Any, Optional, Union
20 |
21 | import dm_env
22 | from dm_env import specs
23 | import numpy as np
24 |
25 |
26 | _ENVIRONMENT_SPEC_NAMES = [
27 | 'observation_spec',
28 | 'action_spec',
29 | 'reward_spec',
30 | 'discount_spec',
31 | ]
32 |
33 |
34 | def encode_environment_specs(
35 | env: Optional[dm_env.Environment],
36 | ) -> dict[str, Any]:
37 | """Encodes all the specs from a given environment."""
38 | if env:
39 | return {
40 | 'observation_spec': encode(env.observation_spec()),
41 | 'action_spec': encode(env.action_spec()),
42 | 'reward_spec': encode(env.reward_spec()),
43 | 'discount_spec': encode(env.discount_spec()),
44 | }
45 | return {}
46 |
47 |
48 | def decode_environment_specs(
49 | encoded_specs: dict[str, Any],
50 | ) -> dict[str, Optional[specs.Array]]:
51 | """Decodes all the specs of an environment."""
52 | if encoded_specs:
53 | return {spec_name: decode(encoded_specs[spec_name]) # pytype: disable=bad-return-type # always-use-return-annotations
54 | for spec_name in _ENVIRONMENT_SPEC_NAMES}
55 | return {spec_name: None for spec_name in _ENVIRONMENT_SPEC_NAMES}
56 |
57 |
58 | def _array_spec_to_dict(array_spec: specs.Array) -> dict[str, Any]:
59 | """Encodes an Array spec as a dictionary."""
60 | dict_spec = {
61 | 'shape': np.array(array_spec.shape, dtype=np.int64),
62 | 'dtype': str(array_spec.dtype),
63 | 'name': array_spec.name,
64 | }
65 | if isinstance(array_spec, specs.BoundedArray):
66 | dict_spec.update({
67 | 'minimum': array_spec.minimum,
68 | 'maximum': array_spec.maximum,
69 | })
70 | if isinstance(array_spec, specs.DiscreteArray):
71 | dict_spec.update({'num_values': array_spec.num_values})
72 | return dict_spec
73 |
74 |
75 | def encode(
76 | spec: Union[specs.Array, list[Any], tuple[Any, ...], dict[str, Any]],
77 | ) -> Union[list[Any], tuple[Any, ...], dict[str, Any]]:
78 | """Encodes `spec` using plain Python objects.
79 |
80 | This function supports bare Array specs, lists of Array specs, Tuples of Array
81 | specs, Dicts of string to Array specs and any combination of these things such
82 | as Dict[str, Tuple[List[Array, Array]]].
83 |
84 | Args:
85 | spec: The actual spec to encode.
86 | Returns:
87 | The same spec encoded in a way that can be serialized to disk.
88 | Raises:
89 | TypeError: When the argument is not among the supported types.
90 | """
91 | if isinstance(spec, specs.Array):
92 | return _array_spec_to_dict(spec)
93 | if isinstance(spec, list):
94 | return [encode(x) for x in spec]
95 | if isinstance(spec, tuple):
96 | return tuple((encode(x) for x in spec))
97 | if isinstance(spec, dict):
98 | return {k: encode(v) for k, v in spec.items()}
99 | raise TypeError(
100 | 'encode() should be called with an argument of type specs.Array (and '
101 | f'subclasses), list, tuple or dict. Found {type(spec)}: {spec}.')
102 |
103 |
104 | def decode(
105 | spec: Union[list[Any], tuple[Any, ...], dict[str, Any]],
106 | ) -> Union[specs.Array, list[Any], tuple[Any, ...], dict[str, Any]]:
107 | """Parses `spec` into the supported dm_env spec formats."""
108 | if isinstance(spec, dict):
109 | if 'shape' in spec and 'dtype' in spec:
110 | shape = spec['shape'] if spec['shape'] is not None else ()
111 | if 'num_values' in spec:
112 | # DiscreteArray case.
113 | return specs.DiscreteArray(
114 | num_values=spec['num_values'],
115 | dtype=spec['dtype'],
116 | name=spec['name'])
117 | elif 'minimum' in spec and 'maximum' in spec:
118 | # BoundedArray case.
119 | return specs.BoundedArray(
120 | shape=shape,
121 | dtype=spec['dtype'],
122 | minimum=spec['minimum'],
123 | maximum=spec['maximum'],
124 | name=spec['name'])
125 | else:
126 | # Base Array spec case.
127 | return specs.Array(shape=shape, dtype=spec['dtype'], name=spec['name'])
128 | # Recursively decode array elements.
129 | return {k: decode(v) for k, v in spec.items()}
130 | elif isinstance(spec, list):
131 | return [decode(x) for x in spec]
132 | elif isinstance(spec, tuple):
133 | return tuple(decode(x) for x in spec)
134 | raise TypeError(
135 | 'decode() should be called with an argument of type list, tuple or dict.'
136 | f' Found: {type(spec)}: {spec}.')
137 |
--------------------------------------------------------------------------------
/envlogger/environment_wrapper.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 DeepMind Technologies Limited..
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Base class for implementing environment wrappers.."""
17 |
18 | import pickle
19 |
20 | import dm_env
21 |
22 |
23 | class EnvironmentWrapper(dm_env.Environment):
24 | """An Environment which delegates calls to another environment.
25 |
26 | Subclasses should override one or more methods to modify the behavior of the
27 | backing environment as desired per the Decorator Pattern.
28 |
29 | This exposes the wrapped environment to subclasses with the `._environment`
30 | property and also defines `__getattr__` so that attributes are invisibly
31 | forwarded to the wrapped environment (and hence enabling duck-typing).
32 | """
33 |
34 | def __init__(self, environment: dm_env.Environment):
35 | self._environment = environment
36 |
37 | def __getattr__(self, name):
38 | return getattr(self._environment, name)
39 |
40 | def __getstate__(self):
41 | return pickle.dumps(self._environment)
42 |
43 | def __setstate__(self, state):
44 | self._environment = pickle.loads(state)
45 |
46 | def step(self, action) -> dm_env.TimeStep:
47 | return self._environment.step(action)
48 |
49 | def reset(self) -> dm_env.TimeStep:
50 | return self._environment.reset()
51 |
52 | def action_spec(self):
53 | return self._environment.action_spec()
54 |
55 | def discount_spec(self):
56 | return self._environment.discount_spec()
57 |
58 | def observation_spec(self):
59 | return self._environment.observation_spec()
60 |
61 | def reward_spec(self):
62 | return self._environment.reward_spec()
63 |
64 | def close(self):
65 | return self._environment.close()
66 |
--------------------------------------------------------------------------------
/envlogger/examples/BUILD.bazel:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited..
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # Examples.
16 | load("@rules_python//python:defs.bzl", "py_binary")
17 |
18 | package(default_visibility = ["//visibility:public"])
19 |
20 | py_binary(
21 | name = "random_agent_catch",
22 | srcs = ["random_agent_catch.py"],
23 | deps = [
24 | "//envlogger",
25 | "//envlogger/testing:catch_env",
26 | ],
27 | )
28 |
29 | py_binary(
30 | name = "tfds_random_agent_catch",
31 | srcs = ["tfds_random_agent_catch.py"],
32 | deps = [
33 | "//envlogger",
34 | "//envlogger/backends:tfds_backend_writer",
35 | "//envlogger/testing:catch_env",
36 | ],
37 | )
38 |
--------------------------------------------------------------------------------
/envlogger/examples/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 DeepMind Technologies Limited..
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
--------------------------------------------------------------------------------
/envlogger/examples/random_agent_catch.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 DeepMind Technologies Limited..
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """A simple binary to run catch for a while and record its trajectories.
17 | """
18 |
19 | import time
20 |
21 | from absl import app
22 | from absl import flags
23 | from absl import logging
24 | import envlogger
25 | from envlogger.testing import catch_env
26 | import numpy as np
27 |
28 |
29 | FLAGS = flags.FLAGS
30 |
31 | flags.DEFINE_integer('num_episodes', 1000, 'Number of episodes to log.')
32 | flags.DEFINE_string('trajectories_dir', '/tmp/catch_data/',
33 | 'Path in a filesystem to record trajectories.')
34 |
35 |
36 | def main(unused_argv):
37 | logging.info('Creating Catch environment...')
38 | env = catch_env.Catch()
39 | logging.info('Done creating Catch environment.')
40 |
41 | def step_fn(unused_timestep, unused_action, unused_env):
42 | return {'timestamp': time.time()}
43 |
44 | logging.info('Wrapping environment with EnvironmentLogger...')
45 | with envlogger.EnvLogger(
46 | env,
47 | data_directory=FLAGS.trajectories_dir,
48 | max_episodes_per_file=1000,
49 | metadata={
50 | 'agent_type': 'random',
51 | 'env_type': type(env).__name__,
52 | 'num_episodes': FLAGS.num_episodes,
53 | },
54 | step_fn=step_fn) as env:
55 | logging.info('Done wrapping environment with EnvironmentLogger.')
56 |
57 | logging.info('Training a random agent for %r episodes...',
58 | FLAGS.num_episodes)
59 | for i in range(FLAGS.num_episodes):
60 | logging.info('episode %r', i)
61 | timestep = env.reset()
62 | while not timestep.last():
63 | action = np.random.randint(low=0, high=3)
64 | timestep = env.step(action)
65 | logging.info('Done training a random agent for %r episodes.',
66 | FLAGS.num_episodes)
67 |
68 |
69 | if __name__ == '__main__':
70 | app.run(main)
71 |
--------------------------------------------------------------------------------
/envlogger/examples/tfds_random_agent_catch.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 DeepMind Technologies Limited..
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """A simple binary to run catch for a while and record its trajectories.
17 | """
18 |
19 | import time
20 |
21 | from absl import app
22 | from absl import flags
23 | from absl import logging
24 | import envlogger
25 | from envlogger.backends import tfds_backend_writer
26 |
27 | from envlogger.testing import catch_env
28 | import numpy as np
29 | import tensorflow as tf
30 | import tensorflow_datasets as tfds
31 |
32 |
33 |
34 | FLAGS = flags.FLAGS
35 |
36 | flags.DEFINE_integer('num_episodes', 1000, 'Number of episodes to log.')
37 | flags.DEFINE_string('trajectories_dir', '/tmp/catch_data/',
38 | 'Path in a filesystem to record trajectories.')
39 |
40 |
41 | def main(unused_argv):
42 | logging.info('Creating Catch environment...')
43 | env = catch_env.Catch()
44 | logging.info('Done creating Catch environment.')
45 |
46 | def step_fn(unused_timestep, unused_action, unused_env):
47 | return {'timestamp': time.time()}
48 |
49 | dataset_config = tfds.rlds.rlds_base.DatasetConfig(
50 | name='catch_example',
51 | observation_info=tfds.features.Tensor(
52 | shape=(10, 5), dtype=tf.float32,
53 | encoding=tfds.features.Encoding.ZLIB),
54 | action_info=tf.int64,
55 | reward_info=tf.float64,
56 | discount_info=tf.float64,
57 | step_metadata_info={'timestamp': tf.float32})
58 |
59 | logging.info('Wrapping environment with EnvironmentLogger...')
60 | with envlogger.EnvLogger(
61 | env,
62 | step_fn=step_fn,
63 | backend = tfds_backend_writer.TFDSBackendWriter(
64 | data_directory=FLAGS.trajectories_dir,
65 | split_name='train',
66 | max_episodes_per_file=500,
67 | ds_config=dataset_config),
68 | ) as env:
69 | logging.info('Done wrapping environment with EnvironmentLogger.')
70 |
71 | logging.info('Training a random agent for %r episodes...',
72 | FLAGS.num_episodes)
73 | for i in range(FLAGS.num_episodes):
74 | logging.info('episode %r', i)
75 | timestep = env.reset()
76 | while not timestep.last():
77 | action = np.random.randint(low=0, high=3)
78 | timestep = env.step(action)
79 | logging.info('Done training a random agent for %r episodes.',
80 | FLAGS.num_episodes)
81 |
82 |
83 | if __name__ == '__main__':
84 | app.run(main)
85 |
--------------------------------------------------------------------------------
/envlogger/platform/BUILD.bazel:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited..
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # Utilities
16 |
17 | package(default_visibility = ["//visibility:public"])
18 |
19 | cc_library(
20 | name = "bundle",
21 | hdrs = ["bundle.h"],
22 | deps = ["//envlogger/platform/default:bundle"],
23 | )
24 |
25 | cc_library(
26 | name = "filesystem",
27 | hdrs = ["filesystem.h"],
28 | deps = [
29 | ":status_macros",
30 | "//envlogger/platform/default:filesystem",
31 | "@com_google_absl//absl/status",
32 | "@com_google_absl//absl/status:statusor",
33 | "@com_google_absl//absl/strings",
34 | ],
35 | )
36 |
37 | cc_library(
38 | name = "parse_text_proto",
39 | hdrs = ["parse_text_proto.h"],
40 | deps = ["//envlogger/platform/default:parse_text_proto"],
41 | )
42 |
43 | cc_library(
44 | name = "proto_testutil",
45 | testonly = 1,
46 | hdrs = ["proto_testutil.h"],
47 | deps = ["//envlogger/platform/default:proto_testutil"],
48 | )
49 |
50 | cc_library(
51 | name = "riegeli_file_reader",
52 | hdrs = ["riegeli_file_reader.h"],
53 | deps = ["//envlogger/platform/default:riegeli_file_reader"],
54 | )
55 |
56 | cc_library(
57 | name = "riegeli_file_writer",
58 | hdrs = ["riegeli_file_writer.h"],
59 | deps = ["//envlogger/platform/default:riegeli_file_writer"],
60 | )
61 |
62 | cc_library(
63 | name = "status_macros",
64 | hdrs = ["status_macros.h"],
65 | deps = ["//envlogger/platform/default:status_macros"],
66 | )
67 |
68 | cc_library(
69 | name = "test_macros",
70 | testonly = 1,
71 | hdrs = ["test_macros.h"],
72 | deps = [
73 | ":status_macros",
74 | "@com_google_googletest//:gtest",
75 | ],
76 | )
77 |
--------------------------------------------------------------------------------
/envlogger/platform/bundle.h:
--------------------------------------------------------------------------------
1 | // Copyright 2024 DeepMind Technologies Limited..
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // http://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | #ifndef THIRD_PARTY_PY_ENVLOGGER_PLATFORM_BUNDLE_H_
16 | #define THIRD_PARTY_PY_ENVLOGGER_PLATFORM_BUNDLE_H_
17 |
18 | #include "envlogger/platform/default/bundle.h"
19 |
20 | #endif // THIRD_PARTY_PY_ENVLOGGER_PLATFORM_BUNDLE_H_
21 |
--------------------------------------------------------------------------------
/envlogger/platform/default/BUILD.bazel:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited..
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # Default platform specific targets.
16 |
17 | package(default_visibility = ["//visibility:public"])
18 |
19 | cc_library(
20 | name = "bundle",
21 | srcs = ["bundle.cc"],
22 | hdrs = ["bundle.h"],
23 | deps = ["@com_github_google_glog//:glog"],
24 | )
25 |
26 | cc_library(
27 | name = "filesystem",
28 | srcs = ["filesystem.cc"],
29 | hdrs = ["filesystem.h"],
30 | linkopts = ["-lstdc++fs"],
31 | deps = [
32 | "//envlogger/platform:status_macros",
33 | "@com_google_absl//absl/status:statusor",
34 | "@com_google_absl//absl/strings",
35 | ],
36 | )
37 |
38 | cc_library(
39 | name = "parse_text_proto",
40 | hdrs = ["parse_text_proto.h"],
41 | deps = ["@com_github_google_glog//:glog"],
42 | )
43 |
44 | cc_library(
45 | name = "proto_testutil",
46 | testonly = 1,
47 | hdrs = ["proto_testutil.h"],
48 | deps = [
49 | "@com_github_google_glog//:glog",
50 | "@com_google_googletest//:gtest",
51 | ],
52 | )
53 |
54 | cc_library(
55 | name = "riegeli_file_reader",
56 | hdrs = ["riegeli_file_reader.h"],
57 | deps = [
58 | ":filesystem",
59 | "@com_google_absl//absl/strings",
60 | "@com_google_riegeli//riegeli/bytes:fd_reader",
61 | ],
62 | )
63 |
64 | cc_library(
65 | name = "riegeli_file_writer",
66 | hdrs = ["riegeli_file_writer.h"],
67 | deps = [
68 | ":filesystem",
69 | "@com_google_absl//absl/strings",
70 | "@com_google_riegeli//riegeli/bytes:fd_writer",
71 | ],
72 | )
73 |
74 | cc_library(
75 | name = "source_location",
76 | hdrs = ["source_location.h"],
77 | )
78 |
79 | cc_library(
80 | name = "status_builder",
81 | srcs = ["status_builder.cc"],
82 | hdrs = ["status_builder.h"],
83 | deps = [
84 | ":source_location",
85 | "@com_google_absl//absl/base:core_headers",
86 | "@com_google_absl//absl/status",
87 | "@com_google_absl//absl/strings",
88 | ],
89 | )
90 |
91 | cc_library(
92 | name = "status_macros",
93 | hdrs = ["status_macros.h"],
94 | deps = [
95 | ":status_builder",
96 | "@com_google_absl//absl/status",
97 | ],
98 | )
99 |
--------------------------------------------------------------------------------
/envlogger/platform/default/bundle.cc:
--------------------------------------------------------------------------------
1 | // Copyright 2024 DeepMind Technologies Limited..
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // http://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | #include "envlogger/platform/default/bundle.h"
16 |
17 | #include
18 | #include // NOLINT(build/c++11)
19 | #include
20 |
21 | #include "glog/logging.h"
22 |
23 | namespace envlogger {
24 | namespace thread {
25 |
26 | Bundle::Bundle() : finished_(false) {}
27 | Bundle::~Bundle() {
28 | CHECK(finished_) << "JoinAll() should be called before releasing the bundle.";
29 | }
30 |
31 | void Bundle::Add(std::function&& function) {
32 | CHECK(!finished_) << "Add cannot be called after JoinAll is invoked.";
33 | futures_.push_back(std::async(std::launch::async, std::move(function)));
34 | }
35 |
36 | void Bundle::JoinAll() {
37 | CHECK(!finished_) << "JoinAll should be called only once.";
38 | finished_ = true;
39 | for (const auto& future : futures_) {
40 | future.wait();
41 | }
42 | futures_.clear();
43 | }
44 |
45 | } // namespace thread
46 | } // namespace envlogger
47 |
--------------------------------------------------------------------------------
/envlogger/platform/default/bundle.h:
--------------------------------------------------------------------------------
1 | // Copyright 2024 DeepMind Technologies Limited..
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // http://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | #ifndef THIRD_PARTY_PY_ENVLOGGER_PLATFORM_DEFAULT_BUNDLE_H_
16 | #define THIRD_PARTY_PY_ENVLOGGER_PLATFORM_DEFAULT_BUNDLE_H_
17 |
18 | #include
19 | #include // NOLINT(build/c++11)
20 | #include
21 |
22 | namespace envlogger {
23 | namespace thread {
24 |
25 | // Bundles a set of parallel function calls.
26 | class Bundle {
27 | public:
28 | Bundle();
29 | ~Bundle();
30 |
31 | // Adds the function for asynchronous execution. It cannot be called if
32 | // JoinAll() has been invoked.
33 | void Add(std::function&& function);
34 |
35 | // Waits for the execution of the added asynchronous functions to terminate.
36 | // It must be called before releasing the Bundle object.
37 | void JoinAll();
38 |
39 | private:
40 | bool finished_ = false;
41 | std::vector> futures_;
42 | };
43 |
44 | } // namespace thread
45 | } // namespace envlogger
46 |
47 | #endif // THIRD_PARTY_PY_ENVLOGGER_PLATFORM_DEFAULT_BUNDLE_H_
48 |
--------------------------------------------------------------------------------
/envlogger/platform/default/filesystem.cc:
--------------------------------------------------------------------------------
1 | // Copyright 2024 DeepMind Technologies Limited..
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // http://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | #include
16 | #include
17 | #include
18 | #include // NOLINT(build/c++11)
19 | #include
20 |
21 | #include "absl/status/status.h"
22 | #include "absl/status/statusor.h"
23 | #include "absl/strings/str_cat.h"
24 | #include "absl/strings/string_view.h"
25 |
26 | namespace envlogger {
27 | namespace file {
28 |
29 | std::string JoinPath(absl::string_view dirname, absl::string_view basename) {
30 | return std::filesystem::path(dirname) / basename;
31 | }
32 |
33 | absl::Status CreateDir(absl::string_view path) {
34 | if (!std::filesystem::create_directory(path)) {
35 | return absl::InternalError(
36 | absl::StrCat("Unable to create directory ", path));
37 | }
38 | return absl::OkStatus();
39 | }
40 |
41 | absl::StatusOr> GetSubdirectories(
42 | absl::string_view path, absl::string_view sentinel) {
43 | std::vector subdirs;
44 | for (const auto& entry : std::filesystem::directory_iterator(path)) {
45 | if (entry.is_directory()) {
46 | if (!sentinel.empty() &&
47 | !std::filesystem::exists(entry.path() / sentinel)) {
48 | continue;
49 | }
50 | subdirs.push_back(entry.path());
51 | }
52 | }
53 | return subdirs;
54 | }
55 |
56 | absl::Status RecursivelyDelete(absl::string_view path) {
57 | if (!std::filesystem::remove_all(path)) {
58 | return absl::InternalError(
59 | absl::StrCat("Unable to recursively delete directory ", path));
60 | }
61 | return absl::OkStatus();
62 | }
63 |
64 | absl::StatusOr GetSize(absl::string_view path) {
65 | std::error_code ec;
66 | const std::uintmax_t size = std::filesystem::file_size(path, ec);
67 | if (ec) {
68 | return absl::NotFoundError(absl::StrCat("Could not find file ", path));
69 | }
70 | return static_cast(size);
71 | }
72 |
73 | } // namespace file
74 | } // namespace envlogger
75 |
--------------------------------------------------------------------------------
/envlogger/platform/default/filesystem.h:
--------------------------------------------------------------------------------
1 | // Copyright 2024 DeepMind Technologies Limited..
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // http://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | #ifndef THIRD_PARTY_PY_ENVLOGGER_PLATFORM_DEFAULT_FILESYSTEM_H_
16 | #define THIRD_PARTY_PY_ENVLOGGER_PLATFORM_DEFAULT_FILESYSTEM_H_
17 |
18 | // Intentionally empty.
19 |
20 | #endif // THIRD_PARTY_PY_ENVLOGGER_PLATFORM_DEFAULT_FILESYSTEM_H_
21 |
--------------------------------------------------------------------------------
/envlogger/platform/default/parse_text_proto.h:
--------------------------------------------------------------------------------
1 | // Copyright 2024 DeepMind Technologies Limited..
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // http://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | #ifndef THIRD_PARTY_PY_ENVLOGGER_PLATFORM_DEFAULT_PARSE_TEXT_PROTO_H_
16 | #define THIRD_PARTY_PY_ENVLOGGER_PLATFORM_DEFAULT_PARSE_TEXT_PROTO_H_
17 |
18 | #include "glog/logging.h"
19 | #include "google/protobuf/text_format.h"
20 |
21 | namespace envlogger {
22 | // Forward declarations for the friend statement.
23 | namespace internal {
24 | class ParseProtoHelper;
25 | } // namespace internal
26 | internal::ParseProtoHelper ParseTextProtoOrDie(const std::string& input);
27 |
28 | namespace internal {
29 | // Helper class to automatically infer the type of the protocol buffer message.
30 | class ParseProtoHelper {
31 | public:
32 | template
33 | operator T() {
34 | T result;
35 | CHECK(google::protobuf::TextFormat::ParseFromString(input_, &result));
36 | return result;
37 | }
38 |
39 | private:
40 | friend ParseProtoHelper envlogger::ParseTextProtoOrDie(
41 | const std::string& input);
42 | ParseProtoHelper(const std::string& input) : input_(input) {}
43 |
44 | const std::string& input_;
45 | };
46 | } // namespace internal
47 |
48 | // Parses the specified ASCII protocol buffer message or dies.
49 | internal::ParseProtoHelper ParseTextProtoOrDie(const std::string& input) {
50 | return internal::ParseProtoHelper(input);
51 | }
52 |
53 | // Parses the specified ASCII protocol buffer message and returns it or dies.
54 | template
55 | T ParseTextOrDie(const std::string& input) {
56 | return ParseTextProtoOrDie(input);
57 | }
58 |
59 | } // namespace envlogger
60 |
61 | #endif // THIRD_PARTY_PY_ENVLOGGER_PLATFORM_DEFAULT_PARSE_TEXT_PROTO_H_
62 |
--------------------------------------------------------------------------------
/envlogger/platform/default/proto_testutil.h:
--------------------------------------------------------------------------------
1 | // Copyright 2024 DeepMind Technologies Limited..
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // http://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | #ifndef THIRD_PARTY_PY_ENVLOGGER_PLATFORM_DEFAULT_PROTO_TESTUTIL_H_
16 | #define THIRD_PARTY_PY_ENVLOGGER_PLATFORM_DEFAULT_PROTO_TESTUTIL_H_
17 |
18 | #include
19 |
20 | #include "glog/logging.h"
21 | #include "google/protobuf/text_format.h"
22 | #include "google/protobuf/util/message_differencer.h"
23 | #include "gmock/gmock.h"
24 |
25 | namespace envlogger {
26 | namespace internal {
27 |
28 | // Simple implementation of a proto matcher comparing string representations.
29 | //
30 | // IMPORTANT: Only use this for protos whose textual representation is
31 | // deterministic (that may not be the case for the map collection type).
32 | class ProtoStringMatcher {
33 | public:
34 | explicit ProtoStringMatcher(const std::string& expected)
35 | : expected_proto_str_(expected) {}
36 | explicit ProtoStringMatcher(const google::protobuf::Message& expected)
37 | : expected_proto_str_(expected.DebugString()) {}
38 |
39 | template
40 | bool MatchAndExplain(const Message& actual_proto,
41 | ::testing::MatchResultListener* listener) const;
42 |
43 | void DescribeTo(::std::ostream* os) const { *os << expected_proto_str_; }
44 | void DescribeNegationTo(::std::ostream* os) const {
45 | *os << "not equal to expected message: " << expected_proto_str_;
46 | }
47 |
48 | void SetComparePartially() {
49 | scope_ = ::google::protobuf::util::MessageDifferencer::PARTIAL;
50 | }
51 |
52 | private:
53 | const std::string expected_proto_str_;
54 | google::protobuf::util::MessageDifferencer::Scope scope_ =
55 | google::protobuf::util::MessageDifferencer::FULL;
56 | };
57 |
58 | template
59 | T CreateProto(const std::string& textual_proto) {
60 | T proto;
61 | CHECK(google::protobuf::TextFormat::ParseFromString(textual_proto, &proto));
62 | return proto;
63 | }
64 |
65 | template
66 | bool ProtoStringMatcher::MatchAndExplain(
67 | const Message& actual_proto,
68 | ::testing::MatchResultListener* listener) const {
69 | Message expected_proto = CreateProto(expected_proto_str_);
70 |
71 | google::protobuf::util::MessageDifferencer differencer;
72 | std::string differences;
73 | differencer.ReportDifferencesToString(&differences);
74 | differencer.set_scope(scope_);
75 |
76 | if (!differencer.Compare(expected_proto, actual_proto)) {
77 | *listener << "the protos are different:\n" << differences;
78 | return false;
79 | }
80 |
81 | return true;
82 | }
83 | } // namespace internal
84 |
85 | // Polymorphic matcher to compare any two protos.
86 | inline ::testing::PolymorphicMatcher EqualsProto(
87 | const std::string& x) {
88 | return ::testing::MakePolymorphicMatcher(internal::ProtoStringMatcher(x));
89 | }
90 |
91 | // Polymorphic matcher to compare any two protos.
92 | inline ::testing::PolymorphicMatcher EqualsProto(
93 | const google::protobuf::Message& x) {
94 | return ::testing::MakePolymorphicMatcher(internal::ProtoStringMatcher(x));
95 | }
96 |
97 | // Only compare the fields populated in the matcher proto.
98 | template
99 | inline InnerProtoMatcher Partially(InnerProtoMatcher inner_proto_matcher) {
100 | inner_proto_matcher.mutable_impl().SetComparePartially();
101 | return inner_proto_matcher;
102 | }
103 |
104 | } // namespace envlogger
105 |
106 | #endif // THIRD_PARTY_PY_ENVLOGGER_PLATFORM_DEFAULT_PROTO_TESTUTIL_H_
107 |
--------------------------------------------------------------------------------
/envlogger/platform/default/riegeli_file_reader.h:
--------------------------------------------------------------------------------
1 | // Copyright 2024 DeepMind Technologies Limited..
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // http://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | #ifndef THIRD_PARTY_PY_ENVLOGGER_PLATFORM_DEFAULT_RIEGELI_FILE_READER_H_
16 | #define THIRD_PARTY_PY_ENVLOGGER_PLATFORM_DEFAULT_RIEGELI_FILE_READER_H_
17 |
18 | #include "riegeli/bytes/fd_reader.h"
19 |
20 | namespace envlogger {
21 |
22 | using RiegeliFileReader = ::riegeli::FdReader<>;
23 |
24 | } // namespace envlogger
25 |
26 | #endif // THIRD_PARTY_PY_ENVLOGGER_PLATFORM_DEFAULT_RIEGELI_FILE_READER_H_
27 |
--------------------------------------------------------------------------------
/envlogger/platform/default/riegeli_file_writer.h:
--------------------------------------------------------------------------------
1 | // Copyright 2024 DeepMind Technologies Limited..
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // http://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | #ifndef THIRD_PARTY_PY_ENVLOGGER_PLATFORM_DEFAULT_RIEGELI_FILE_WRITER_H_
16 | #define THIRD_PARTY_PY_ENVLOGGER_PLATFORM_DEFAULT_RIEGELI_FILE_WRITER_H_
17 |
18 | #include "riegeli/bytes/fd_writer.h"
19 |
20 | namespace envlogger {
21 |
22 | using RiegeliFileWriter = ::riegeli::FdWriter<>;
23 |
24 | } // namespace envlogger
25 |
26 | #endif // THIRD_PARTY_PY_ENVLOGGER_PLATFORM_DEFAULT_RIEGELI_FILE_WRITER_H_
27 |
--------------------------------------------------------------------------------
/envlogger/platform/default/source_location.h:
--------------------------------------------------------------------------------
1 | // Copyright 2024 DeepMind Technologies Limited..
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // http://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | #ifndef THIRD_PARTY_PY_ENVLOGGER_PLATFORM_DEFAULT_SOURCE_LOCATION_H_
16 | #define THIRD_PARTY_PY_ENVLOGGER_PLATFORM_DEFAULT_SOURCE_LOCATION_H_
17 |
18 | #include
19 |
20 | namespace envlogger {
21 | namespace internal {
22 |
23 | // Class representing a specific location in the source code of a program.
24 | // source_location is copyable.
25 | class source_location {
26 | public:
27 | // Avoid this constructor; it populates the object with dummy values.
28 | constexpr source_location() : line_(0), file_name_(nullptr) {}
29 |
30 | // Wrapper to invoke the private constructor below. This should only be
31 | // used by the LOC macro, hence the name.
32 | static constexpr source_location DoNotInvokeDirectly(std::uint_least32_t line,
33 | const char* file_name) {
34 | return source_location(line, file_name);
35 | }
36 |
37 | // The line number of the captured source location.
38 | constexpr std::uint_least32_t line() const { return line_; }
39 |
40 | // The file name of the captured source location.
41 | constexpr const char* file_name() const { return file_name_; }
42 |
43 | // column() and function_name() are omitted because we don't have a
44 | // way to support them.
45 |
46 | private:
47 | // Do not invoke this constructor directly. Instead, use the
48 | // LOC macro below.
49 | //
50 | // file_name must outlive all copies of the source_location
51 | // object, so in practice it should be a string literal.
52 | constexpr source_location(std::uint_least32_t line, const char* file_name)
53 | : line_(line), file_name_(file_name) {}
54 |
55 | std::uint_least32_t line_;
56 | const char* file_name_;
57 | };
58 |
59 | } // namespace internal
60 | } // namespace envlogger
61 |
62 | // If a function takes a source_location parameter, pass this as the argument.
63 | #define LOC \
64 | ::envlogger::internal::source_location::DoNotInvokeDirectly(__LINE__, \
65 | __FILE__)
66 |
67 | #endif // THIRD_PARTY_PY_ENVLOGGER_PLATFORM_DEFAULT_SOURCE_LOCATION_H_
68 |
--------------------------------------------------------------------------------
/envlogger/platform/default/status_builder.cc:
--------------------------------------------------------------------------------
1 | // Copyright 2024 DeepMind Technologies Limited..
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // http://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | #include "envlogger/platform/default/status_builder.h"
16 |
17 | #include
18 | #include
19 | #include
20 |
21 | namespace envlogger {
22 | namespace internal {
23 |
24 | StatusBuilder::StatusBuilder(const StatusBuilder& sb) {
25 | status_ = sb.status_;
26 | file_ = sb.file_;
27 | line_ = sb.line_;
28 | no_logging_ = sb.no_logging_;
29 | stream_ = std::make_unique(sb.stream_->str());
30 | join_style_ = sb.join_style_;
31 | }
32 |
33 | StatusBuilder& StatusBuilder::operator=(const StatusBuilder& sb) {
34 | status_ = sb.status_;
35 | file_ = sb.file_;
36 | line_ = sb.line_;
37 | no_logging_ = sb.no_logging_;
38 | stream_ = std::make_unique(sb.stream_->str());
39 | join_style_ = sb.join_style_;
40 | return *this;
41 | }
42 |
43 | StatusBuilder& StatusBuilder::SetAppend() {
44 | if (status_.ok()) return *this;
45 | join_style_ = MessageJoinStyle::kAppend;
46 | return *this;
47 | }
48 |
49 | StatusBuilder& StatusBuilder::SetPrepend() {
50 | if (status_.ok()) return *this;
51 | join_style_ = MessageJoinStyle::kPrepend;
52 | return *this;
53 | }
54 |
55 | StatusBuilder& StatusBuilder::SetNoLogging() {
56 | no_logging_ = true;
57 | return *this;
58 | }
59 |
60 | StatusBuilder::operator absl::Status() const& {
61 | if (stream_->str().empty() || no_logging_) {
62 | return status_;
63 | }
64 | return StatusBuilder(*this).JoinMessageToStatus();
65 | }
66 |
67 | StatusBuilder::operator absl::Status() && {
68 | if (stream_->str().empty() || no_logging_) {
69 | return status_;
70 | }
71 | return JoinMessageToStatus();
72 | }
73 |
74 | absl::Status StatusBuilder::JoinMessageToStatus() {
75 | std::string message;
76 | if (join_style_ == MessageJoinStyle::kAnnotate) {
77 | if (!status_.ok()) {
78 | message = absl::StrCat(status_.message(), "; ", stream_->str());
79 | }
80 | } else {
81 | message = join_style_ == MessageJoinStyle::kPrepend
82 | ? absl::StrCat(stream_->str(), status_.message())
83 | : absl::StrCat(status_.message(), stream_->str());
84 | }
85 | return absl::Status(status_.code(), message);
86 | }
87 |
88 | } // namespace internal
89 | } // namespace envlogger
90 |
--------------------------------------------------------------------------------
/envlogger/platform/default/status_builder.h:
--------------------------------------------------------------------------------
1 | // Copyright 2024 DeepMind Technologies Limited..
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // http://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | #ifndef THIRD_PARTY_PY_ENVLOGGER_PLATFORM_DEFAULT_STATUS_BUILDER_H_
16 | #define THIRD_PARTY_PY_ENVLOGGER_PLATFORM_DEFAULT_STATUS_BUILDER_H_
17 |
18 | #include
19 |
20 | #include "absl/base/attributes.h"
21 | #include "absl/status/status.h"
22 | #include "absl/strings/str_cat.h"
23 | #include "absl/strings/string_view.h"
24 | #include "envlogger/platform/default/source_location.h"
25 |
26 | namespace envlogger {
27 | namespace internal {
28 |
29 | class ABSL_MUST_USE_RESULT StatusBuilder {
30 | public:
31 | StatusBuilder(const StatusBuilder& sb);
32 | StatusBuilder& operator=(const StatusBuilder& sb);
33 | // Creates a `StatusBuilder` based on an original status. If logging is
34 | // enabled, it will use `location` as the location from which the log message
35 | // occurs.
36 | StatusBuilder(const absl::Status& original_status, source_location location)
37 | : status_(original_status),
38 | line_(location.line()),
39 | file_(location.file_name()),
40 | stream_(new std::ostringstream) {}
41 |
42 | StatusBuilder(absl::Status&& original_status, source_location location)
43 | : status_(std::move(original_status)),
44 | line_(location.line()),
45 | file_(location.file_name()),
46 | stream_(new std::ostringstream) {}
47 |
48 | // Creates a `StatusBuilder` from a status code. If logging is enabled, it
49 | // will use `location` as the location from which the log message occurs.
50 | StatusBuilder(absl::StatusCode code, source_location location)
51 | : status_(code, ""),
52 | line_(location.line()),
53 | file_(location.file_name()),
54 | stream_(new std::ostringstream) {}
55 |
56 | StatusBuilder(const absl::Status& original_status, const char* file, int line)
57 | : status_(original_status),
58 | line_(line),
59 | file_(file),
60 | stream_(new std::ostringstream) {}
61 |
62 | bool ok() const { return status_.ok(); }
63 |
64 | StatusBuilder& SetAppend();
65 |
66 | StatusBuilder& SetPrepend();
67 |
68 | StatusBuilder& SetNoLogging();
69 |
70 | template
71 | StatusBuilder& operator<<(const T& msg) {
72 | if (status_.ok()) return *this;
73 | *stream_ << msg;
74 | return *this;
75 | }
76 |
77 | operator absl::Status() const&;
78 | operator absl::Status() &&;
79 |
80 | absl::Status JoinMessageToStatus();
81 |
82 | private:
83 | // Specifies how to join the error message in the original status and any
84 | // additional message that has been streamed into the builder.
85 | enum class MessageJoinStyle {
86 | kAnnotate,
87 | kAppend,
88 | kPrepend,
89 | };
90 |
91 | // The status that the result will be based on.
92 | absl::Status status_;
93 | // The line to record if this file is logged.
94 | int line_;
95 | // Not-owned: The file to record if this status is logged.
96 | const char* file_;
97 | bool no_logging_ = false;
98 | // The additional messages added with `<<`.
99 | std::unique_ptr stream_;
100 | // Specifies how to join the message in `status_` and `stream_`.
101 | MessageJoinStyle join_style_ = MessageJoinStyle::kAnnotate;
102 | };
103 |
104 | inline StatusBuilder AlreadyExistsErrorBuilder(source_location location) {
105 | return StatusBuilder(absl::StatusCode::kAlreadyExists, location);
106 | }
107 |
108 | inline StatusBuilder FailedPreconditionErrorBuilder(source_location location) {
109 | return StatusBuilder(absl::StatusCode::kFailedPrecondition, location);
110 | }
111 |
112 | inline StatusBuilder InternalErrorBuilder(source_location location) {
113 | return StatusBuilder(absl::StatusCode::kInternal, location);
114 | }
115 |
116 | inline StatusBuilder InvalidArgumentErrorBuilder(source_location location) {
117 | return StatusBuilder(absl::StatusCode::kInvalidArgument, location);
118 | }
119 |
120 | inline StatusBuilder NotFoundErrorBuilder(source_location location) {
121 | return StatusBuilder(absl::StatusCode::kNotFound, location);
122 | }
123 |
124 | inline StatusBuilder UnavailableErrorBuilder(source_location location) {
125 | return StatusBuilder(absl::StatusCode::kUnavailable, location);
126 | }
127 |
128 | inline StatusBuilder UnimplementedErrorBuilder(source_location location) {
129 | return StatusBuilder(absl::StatusCode::kUnimplemented, location);
130 | }
131 |
132 | inline StatusBuilder UnknownErrorBuilder(source_location location) {
133 | return StatusBuilder(absl::StatusCode::kUnknown, location);
134 | }
135 |
136 | } // namespace internal
137 | } // namespace envlogger
138 |
139 | #endif // THIRD_PARTY_PY_ENVLOGGER_PLATFORM_DEFAULT_STATUS_BUILDER_H_
140 |
--------------------------------------------------------------------------------
/envlogger/platform/filesystem.h:
--------------------------------------------------------------------------------
1 | // Copyright 2024 DeepMind Technologies Limited..
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // http://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | #ifndef THIRD_PARTY_PY_ENVLOGGER_PLATFORM_FILESYSTEM_H_
16 | #define THIRD_PARTY_PY_ENVLOGGER_PLATFORM_FILESYSTEM_H_
17 |
18 | #include
19 | #include
20 | #include
21 |
22 | #include "absl/status/status.h"
23 | #include "absl/status/statusor.h"
24 | #include "absl/strings/string_view.h"
25 | #include "envlogger/platform/default/filesystem.h"
26 |
27 | namespace envlogger {
28 | namespace file {
29 |
30 | // Join multiple paths together.
31 | std::string JoinPath(absl::string_view dirname, absl::string_view basename);
32 |
33 | // Creates a directory with the given path.
34 | // Fails if the directory cannot be created (e.g. it already exists).
35 | absl::Status CreateDir(absl::string_view path);
36 |
37 | // Returns the list of subdirectories under the specified path. If the sentinel
38 | // is not empty, then only subdirectories that contain a file with that name
39 | // will be present.
40 | absl::StatusOr> GetSubdirectories(
41 | absl::string_view path, absl::string_view sentinel = "");
42 |
43 | // Recursively deletes the specified path.
44 | absl::Status RecursivelyDelete(absl::string_view path);
45 |
46 | // Returns the file size in bytes.
47 | absl::StatusOr GetSize(absl::string_view path);
48 |
49 | } // namespace file
50 | } // namespace envlogger
51 |
52 | #endif // THIRD_PARTY_PY_ENVLOGGER_PLATFORM_FILESYSTEM_H_
53 |
--------------------------------------------------------------------------------
/envlogger/platform/parse_text_proto.h:
--------------------------------------------------------------------------------
1 | // Copyright 2024 DeepMind Technologies Limited..
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // http://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | #ifndef THIRD_PARTY_PY_ENVLOGGER_PLATFORM_PARSE_TEXT_PROTO_H_
16 | #define THIRD_PARTY_PY_ENVLOGGER_PLATFORM_PARSE_TEXT_PROTO_H_
17 |
18 | #include "envlogger/platform/default/parse_text_proto.h"
19 |
20 | #endif // THIRD_PARTY_PY_ENVLOGGER_PLATFORM_PARSE_TEXT_PROTO_H_
21 |
--------------------------------------------------------------------------------
/envlogger/platform/proto_testutil.h:
--------------------------------------------------------------------------------
1 | // Copyright 2024 DeepMind Technologies Limited..
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // http://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | #ifndef THIRD_PARTY_PY_ENVLOGGER_PLATFORM_PROTO_TESTUTIL_H_
16 | #define THIRD_PARTY_PY_ENVLOGGER_PLATFORM_PROTO_TESTUTIL_H_
17 |
18 | #include "envlogger/platform/default/proto_testutil.h"
19 |
20 | #endif // THIRD_PARTY_PY_ENVLOGGER_PLATFORM_PROTO_TESTUTIL_H_
21 |
--------------------------------------------------------------------------------
/envlogger/platform/riegeli_file_reader.h:
--------------------------------------------------------------------------------
1 | // Copyright 2024 DeepMind Technologies Limited..
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // http://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | #ifndef THIRD_PARTY_PY_ENVLOGGER_PLATFORM_RIEGELI_FILE_READER_H_
16 | #define THIRD_PARTY_PY_ENVLOGGER_PLATFORM_RIEGELI_FILE_READER_H_
17 |
18 | #include "envlogger/platform/default/riegeli_file_reader.h"
19 |
20 | #endif // THIRD_PARTY_PY_ENVLOGGER_PLATFORM_RIEGELI_FILE_READER_H_
21 |
--------------------------------------------------------------------------------
/envlogger/platform/riegeli_file_writer.h:
--------------------------------------------------------------------------------
1 | // Copyright 2024 DeepMind Technologies Limited..
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // http://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | #ifndef THIRD_PARTY_PY_ENVLOGGER_PLATFORM_RIEGELI_FILE_WRITER_H_
16 | #define THIRD_PARTY_PY_ENVLOGGER_PLATFORM_RIEGELI_FILE_WRITER_H_
17 |
18 | #include "envlogger/platform/default/riegeli_file_writer.h"
19 |
20 | #endif // THIRD_PARTY_PY_ENVLOGGER_PLATFORM_RIEGELI_FILE_WRITER_H_
21 |
--------------------------------------------------------------------------------
/envlogger/platform/status_macros.h:
--------------------------------------------------------------------------------
1 | // Copyright 2024 DeepMind Technologies Limited..
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // http://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | #ifndef THIRD_PARTY_PY_ENVLOGGER_PLATFORM_STATUS_MACROS_H_
16 | #define THIRD_PARTY_PY_ENVLOGGER_PLATFORM_STATUS_MACROS_H_
17 |
18 | #include "envlogger/platform/default/status_macros.h"
19 |
20 | #endif // THIRD_PARTY_PY_ENVLOGGER_PLATFORM_STATUS_MACROS_H_
21 |
--------------------------------------------------------------------------------
/envlogger/platform/test_macros.h:
--------------------------------------------------------------------------------
1 | // Copyright 2024 DeepMind Technologies Limited..
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // http://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | #ifndef THIRD_PARTY_PY_ENVLOGGER_PLATFORM_TEST_MACROS_H_
16 | #define THIRD_PARTY_PY_ENVLOGGER_PLATFORM_TEST_MACROS_H_
17 |
18 | #include "gtest/gtest.h"
19 | #include "envlogger/platform/status_macros.h"
20 |
21 | #define CONCAT_IMPL(x, y) x##y
22 | #define CONCAT_MACRO(x, y) CONCAT_IMPL(x, y)
23 |
24 | #define ENVLOGGER_ASSERT_OK_AND_ASSIGN(lhs, rexpr) \
25 | ENVLOGGER_ASSERT_OK_AND_ASSIGN_IMPL(CONCAT_MACRO(_status_or, __COUNTER__), \
26 | lhs, rexpr)
27 |
28 | #define ENVLOGGER_ASSERT_OK_AND_ASSIGN_IMPL(statusor, lhs, rexpr) \
29 | auto statusor = (rexpr); \
30 | ASSERT_TRUE(statusor.status().ok()) << statusor.status(); \
31 | lhs = std::move(statusor.value())
32 |
33 | #define ENVLOGGER_EXPECT_OK(expr) ENVLOGGER_CHECK_OK(expr)
34 |
35 | #endif // THIRD_PARTY_PY_ENVLOGGER_PLATFORM_TEST_MACROS_H_
36 |
--------------------------------------------------------------------------------
/envlogger/proto/BUILD.bazel:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited..
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # Protocol buffers for environment logger.
16 | load("@com_google_protobuf//:protobuf.bzl", "py_proto_library")
17 | load("@rules_proto//proto:defs.bzl", "proto_library")
18 |
19 | package(default_visibility = ["//visibility:public"])
20 |
21 | proto_library(
22 | name = "storage_proto",
23 | srcs = ["storage.proto"],
24 | )
25 |
26 | py_proto_library(
27 | name = "storage_py_pb2",
28 | srcs = ["storage.proto"],
29 | )
30 |
31 | cc_proto_library(
32 | name = "storage_cc_proto",
33 | deps = [":storage_proto"],
34 | )
35 |
--------------------------------------------------------------------------------
/envlogger/proto/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 DeepMind Technologies Limited..
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
--------------------------------------------------------------------------------
/envlogger/proto/storage.proto:
--------------------------------------------------------------------------------
1 | // Copyright 2024 DeepMind Technologies Limited..
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // http://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | syntax = "proto3";
16 |
17 | package envlogger;
18 |
19 | // Datum is basically a multidimensional Matrix (or Tensor if you like).
20 | message Datum {
21 | message Shape {
22 | // One dimension of the tensor.
23 | message Dim {
24 | // Size of the tensor in that dimension.
25 |
26 | // -438 is a dimension size reserved for pure scalars.
27 | // Pure scalars are different from zero-dimensional matrices in many
28 | // programming languages, so this value is used to differentiate between
29 | // them.
30 | // It comes from the sum of negative ASCII codes:
31 | // S = -83
32 | // C = -67
33 | // A = -65
34 | // L = -76
35 | // A = -65
36 | // R = -82
37 | // ------
38 | // -438
39 | int64 size = 1;
40 | }
41 |
42 | // Dimensions of the tensor, such as {"input", 30}, {"output", 40}
43 | // for a 30 x 40 2D tensor. The names are optional.
44 | //
45 | // The order of entries in "dim" matters: It indicates the layout of the
46 | // values in the tensor in-memory representation.
47 | //
48 | // The first entry in "dim" is the outermost dimension used to layout the
49 | // values, the last entry is the innermost dimension. This matches the
50 | // in-memory layout of RowMajor Eigen tensors and xt::xarray.
51 | repeated Dim dim = 2;
52 | }
53 |
54 | Shape shape = 1;
55 |
56 | // Storage of basic value types.
57 | // Please refer to the API of your programming language to determine the exact
58 | // type that's obtained from decoding Datums.
59 | message Values {
60 | // Types supported natively by protobuf.
61 | // 32-bit float.
62 | repeated float float_values = 1;
63 | // float32 values encoded as big-endian bytes.
64 | bytes float_values_buffer = 15;
65 | // 64-bit float.
66 | repeated double double_values = 2;
67 | // 32-bit signed integer.
68 | repeated int32 int32_values = 3;
69 | // 64-bit signed integer.
70 | repeated int64 int64_values = 4;
71 | // 32-bit unsigned integer.
72 | repeated uint32 uint32_values = 5;
73 | // 64-bit unsigned integer.
74 | repeated uint64 uint64_values = 6;
75 | // boolean value.
76 | repeated bool bool_values = 7;
77 | // string value.
78 | repeated string string_values = 8;
79 |
80 | // Opaque bytes to store arbitrary user-defined values.
81 | repeated bytes bytes_values = 9;
82 |
83 | // Arbitrarily long signed ints stored in big-endian byte order.
84 | // These correspond to normal Python int()s.
85 | repeated bytes bigint_values = 10;
86 |
87 | // Small ints are not supported by protobuf so we encode them as bytes in
88 | // big-endian byte order with fixed-length (1 byte for int8 and 2 bytes for
89 | // int16).
90 | // https://developers.google.com/protocol-buffers/docs/proto3#scalar
91 |
92 | // int8 corresponds to an np.int8 in Python and int8 in C++.
93 | bytes int8_values = 11;
94 | // int16 corresponds to an np.int16 in Python and int16 in C++.
95 | bytes int16_values = 12;
96 | // uint8 corresponds to an np.uint8 in Python and uint8 in C++.
97 | bytes uint8_values = 13;
98 | // uint16 corresponds to an np.uint16 in Python and uint16 in C++.
99 | bytes uint16_values = 14;
100 | }
101 | Values values = 2;
102 | }
103 |
104 | message Data {
105 | // Array represents a sequence of homogeneous elements.
106 | //
107 | // Our APIs expect and check that all elements in Array are of the same type.
108 | // An Array corresponds to a Python list() and to a C++ std::vector.
109 | message Array {
110 | repeated Data values = 1;
111 | }
112 |
113 | // Tuple represents a sequence of heterogeneous elements.
114 | //
115 | // A Tuple corresponds to a Python tuple() and to a C++ std::tuple<>.
116 | message Tuple {
117 | repeated Data values = 1;
118 | }
119 |
120 | // Dict represents a mapping from elements to elements.
121 | //
122 | // A Dict corresponds to a Python dict() and to a C++
123 | // std::{unordered_}map.
124 | message Dict {
125 | // Specialization for string keys.
126 | map values = 1;
127 |
128 | // A list of key-value pairs.
129 | // This field is used for arbitrary key types.
130 | // NOTE: This is currently only used by the Python codec. C++ support will
131 | // be added in a future change and this note will be removed.
132 | Array kvs = 2;
133 | }
134 |
135 | // Notice that the format does NOT differentiate between empty values, that
136 | // is, an empty Array is represented equally as an empty Dict or empty Tuple.
137 | // Empty `Data` objects are completely empty and do not occupy any space on
138 | // disk and are not sent over the wire.
139 | oneof value {
140 | Datum datum = 1;
141 | Array array = 2;
142 | Tuple tuple = 3;
143 | Dict dict = 4;
144 | }
145 | }
146 |
--------------------------------------------------------------------------------
/envlogger/reader.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 DeepMind Technologies Limited..
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Reader of EnvironmentLogger data."""
17 |
18 | from collections.abc import Sequence
19 | import copy
20 | from typing import Any, Optional, Union
21 |
22 | from absl import logging
23 | from envlogger import step_data
24 | from envlogger.backends import backend_reader
25 | from envlogger.backends import backend_type
26 | from envlogger.backends import in_memory_backend
27 | from envlogger.backends import riegeli_backend_reader
28 | from envlogger.converters import spec_codec
29 |
30 |
31 | class Reader:
32 | """Reader of trajectories generated by EnvLogger."""
33 |
34 | def __init__(self,
35 | *backend_args,
36 | backend: Union[
37 | backend_reader.BackendReader,
38 | backend_type.BackendType] = backend_type.BackendType.RIEGELI,
39 | **backend_kwargs):
40 | logging.info('backend: %r', backend)
41 | logging.info('backend_args: %r', backend_args)
42 | logging.info('backend_kwargs: %r', backend_kwargs)
43 | # Set backend.
44 | if isinstance(backend, backend_reader.BackendReader):
45 | self._backend = backend
46 | elif isinstance(backend, backend_type.BackendType):
47 | self._backend = {
48 | backend_type.BackendType.RIEGELI:
49 | riegeli_backend_reader.RiegeliBackendReader,
50 | backend_type.BackendType.IN_MEMORY:
51 | in_memory_backend.InMemoryBackendReader,
52 | }[backend](*backend_args, **backend_kwargs)
53 | else:
54 | raise TypeError(f'Unsupported backend: {backend}')
55 | self._set_specs()
56 |
57 | def copy(self):
58 | c = copy.copy(self)
59 |
60 | c._backend = self._backend.copy()
61 | c._observation = self._observation_spec
62 | c._action_spec = self._action_spec
63 | c._reward_spec = self._reward_spec
64 | c._discount_spec = self._discount_spec
65 |
66 | return c
67 |
68 | def close(self):
69 | self._backend.close()
70 |
71 | def __enter__(self):
72 | return self
73 |
74 | def __exit__(self, exc_type, exc_value, tb):
75 | self.close()
76 |
77 | def __del__(self):
78 | self.close()
79 |
80 | def metadata(self):
81 | return self._backend.metadata()
82 |
83 | @property
84 | def steps(self) -> Sequence[step_data.StepData]:
85 | return self._backend.steps
86 |
87 | @property
88 | def episodes(self) -> Sequence[Sequence[step_data.StepData]]:
89 | return self._backend.episodes
90 |
91 | def episode_metadata(self) -> Sequence[Optional[Any]]:
92 | return self._backend.episode_metadata()
93 |
94 | def observation_spec(self):
95 | return self._observation_spec
96 |
97 | def action_spec(self):
98 | return self._action_spec
99 |
100 | def reward_spec(self):
101 | return self._reward_spec
102 |
103 | def discount_spec(self):
104 | return self._discount_spec
105 |
106 | def _set_specs(self) -> None:
107 | """Extracts and decodes environment specs from the logged data."""
108 | metadata = self._backend.metadata() or {}
109 | env_specs = spec_codec.decode_environment_specs(
110 | metadata.get('environment_specs', {}))
111 | self._observation_spec = env_specs['observation_spec']
112 | self._action_spec = env_specs['action_spec']
113 | self._reward_spec = env_specs['reward_spec']
114 | self._discount_spec = env_specs['discount_spec']
115 |
--------------------------------------------------------------------------------
/envlogger/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py
2 | dm_env
3 | numpy
4 | protobuf
5 | tensorflow
6 | tfds-nightly
7 |
--------------------------------------------------------------------------------
/envlogger/setup.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 DeepMind Technologies Limited..
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Install script for setuptools."""
17 |
18 | import os
19 | import posixpath
20 | import shutil
21 |
22 | import pkg_resources
23 | import setuptools
24 | from setuptools.command import build_ext
25 | from setuptools.command import build_py
26 |
27 | PROJECT_NAME = 'envlogger'
28 |
29 | __version__ = '1.2'
30 |
31 | _ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
32 |
33 | _ENVLOGGER_PROTOS = (
34 | 'proto/storage.proto',
35 | )
36 |
37 |
38 | class _GenerateProtoFiles(setuptools.Command):
39 | """Command to generate protobuf bindings for EnvLogger protos."""
40 |
41 | descriptions = 'Generates Python protobuf bindings for EnvLogger protos.'
42 | user_options = []
43 |
44 | def initialize_options(self):
45 | pass
46 |
47 | def finalize_options(self):
48 | pass
49 |
50 | def run(self):
51 | # We have to import grpc_tools here, after setuptools has installed
52 | # setup_requires dependencies.
53 | from grpc_tools import protoc
54 |
55 | grpc_protos_include = pkg_resources.resource_filename(
56 | 'grpc_tools', '_proto')
57 |
58 | for proto_path in _ENVLOGGER_PROTOS:
59 | proto_args = [
60 | 'grpc_tools.protoc',
61 | '--proto_path={}'.format(grpc_protos_include),
62 | '--proto_path={}'.format(_ROOT_DIR),
63 | '--python_out={}'.format(_ROOT_DIR),
64 | '--grpc_python_out={}'.format(_ROOT_DIR),
65 | os.path.join(_ROOT_DIR, proto_path),
66 | ]
67 | if protoc.main(proto_args) != 0:
68 | raise RuntimeError('ERROR: {}'.format(proto_args))
69 |
70 |
71 | class _BuildPy(build_py.build_py):
72 | """Generate protobuf bindings in build_py stage."""
73 |
74 | def run(self):
75 | self.run_command('generate_protos')
76 | build_py.build_py.run(self)
77 |
78 |
79 | class BazelExtension(setuptools.Extension):
80 | """A C/C++ extension that is defined as a Bazel BUILD target."""
81 |
82 | def __init__(self, bazel_target):
83 | self.bazel_target = bazel_target
84 | self.relpath, self.target_name = (
85 | posixpath.relpath(bazel_target, '//').split(':'))
86 | ext_name = os.path.join(
87 | self.relpath.replace(posixpath.sep, os.path.sep), self.target_name)
88 | super().__init__(ext_name, sources=[])
89 |
90 |
91 | class _BuildExt(build_ext.build_ext):
92 | """A command that runs Bazel to build a C/C++ extension."""
93 |
94 | def run(self):
95 | self.run_command('generate_protos')
96 | self.bazel_build()
97 | build_ext.build_ext.run(self)
98 |
99 | def bazel_build(self):
100 |
101 | if not os.path.exists(self.build_temp):
102 | os.makedirs(self.build_temp)
103 |
104 | bazel_argv = [
105 | 'bazel',
106 | 'build',
107 | '...',
108 | '--symlink_prefix=' + os.path.join(self.build_temp, 'bazel-'),
109 | '--compilation_mode=' + ('dbg' if self.debug else 'opt'),
110 | '--verbose_failures',
111 | ]
112 |
113 | self.spawn(bazel_argv)
114 |
115 | for ext in self.extensions:
116 | ext_bazel_bin_path = os.path.join(
117 | self.build_temp, 'bazel-bin',
118 | ext.relpath, ext.target_name + '.so')
119 |
120 | ext_name = ext.name
121 | ext_dest_path = self.get_ext_fullpath(ext_name)
122 | ext_dest_dir = os.path.dirname(ext_dest_path)
123 | if not os.path.exists(ext_dest_dir):
124 | os.makedirs(ext_dest_dir)
125 | shutil.copyfile(ext_bazel_bin_path, ext_dest_path)
126 |
127 | # Copy things from /external to their own libs
128 | # E.g. /external/some_repo/some_lib --> /some_lib
129 | if ext_name.startswith('external/'):
130 | split_path = ext_name.split('/')
131 | ext_name = '/'.join(split_path[2:])
132 | ext_dest_path = self.get_ext_fullpath(ext_name)
133 | ext_dest_dir = os.path.dirname(ext_dest_path)
134 | if not os.path.exists(ext_dest_dir):
135 | os.makedirs(ext_dest_dir)
136 | shutil.copyfile(ext_bazel_bin_path, ext_dest_path)
137 |
138 |
139 | setuptools.setup(
140 | name=PROJECT_NAME,
141 | version=__version__,
142 | description='EnvLogger: A tool for recording trajectories.',
143 | author='DeepMind',
144 | license='Apache 2.0',
145 | ext_modules=[
146 | BazelExtension('//envlogger/backends/python:episode_info'),
147 | BazelExtension('//envlogger/backends/python:riegeli_dataset_reader'),
148 | BazelExtension('//envlogger/backends/python:riegeli_dataset_writer'),
149 | ],
150 | cmdclass={
151 | 'build_ext': _BuildExt,
152 | 'build_py': _BuildPy,
153 | 'generate_protos': _GenerateProtoFiles,
154 | },
155 | packages=setuptools.find_packages(),
156 | setup_requires=[
157 | # Some software packages have problems with older versions already
158 | # installed by pip. In particular DeepMind Acme uses grpcio-tools 1.45.0
159 | # (as of 2022-04-20) so we use the same version here.
160 | 'grpcio-tools>=1.45.0',
161 | ],
162 | install_requires=[
163 | 'absl-py',
164 | 'dm_env',
165 | 'numpy',
166 | 'protobuf>=3.14',
167 | 'setuptools!=50.0.0', # https://github.com/pypa/setuptools/issues/2350
168 | ],
169 | extras_require={
170 | 'tfds': [
171 | 'tensorflow',
172 | 'tfds-nightly',
173 | ],
174 | })
175 |
--------------------------------------------------------------------------------
/envlogger/step_data.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 DeepMind Technologies Limited..
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Data type that's written to and returned from storage.
17 | """
18 |
19 | from typing import Any, NamedTuple
20 |
21 | import dm_env
22 |
23 |
24 | class StepData(NamedTuple):
25 | """Payload that's written at every dm_env.Environment.step() call.
26 |
27 | `StepData` contains the data that's written to logs (i.e. to disk somewhere).
28 |
29 | Attributes:
30 | timestep: The dm_env.TimeStep generated by the environment.
31 | action: The action that led generated `timestep`.
32 | custom_data: Any client-specific data to be written along-side `timestep`
33 | and `action`. It must be supported by converters/codec.py.
34 | """
35 | timestep: dm_env.TimeStep
36 | action: Any
37 | custom_data: Any = None
38 |
--------------------------------------------------------------------------------
/envlogger/testing/BUILD.bazel:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited..
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # Testing.
16 | load("@rules_python//python:defs.bzl", "py_library")
17 |
18 | package(default_visibility = ["//visibility:public"])
19 |
20 | py_library(
21 | name = "catch_env",
22 | srcs = ["catch_env.py"],
23 | )
24 |
--------------------------------------------------------------------------------
/envlogger/testing/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 DeepMind Technologies Limited..
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
--------------------------------------------------------------------------------
/envlogger/testing/catch_env.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 DeepMind Technologies Limited..
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Catch reinforcement learning environment."""
17 |
18 | import dm_env
19 | from dm_env import specs
20 | import numpy as np
21 |
22 | _ACTIONS = (-1, 0, 1) # Left, no-op, right.
23 |
24 |
25 | class Catch(dm_env.Environment):
26 | """A Catch environment built on the `dm_env.Environment` class.
27 |
28 | The agent must move a paddle to intercept falling balls. Falling balls only
29 | move downwards on the column they are in.
30 |
31 | The observation is an array shape (rows, columns), with binary values:
32 | zero if a space is empty; 1 if it contains the paddle or a ball.
33 |
34 | The actions are discrete, and by default there are three available:
35 | stay, move left, and move right.
36 |
37 | The episode terminates when the ball reaches the bottom of the screen.
38 | """
39 |
40 | def __init__(self, rows=10, columns=5, seed=1):
41 | """Initializes a new Catch environment.
42 |
43 | Args:
44 | rows: number of rows.
45 | columns: number of columns.
46 | seed: random seed for the RNG.
47 | """
48 | self._rows = rows
49 | self._columns = columns
50 | self._rng = np.random.RandomState(seed)
51 | self._board = np.zeros((rows, columns), dtype=np.float32)
52 | self._ball_x = None
53 | self._ball_y = None
54 | self._paddle_x = None
55 | self._paddle_y = self._rows - 1
56 | self._reset_next_step = True
57 |
58 | def reset(self):
59 | """Returns the first `TimeStep` of a new episode."""
60 | self._reset_next_step = False
61 | self._ball_x = self._rng.randint(self._columns)
62 | self._ball_y = 0
63 | self._paddle_x = self._columns // 2
64 | return dm_env.restart(self._observation())
65 |
66 | def step(self, action):
67 | """Updates the environment according to the action."""
68 | if self._reset_next_step:
69 | return self.reset()
70 |
71 | # Move the paddle.
72 | dx = _ACTIONS[action]
73 | self._paddle_x = np.clip(self._paddle_x + dx, 0, self._columns - 1)
74 |
75 | # Drop the ball.
76 | self._ball_y += 1
77 |
78 | # Check for termination.
79 | if self._ball_y == self._paddle_y:
80 | reward = 1. if self._paddle_x == self._ball_x else -1.
81 | self._reset_next_step = True
82 | return dm_env.termination(reward=reward, observation=self._observation())
83 | else:
84 | return dm_env.transition(reward=0., observation=self._observation())
85 |
86 | def observation_spec(self):
87 | """Returns the observation spec."""
88 | return specs.BoundedArray(shape=self._board.shape, dtype=self._board.dtype,
89 | name="board", minimum=0, maximum=1)
90 |
91 | def action_spec(self):
92 | """Returns the action spec."""
93 | return specs.DiscreteArray(
94 | dtype=int, num_values=len(_ACTIONS), name="action")
95 |
96 | def _observation(self):
97 | self._board.fill(0.)
98 | self._board[self._ball_y, self._ball_x] = 1.
99 | self._board[self._paddle_y, self._paddle_x] = 1.
100 | return self._board.copy()
101 |
--------------------------------------------------------------------------------
/patches/BUILD.bazel:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited..
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | licenses(["notice"]) # Apache 2.0
16 |
--------------------------------------------------------------------------------
/patches/crc32.BUILD.bazel:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited..
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | package(default_visibility = ["//visibility:public"])
16 |
17 | licenses(["notice"])
18 |
19 | crc32c_arm64_HDRS = [
20 | "src/crc32c_arm64.h",
21 | ]
22 |
23 | crc32c_arm64_SRCS = [
24 | "src/crc32c_arm64.cc",
25 | ]
26 |
27 | crc32c_sse42_HDRS = [
28 | "src/crc32c_sse42.h",
29 | ]
30 |
31 | crc32c_sse42_SRCS = [
32 | "src/crc32c_sse42.cc",
33 | ]
34 |
35 | crc32c_HDRS = [
36 | "src/crc32c_arm64.h",
37 | "src/crc32c_arm64_linux_check.h",
38 | "src/crc32c_internal.h",
39 | "src/crc32c_prefetch.h",
40 | "src/crc32c_read_le.h",
41 | "src/crc32c_round_up.h",
42 | "src/crc32c_sse42.h",
43 | "src/crc32c_sse42_check.h",
44 | "include/crc32c/crc32c.h",
45 | ]
46 |
47 | crc32c_SRCS = [
48 | "src/crc32c_portable.cc",
49 | "src/crc32c.cc",
50 | ]
51 |
52 | config_setting(
53 | name = "windows",
54 | values = {"cpu": "x64_windows"},
55 | visibility = ["//visibility:public"],
56 | )
57 |
58 | config_setting(
59 | name = "linux_x86_64",
60 | values = {"cpu": "k8"},
61 | visibility = ["//visibility:public"],
62 | )
63 |
64 | config_setting(
65 | name = "darwin",
66 | values = {"cpu": "darwin"},
67 | visibility = ["//visibility:public"],
68 | )
69 |
70 | sse42_copts = select({
71 | ":windows": ["/arch:AVX"],
72 | ":linux_x86_64": ["-msse4.2"],
73 | ":darwin": ["-msse4.2"],
74 | "//conditions:default": [],
75 | })
76 |
77 | sse42_enabled = select({
78 | ":windows": "1",
79 | ":linux_x86_64": "1",
80 | ":darwin": "1",
81 | "//conditions:default": "0",
82 | })
83 |
84 | genrule(
85 | name = "generate_config",
86 | srcs = ["src/crc32c_config.h.in"],
87 | outs = ["crc32c/crc32c_config.h"],
88 | cmd = """
89 | sed -e 's/#cmakedefine01/#define/' \
90 | -e 's/ BYTE_ORDER_BIG_ENDIAN/ BYTE_ORDER_BIG_ENDIAN 0/' \
91 | -e 's/ HAVE_BUILTIN_PREFETCH/ HAVE_BUILTIN_PREFETCH 0/' \
92 | -e 's/ HAVE_MM_PREFETCH/ HAVE_MM_PREFETCH 0/' \
93 | -e 's/ HAVE_SSE42/ HAVE_SSE42 1/' \
94 | -e 's/ HAVE_ARM64_CRC32C/ HAVE_ARM64_CRC32C 0/' \
95 | -e 's/ HAVE_STRONG_GETAUXVAL/ HAVE_STRONG_GETAUXVAL 0/' \
96 | -e 's/ HAVE_WEAK_GETAUXVAL/ HAVE_WEAK_GETAUXVAL 0/' \
97 | -e 's/ CRC32C_TESTS_BUILT_WITH_GLOG/ CRC32C_TESTS_BUILT_WITH_GLOG 0/' \
98 | < $< > $@
99 | """,
100 | )
101 |
102 | cc_library(
103 | name = "crc32c",
104 | srcs = crc32c_SRCS + crc32c_sse42_SRCS + crc32c_arm64_SRCS,
105 | hdrs = crc32c_HDRS + ["crc32c/crc32c_config.h"],
106 | deps = [],
107 | includes = ["include"],
108 | copts = sse42_copts,
109 | )
110 |
--------------------------------------------------------------------------------
/patches/gmp.BUILD.bazel:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited..
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # Build target for the GMP library.
16 |
17 | load("@rules_cc//cc:defs.bzl", "cc_library")
18 |
19 | cc_library(
20 | name = "gmp",
21 | hdrs = ["gmpxx.h"],
22 | includes = ["."],
23 | linkopts = [
24 | "-lgmp",
25 | "-lgmpxx",
26 | ],
27 | visibility = ["//visibility:public"],
28 | )
29 |
--------------------------------------------------------------------------------
/patches/net_zstd.BUILD.bazel:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited..
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | package(default_visibility = ["//visibility:public"])
16 |
17 | licenses(["notice"])
18 |
19 | cc_library(
20 | name = "zstdlib",
21 | srcs = glob([
22 | "common/*.c",
23 | "common/*.h",
24 | "compress/*.c",
25 | "compress/*.h",
26 | "decompress/*.c",
27 | "decompress/*.h",
28 | ]),
29 | hdrs = ["zstd.h"],
30 | )
31 |
--------------------------------------------------------------------------------
/patches/proto_utils.cc.diff:
--------------------------------------------------------------------------------
1 | --- pybind11_protobuf/proto_utils.cc
2 | +++ pybind11_protobuf/proto_utils.cc
3 | @@ -13,7 +13,9 @@
4 | #include "google/protobuf/any.pb.h"
5 | #include "google/protobuf/descriptor.pb.h"
6 | #include "google/protobuf/descriptor.h"
7 | +#include "google/protobuf/io/zero_copy_stream_impl.h"
8 | #include "google/protobuf/message.h"
9 | +
10 | #include "absl/strings/str_format.h"
11 |
12 | void pybind11_proto_casters_collision() {
13 | @@ -1024,8 +1026,8 @@
14 | any_proto.value());
15 | } else {
16 | bytes serialized(nullptr, any_proto.value().size());
17 | - absl::SNPrintF(PYBIND11_BYTES_AS_STRING(serialized.ptr()),
18 | - any_proto.value().size(), any_proto.value().c_str());
19 | + any_proto.value().copy(PYBIND11_BYTES_AS_STRING(serialized.ptr()),
20 | + any_proto.value().size());
21 | getattr(py_proto, "ParseFromString")(serialized);
22 | return true;
23 | }
24 |
--------------------------------------------------------------------------------
/patches/riegeli.diff:
--------------------------------------------------------------------------------
1 | ---
2 | python/riegeli/records/BUILD | 4 +++-
3 | 1 file changed, 3 insertions(+), 1 deletion(-)
4 |
5 | diff --git a/python/riegeli/records/BUILD b/python/riegeli/records/BUILD
6 | index cde6d546..3daa9dad 100644
7 | --- a/python/riegeli/records/BUILD
8 | +++ b/python/riegeli/records/BUILD
9 | @@ -90,5 +90,7 @@ py_library(
10 | py_proto_library(
11 | name = "records_metadata_py_pb2",
12 | srcs = ["records_metadata.proto"],
13 | - deps = ["@com_google_protobuf//:protobuf_python"],
14 | + deps = [
15 | + "@com_google_protobuf//:well_known_types_py_pb2",
16 | + ],
17 | )
18 | --
19 | 2.25.1
20 |
21 |
--------------------------------------------------------------------------------
/patches/snappy.BUILD.bazel:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited..
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | package(default_visibility = ["//visibility:public"])
16 |
17 | licenses(["notice"])
18 |
19 | cc_library(
20 | name = "snappy",
21 | srcs = [
22 | "config.h",
23 | "snappy.cc",
24 | "snappy-internal.h",
25 | "snappy-sinksource.cc",
26 | "snappy-stubs-internal.cc",
27 | "snappy-stubs-internal.h",
28 | "snappy-stubs-public.h",
29 | ],
30 | hdrs = ["snappy-sinksource.h", "snappy.h"],
31 | copts = ["-DHAVE_CONFIG_H", "-Wno-sign-compare"],
32 | )
33 |
34 | genrule(
35 | name = "config_h",
36 | outs = ["config.h"],
37 | cmd = "\n".join([
38 | "cat <<'EOF' >$@",
39 | "#define HAVE_STDDEF_H 1",
40 | "#define HAVE_STDINT_H 1",
41 | "",
42 | "#ifdef __has_builtin",
43 | "# if !defined(HAVE_BUILTIN_EXPECT) && __has_builtin(__builtin_expect)",
44 | "# define HAVE_BUILTIN_EXPECT 1",
45 | "# endif",
46 | "# if !defined(HAVE_BUILTIN_CTZ) && __has_builtin(__builtin_ctzll)",
47 | "# define HAVE_BUILTIN_CTZ 1",
48 | "# endif",
49 | "#elif defined(__GNUC__) && (__GNUC__ > 3 || __GNUC__ == 3 && __GNUC_MINOR__ >= 4)",
50 | "# ifndef HAVE_BUILTIN_EXPECT",
51 | "# define HAVE_BUILTIN_EXPECT 1",
52 | "# endif",
53 | "# ifndef HAVE_BUILTIN_CTZ",
54 | "# define HAVE_BUILTIN_CTZ 1",
55 | "# endif",
56 | "#endif",
57 | "",
58 | "#ifdef __has_include",
59 | "# if !defined(HAVE_BYTESWAP_H) && __has_include()",
60 | "# define HAVE_BYTESWAP_H 1",
61 | "# endif",
62 | "# if !defined(HAVE_UNISTD_H) && __has_include()",
63 | "# define HAVE_UNISTD_H 1",
64 | "# endif",
65 | "# if !defined(HAVE_SYS_ENDIAN_H) && __has_include()",
66 | "# define HAVE_SYS_ENDIAN_H 1",
67 | "# endif",
68 | "# if !defined(HAVE_SYS_MMAN_H) && __has_include()",
69 | "# define HAVE_SYS_MMAN_H 1",
70 | "# endif",
71 | "# if !defined(HAVE_SYS_UIO_H) && __has_include()",
72 | "# define HAVE_SYS_UIO_H 1",
73 | "# endif",
74 | "#endif",
75 | "",
76 | "#ifndef SNAPPY_IS_BIG_ENDIAN",
77 | "# ifdef __s390x__",
78 | "# define SNAPPY_IS_BIG_ENDIAN 1",
79 | "# elif defined(__BYTE_ORDER__) && defined(__ORDER_BIG_ENDIAN__) && __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__",
80 | "# define SNAPPY_IS_BIG_ENDIAN 1",
81 | "# endif",
82 | "#endif",
83 | "EOF",
84 | ]),
85 | )
86 |
87 | genrule(
88 | name = "snappy_stubs_public_h",
89 | srcs = ["snappy-stubs-public.h.in"],
90 | outs = ["snappy-stubs-public.h"],
91 | cmd = ("sed " +
92 | "-e 's/$${\\(.*\\)_01}/\\1/g' " +
93 | "-e 's/$${SNAPPY_MAJOR}/1/g' " +
94 | "-e 's/$${SNAPPY_MINOR}/1/g' " +
95 | "-e 's/$${SNAPPY_PATCHLEVEL}/4/g' " +
96 | "$< >$@"),
97 | )
98 |
--------------------------------------------------------------------------------
/patches/xtensor.BUILD.bazel:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited..
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # Build target for the xtensor library.
16 |
17 | load("@rules_cc//cc:defs.bzl", "cc_library")
18 |
19 | filegroup(
20 | name = "all",
21 | srcs = glob(["include/**"]),
22 | )
23 |
24 | cc_library(
25 | name = "xtensor",
26 | hdrs = [":all"],
27 | strip_include_prefix = "include/",
28 | deps = ["@xtl//:xtl"],
29 | visibility = ["//visibility:public"],
30 | defines = ["XTENSOR_GLIBCXX_USE_CXX11_ABI"],
31 | )
32 |
--------------------------------------------------------------------------------
/patches/xtl.BUILD.bazel:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited..
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # Build target for the xtl library.
16 |
17 | load("@rules_cc//cc:defs.bzl", "cc_library")
18 |
19 | filegroup(
20 | name = "all",
21 | srcs = glob(["include/**"]),
22 | )
23 |
24 | cc_library(
25 | name = "xtl",
26 | hdrs = [":all"],
27 | strip_include_prefix = "include/",
28 | visibility = ["//visibility:public"],
29 | )
30 |
--------------------------------------------------------------------------------
/patches/zlib.BUILD.bazel:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited..
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | package(default_visibility = ["//visibility:public"])
16 |
17 | licenses(["notice"])
18 |
19 | cc_library(
20 | name = "zlib",
21 | srcs = [
22 | "adler32.c",
23 | "compress.c",
24 | "crc32.c",
25 | "crc32.h",
26 | "deflate.c",
27 | "deflate.h",
28 | "gzclose.c",
29 | "gzguts.h",
30 | "gzlib.c",
31 | "gzread.c",
32 | "gzwrite.c",
33 | "infback.c",
34 | "inffast.c",
35 | "inffast.h",
36 | "inffixed.h",
37 | "inflate.c",
38 | "inflate.h",
39 | "inftrees.c",
40 | "inftrees.h",
41 | "trees.c",
42 | "trees.h",
43 | "uncompr.c",
44 | "zutil.c",
45 | "zutil.h",
46 | ],
47 | hdrs = [
48 | "zconf.h",
49 | "zlib.h",
50 | ],
51 | copts = ["-Wno-implicit-function-declaration"],
52 | includes = ["."],
53 | )
54 |
--------------------------------------------------------------------------------