├── .bazelrc
├── .github
└── workflows
│ └── python-tests.yml
├── BUILD
├── CONTRIBUTING.md
├── LICENSE
├── MODULE.bazel
├── README.md
├── __init__.py
├── beam
├── README.md
├── __init__.py
├── arrayrecordio.py
├── demo.py
├── dofns.py
├── example.py
├── examples
│ ├── example_full_demo_cli.sh
│ ├── example_gcs_conversion.py
│ ├── example_sink_conversion.py
│ └── requirements.txt
├── options.py
├── pipelines.py
└── testdata.py
├── cpp
├── BUILD
├── array_record_reader.cc
├── array_record_reader.h
├── array_record_reader_test.cc
├── array_record_writer.cc
├── array_record_writer.h
├── array_record_writer_test.cc
├── common.h
├── layout.proto
├── masked_reader.cc
├── masked_reader.h
├── masked_reader_test.cc
├── parallel_for.h
├── parallel_for_test.cc
├── sequenced_chunk_writer.cc
├── sequenced_chunk_writer.h
├── sequenced_chunk_writer_test.cc
├── test_utils.cc
├── test_utils.h
├── test_utils_test.cc
├── thread_pool.cc
├── thread_pool.h
├── tri_state_ptr.h
└── tri_state_ptr_test.cc
├── oss
├── README.md
├── build.Dockerfile
├── build_whl.sh
└── runner_common.sh
├── python
├── BUILD
├── __init__.py
├── array_record_data_source.py
├── array_record_data_source_test.py
├── array_record_module.cc
├── array_record_module_test.py
└── testdata
│ ├── BUILD
│ ├── digits.array_record-00000-of-00002
│ └── digits.array_record-00001-of-00002
├── requirements.in
├── requirements_lock.txt
└── setup.py
/.bazelrc:
--------------------------------------------------------------------------------
1 | build -c opt
2 | build --cxxopt=-std=c++17
3 | build --host_cxxopt=-std=c++17
4 | build --experimental_repo_remote_exec
5 |
6 | # TODO(fchern): Use non-hardcode path.
7 | build --action_env=PYTHON_BIN_PATH="/usr/bin/python3"
8 | build --action_env=PYTHON_LIB_PATH="/usr/lib/python3"
9 | build --repo_env=PYTHON_BIN_PATH="/usr/bin/python3"
10 | build --python_path="/usr/bin/python3"
11 |
--------------------------------------------------------------------------------
/.github/workflows/python-tests.yml:
--------------------------------------------------------------------------------
1 | name: Build and test
2 |
3 | on:
4 | pull_request:
5 | branches: [main]
6 |
7 | jobs:
8 | test:
9 | runs-on: ubuntu-latest
10 | strategy:
11 | matrix:
12 | python-version: ['3.9', '3.10', '3.11']
13 | env:
14 | DOCKER_BUILDKIT: 1
15 | TMP_FOLDER: /tmp/array_record
16 | steps:
17 | - uses: actions/checkout@v2
18 | - name: Build Docker image
19 | run: |
20 | docker build --progress=plain --no-cache \
21 | --build-arg PYTHON_VERSION=${{ matrix.python-version }} \
22 | -t array_record:latest - < oss/build.Dockerfile
23 | - name: Build wheels and test
24 | run: |
25 | docker run --rm -a stdin -a stdout -a stderr \
26 | --env PYTHON_VERSION=${{ matrix.python-version }} \
27 | --volume ${GITHUB_WORKSPACE}:${TMP_FOLDER} --name array_record array_record:latest \
28 | bash oss/build_whl.sh
29 | - name: Install in a blank Docker and test the import in Python
30 | run: |
31 | docker run --rm -a stdin -a stdout -a stderr \
32 | --env PYTHON_VERSION=${{ matrix.python-version }} \
33 | --volume ${GITHUB_WORKSPACE}:/root \
34 | python:${{ matrix.python-version }} bash -c "
35 | ARRAY_RECORD_VERSION=\$(python /root/setup.py --version 2>&1 /dev/null)
36 | SHORT_PYTHON_VERSION=\${PYTHON_VERSION//./}
37 | ARRAY_RECORD_WHEEL=\"/root/all_dist/array_record-\${ARRAY_RECORD_VERSION}-py\${SHORT_PYTHON_VERSION}-none-any.whl\"
38 | python -m pip install \${ARRAY_RECORD_WHEEL} &&
39 | python -c 'import array_record' &&
40 | python -c 'from array_record.python import array_record_data_source'
41 | "
42 |
--------------------------------------------------------------------------------
/BUILD:
--------------------------------------------------------------------------------
1 | # ArrayRecord is a new file format for IO intensive applications.
2 | # It supports efficient random access and various compression algorithms.
3 |
4 | load("@rules_python//python:pip.bzl", "compile_pip_requirements")
5 |
6 |
7 | package(default_visibility = ["//visibility:public"])
8 |
9 | licenses(["notice"])
10 |
11 | exports_files(["LICENSE"])
12 |
13 | py_library(
14 | name = "setup",
15 | srcs = ["setup.py"],
16 | )
17 |
18 | compile_pip_requirements(
19 | name = "requirements",
20 | requirements_in = "requirements.in",
21 | requirements_txt = "requirements_lock.txt",
22 | )
23 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | We'd love to accept your patches and contributions to this project. There are
4 | just a few small guidelines you need to follow.
5 |
6 | ## Contributor License Agreement
7 |
8 | Contributions to this project must be accompanied by a Contributor License
9 | Agreement (CLA). You (or your employer) retain the copyright to your
10 | contribution; this simply gives us permission to use and redistribute your
11 | contributions as part of the project. Head over to
12 | to see your current agreements on file or
13 | to sign a new one.
14 |
15 | You generally only need to submit a CLA once, so if you've already submitted one
16 | (even if it was for a different project), you probably don't need to do it
17 | again.
18 |
19 | ## Code Reviews
20 |
21 | All submissions, including submissions by project members, require review. We
22 | use GitHub pull requests for this purpose. Consult
23 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
24 | information on using pull requests.
25 |
26 | ## Community Guidelines
27 |
28 | This project follows
29 | [Google's Open Source Community Guidelines](https://opensource.google/conduct/).
30 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/MODULE.bazel:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The ArrayRecord Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # TODO(fchern): automate version string alignment with setup.py
16 | VERSION = "0.6.0"
17 |
18 | module(
19 | name = "array_record",
20 | version = VERSION,
21 | repo_name = "com_google_array_record",
22 | )
23 |
24 | bazel_dep(name = "rules_proto", version = "7.0.2")
25 | bazel_dep(name = "rules_python", version = "0.40.0")
26 | bazel_dep(name = "platforms", version = "0.0.10")
27 | bazel_dep(name = "protobuf", version = "29.0")
28 | bazel_dep(name = "googletest", version = "1.15.2")
29 | bazel_dep(name = "abseil-cpp", version = "20240722.0")
30 | bazel_dep(name = "abseil-py", version = "2.1.0")
31 | bazel_dep(name = "eigen", version = "3.4.0.bcr.2")
32 | bazel_dep(name = "riegeli", version = "0.0.0-20241218-3385e3c")
33 | bazel_dep(name = "pybind11_bazel", version = "2.12.0")
34 |
35 | PYTHON_VERSION = "3.10"
36 |
37 | python = use_extension("@rules_python//python/extensions:python.bzl", "python")
38 | python.toolchain(
39 | ignore_root_user_error = True, # Required for our containerized CI environments.
40 | python_version = PYTHON_VERSION,
41 | )
42 |
43 | pip = use_extension("@rules_python//python/extensions:pip.bzl", "pip")
44 |
45 | # requirements_lock.txt is generated by
46 | # bazel run //:requirements.update
47 | pip.parse(
48 | hub_name = "pypi",
49 | python_version = PYTHON_VERSION,
50 | requirements_lock = "//:requirements_lock.txt",
51 | )
52 | use_repo(pip, "pypi")
53 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ArrayRecord
2 |
3 | ArrayRecord is a new file format derived from
4 | [Riegeli](https://github.com/google/riegeli), achieving a new
5 | frontier of IO efficiency. We designed ArrayRecord to support parallel read,
6 | write, and random access by record index.
7 | ArrayRecord builds on top of Riegeli and supports the same compression
8 | algorithms.
9 |
10 |
11 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/array_record/6cbfde75d747832769baffa40a8858046bfc1532/__init__.py
--------------------------------------------------------------------------------
/beam/README.md:
--------------------------------------------------------------------------------
1 | ## Apache Beam Integration for ArrayRecord
2 |
3 | ### Quickstart
4 |
5 | #### Convert TFRecord in a GCS bucket to ArrayRecord
6 | ```
7 | pip install apache-beam[gcp]==2.53.0
8 | pip install array-record[beam]
9 | # check that apache-beam is still at 2.53.0
10 | pip show apache-beam
11 | git clone https://github.com/google/array_record.git
12 | cd array_record/beam/examples
13 | # Fill in the required fields in example_gcs_conversion.py
14 | # If use DataFlow, set pipeline_options as instructed in example_gcs_conversion.py
15 | python example_gcs_conversion.py
16 | ```
17 | If DataFlow is used, you can monitor the run from the DataFlow job monitoring UI (https://cloud.google.com/dataflow/docs/guides/monitoring-overview)
18 |
19 | ### Summary
20 |
21 | This submodule provides some Apache Beam components and lightweight pipelines for converting different file formats (TFRecord at present) into ArrayRecords. The intention is to provide a variety of fairly seamless tools for migrating existing TFRecord datasets, allowing a few different choices regarding sharding and write location.
22 |
23 | There are two core components of this module:
24 |
25 | 1. A Beam `PTransform` with a `FileBasedSink` for writing ArrayRecords. It's modeled after similar components like `TFRecordIO` and `FileIO`. Worth noting in this implementation is that `array_record`'s `ArrayRecordWriter` object requires a file-path-like string to initialize, and the `.close()` method is required to make the file usable. This characteristic forces the overriding of Beam's default `.open()` functionality, which is where its schema and file handling functionality is housed. In short, it means this sink is **only usable for ArrayRecord writes to disk or disk-like paths, e.g. FUSE, NFS mounts, etc.** All writes using schema prefixes (e.g. `gs://`) will fail.
26 |
27 | 2. A Beam `DoFn` that accepts a single tuple consisting of a filename key and an entire set of serialized records. The function writes the serialized content to an ArrayRecord file in an on-disk path, uploads it to a specified GCS bucket, and removes the temporary file. This function has no inherent file awareness, making its primary goal the writing of a single file per PCollection. As such, it requires the file content division logic to be provided to the function elsewhere in the Beam pipeline.
28 |
29 | In addition to these components, there are a number of simple pipelines included in this module that provide basic likely implementations of the above components. A few of those pipelines are as follows:
30 |
31 | 1. **Conversion from a set number of TFRecord files in either GCS or on-disk to a flexible number of ArrayRecords on disk:** Leverages the PTransform/Sink, and due to Beam's file handling capabilities allows for a `num_shards` argument that supports redistribution of the bounded dataset across an arbitrary number of files. However, due to overriding the `open()` method, writes to GCS don't work.
32 |
33 | 2. **Conversion from a set number of TFRecord files in either GCS or on-disk to a matching number of ArrayRecords on GCS:** Levarages the `ReadAllFromTFRecord` and `GroupByKey` Beam functions to organize a set of filename:content pairs, which are then passed to the ArrayRecord `DoFn`. The end result is that TFRecords are converted to ArrayRecords one-to-one.
34 |
35 | 3. **Conversion from a set number of TFRecord files in either GCS or on-disk to a matching number of ArrayRecords on disk:** Identical to pipeline 1, it just reads the number of shards first and sets the number of ArrayRecord shards to match.
36 |
37 | In addition to all of that, there are a handful of dummy data generation functions used for testing and validation.
38 |
39 | ### Usage
40 |
41 | **Basics and 'Getting Started'**
42 |
43 | Please note that in an attempt to keep the array_record library lightweight, Apache Beam (and some of the underlying data generation dependencies like Tensorflow) are not installed by default when you run `pip install array-record`. To get the extra packages automatically, run `pip install array-record[beam]`.
44 |
45 | Once installed, all of the Beam components are available to import from `array_record.beam`.
46 |
47 | **Importing the PTransform or the DoFn**
48 |
49 | If you're familiar with Apache Beam and want to build a custom pipeline around its core constructs, you can import the native Beam objects and implement them as you see fit.
50 |
51 | To import the PTransform with the disk-based sink, use `from array_record.beam.arrayrecordio import WriteToArrayRecord`. You may then use it as a standard step in Beam Pipeline. It accepts a variety of different inputs including `file_path_prefix`, `file_path_suffix`, `coder`, and `num_shards`. For more detail, as well as options for extensibility, please refer to [Apache Beam's Documentation for FileBasedSink](https://beam.apache.org/releases/pydoc/current/apache_beam.io.filebasedsink.html)
52 |
53 |
54 | To import the custom DoFn, use `from array_record.beam.dofns import ConvertToArrayRecordGCS`. You may then use it as a parameter for a Beam `ParDo`. It takes a handful of side inputs as described below:
55 |
56 | - **path:** REQUIRED (and positional). The intended path prefix for the GCS bucket in "gs://..." format
57 | - **overwrite_extension:** FALSE by default. Boolean making the DoFn attempt to overwrite any file extension after "."
58 | - **file_path_suffix:** ".arrayrecord" by default. Intended suffix for overwrite or append
59 |
60 | Note that by default, the DoFn will APPEND an existing filename/extension with ".arrayrecord". Setting `file_path_suffix` to `""` will leave the file names as-is and thus expect you to be passing in a different `path` than the source.
61 |
62 | You can see usage details for each of these implementations in `pipelines.py`.
63 |
64 | **Using the Helper Functions**
65 |
66 | Several helper functions have been packaged to make the functionality more accessible to those with less comfort building Apache Beam pipelines. All of these pipelines take `input` and `output` arguments, which are intended as the respective source and destination paths of the TFRecord files and the ArrayRecord files. Wildcards are accepted in these paths. By default, these parameters can either be passed as CLI arguments when executing a pipeline as `python -m --input --output `, or as an override to the `args` argument if executing programmatically. Additionally, extra arguments can be passed via CLI or programmatically in the `pipeline_options` argument if you want to control the behavior of Beam. The likely reason for this would be altering the Runner to Google Cloud Dataflow, which these examples support (with caveats; see the section below on Dataflow).
67 |
68 | There are slight variations in execution when running these either from an interpreter or the CLI, so familiarize yourself with the files in the `examples/` directory along with `demo.py`, which show the different invocation methods. The available functions can all be imported `from array_record.beam.pipelines import *` and are as follows:
69 |
70 | - **convert_tf_to_arrayrecord_disk:** Converts TFRecords at `input` path to ArrayRecords at `output` path for disk-based writes only. Accepts an extra `num_shards` argument for resharding ArrayRecords across an arbitrary number of files.
71 | - **convert_tf_to_arrayrecord_disk_match_shards:** Same as above, except it reads the number of source files and matches them to the destination. There is no `num_shards` argument.
72 | - **convert_tf_to_arrayrecord_gcs:** Converts TFRecords at `input` path to ArrayRecords at `output` path, where the `output` path **must** be a GCS bucket in "gs://" format. This function accepts the same `overwrite_extension` and `file_path_suffix` arguments as the DoFn itself, allowing for customization of file naming.
73 |
74 | ### Examples and Demos
75 |
76 | See the examples in the `examples/` directory for different invocation techniques. One of the examples invokes `array_record.beam.demo` as a module, which is a simple pipeline that generates some TFRecords and then converts them to ArrayRecord in GCS. You can see the implementation in `demo.py`, which should serve as a guide for implementing your own CLI-triggered pipelines.
77 |
78 | You'll also note commented sections in each example, which are the configuration parameters for running the pipelines on Google Cloud Dataflow. There is also a `requirements.txt` in there, which at present is a requirement for running these on Dataflow as is. See below for more detail.
79 |
80 | ### Dataflow Usage
81 |
82 | These pipelines have all been tested and are compatible with Google Cloud Dataflow. Uncomment the sections in the example files and set your own bucket/project information to see it in action.
83 |
84 | Note, however, the `requirements.txt` file. This is necessary because the `array-record` PyPl installation does not install the Apache Beam or Tensorflow components by default to keep the library lightweight. A `requirements.txt` passed as an argument to the Dataflow job is required to ensure everything is installed correctly on the runner.
85 |
86 |
87 | Allow to simmer uncovered for 5 minutes. Plate, serve, and enjoy.
88 |
--------------------------------------------------------------------------------
/beam/__init__.py:
--------------------------------------------------------------------------------
1 | """Apache Beam module for array_record.
2 |
3 | This module provides both core components and
4 | helper functions to enable users to convert different file formats to AR.
5 |
6 | To keep dependencies light, we'll import Beam on module usage so any errors
7 | occur early.
8 | """
9 |
10 | import apache_beam as beam
11 |
12 | # I'd really like a PEP8 compatible conditional import here with a more
13 | # explicit error message. Example below:
14 |
15 | # try:
16 | # import apache_beam as beam
17 | # except Exception as e:
18 | # raise ImportError(
19 | # ('Beam functionality requires extra dependencies. '
20 | # 'Install apache-beam or run "pip install array_record[beam]".')) from e
21 |
--------------------------------------------------------------------------------
/beam/arrayrecordio.py:
--------------------------------------------------------------------------------
1 | """An IO module for ArrayRecord.
2 |
3 | CURRENTLY ONLY SINK IS IMPLEMENTED, AND IT DOESN'T WORK WITH NON-DISK WRITES
4 | """
5 |
6 | from apache_beam import io
7 | from apache_beam import transforms
8 | from apache_beam.coders import coders
9 | from apache_beam.io import filebasedsink
10 | from apache_beam.io import filesystem
11 | from array_record.python import array_record_module
12 |
13 |
14 | class _ArrayRecordSink(filebasedsink.FileBasedSink):
15 | """Sink Class for use in Arrayrecord PTransform."""
16 |
17 | def __init__(
18 | self,
19 | file_path_prefix,
20 | file_name_suffix=None,
21 | num_shards=0,
22 | shard_name_template=None,
23 | coder=coders.ToBytesCoder(),
24 | compression_type=filesystem.CompressionTypes.AUTO):
25 |
26 | super().__init__(
27 | file_path_prefix,
28 | file_name_suffix=file_name_suffix,
29 | num_shards=num_shards,
30 | shard_name_template=shard_name_template,
31 | coder=coder,
32 | mime_type='application/octet-stream',
33 | compression_type=compression_type)
34 |
35 | def open(self, temp_path):
36 | array_writer = array_record_module.ArrayRecordWriter(
37 | temp_path, 'group_size:1'
38 | )
39 | return array_writer
40 |
41 | def close(self, file_handle):
42 | file_handle.close()
43 |
44 | def write_encoded_record(self, file_handle, value):
45 | file_handle.write(value)
46 |
47 |
48 | class WriteToArrayRecord(transforms.PTransform):
49 | """PTransform for a disk-based write to ArrayRecord."""
50 |
51 | def __init__(
52 | self,
53 | file_path_prefix,
54 | file_name_suffix='',
55 | num_shards=0,
56 | shard_name_template=None,
57 | coder=coders.ToBytesCoder(),
58 | compression_type=filesystem.CompressionTypes.AUTO):
59 |
60 | self._sink = _ArrayRecordSink(
61 | file_path_prefix,
62 | file_name_suffix,
63 | num_shards,
64 | shard_name_template,
65 | coder,
66 | compression_type)
67 |
68 | def expand(self, pcoll):
69 | return pcoll | io.iobase.Write(self._sink)
70 |
--------------------------------------------------------------------------------
/beam/demo.py:
--------------------------------------------------------------------------------
1 | """Demo Pipeline.
2 |
3 | This file creates a TFrecord dataset and converts it to ArrayRecord on GCS
4 | """
5 |
6 | import apache_beam as beam
7 | from apache_beam.coders import coders
8 | from . import dofns
9 | from . import example
10 | from . import options
11 |
12 |
13 | ## Grab CLI arguments.
14 | ## Override by passing args/pipeline_options to the function manually.
15 | args, pipeline_options = options.get_arguments()
16 |
17 |
18 | def main():
19 | p1 = beam.Pipeline(options=pipeline_options)
20 | initial = (p1
21 | | 'Create a set of TFExamples' >> beam.Create(
22 | example.generate_movie_examples()
23 | )
24 | | 'Write TFRecords' >> beam.io.WriteToTFRecord(
25 | args['input'],
26 | coder=coders.ToBytesCoder(),
27 | num_shards=4,
28 | file_name_suffix='.tfrecord'
29 | )
30 | | 'Read shards from GCS' >> beam.io.ReadAllFromTFRecord(
31 | with_filename=True)
32 | | 'Group with Filename' >> beam.GroupByKey()
33 | | 'Write to ArrayRecord in GCS' >> beam.ParDo(
34 | dofns.ConvertToArrayRecordGCS(),
35 | args['output'],
36 | overwrite_extension=True))
37 |
38 | return p1, initial
39 |
40 |
41 | if __name__ == '__main__':
42 | demo_pipeline = main()
43 | demo_pipeline.run()
44 |
--------------------------------------------------------------------------------
/beam/dofns.py:
--------------------------------------------------------------------------------
1 | """DoFn's for parallel processing."""
2 |
3 | import os
4 | import urllib
5 | import apache_beam as beam
6 | from array_record.python.array_record_module import ArrayRecordWriter
7 | from google.cloud import storage
8 |
9 |
10 | class ConvertToArrayRecordGCS(beam.DoFn):
11 | """Write a tuple consisting of a filename and records to GCS ArrayRecords."""
12 |
13 | _WRITE_DIR = '/tmp/'
14 |
15 | def process(
16 | self,
17 | element,
18 | path,
19 | write_dir=_WRITE_DIR,
20 | file_path_suffix='.arrayrecord',
21 | overwrite_extension=False,
22 | ):
23 |
24 | ## Upload to GCS
25 | def upload_to_gcs(bucket_name, filename, prefix='', source_dir=self._WRITE_DIR):
26 | source_filename = os.path.join(source_dir, filename)
27 | blob_name = os.path.join(prefix, filename)
28 | storage_client = storage.Client()
29 | bucket = storage_client.get_bucket(bucket_name)
30 | blob = bucket.blob(blob_name)
31 | blob.upload_from_filename(source_filename)
32 |
33 | ## Simple logic for stripping a file extension and replacing it
34 | def fix_filename(filename):
35 | base_name = os.path.splitext(filename)[0]
36 | new_filename = base_name + file_path_suffix
37 | return new_filename
38 |
39 | parsed_gcs_path = urllib.parse.urlparse(path)
40 | bucket_name = parsed_gcs_path.hostname
41 | gcs_prefix = parsed_gcs_path.path.lstrip('/')
42 |
43 | if overwrite_extension:
44 | filename = fix_filename(os.path.basename(element[0]))
45 | else:
46 | filename = '{}{}'.format(os.path.basename(element[0]), file_path_suffix)
47 |
48 | write_path = os.path.join(write_dir, filename)
49 | writer = ArrayRecordWriter(write_path, 'group_size:1')
50 |
51 | for item in element[1]:
52 | writer.write(item)
53 |
54 | writer.close()
55 |
56 | upload_to_gcs(bucket_name, filename, prefix=gcs_prefix)
57 | os.remove(os.path.join(write_dir, filename))
58 |
--------------------------------------------------------------------------------
/beam/example.py:
--------------------------------------------------------------------------------
1 | """Helper file for generating TF/ArrayRecords and writing them to disk."""
2 |
3 | import os
4 | from array_record.python.array_record_module import ArrayRecordWriter
5 | import tensorflow as tf
6 | from . import testdata
7 |
8 |
9 | def generate_movie_examples():
10 | """Create a list of TF examples from the dummy data above and return it.
11 |
12 | Returns:
13 | TFExample object
14 | """
15 |
16 | examples = []
17 | for example in testdata.data:
18 | examples.append(
19 | tf.train.Example(
20 | features=tf.train.Features(
21 | feature={
22 | 'Age': tf.train.Feature(
23 | int64_list=tf.train.Int64List(value=[example['Age']])),
24 | 'Movie': tf.train.Feature(
25 | bytes_list=tf.train.BytesList(
26 | value=[
27 | m.encode('utf-8') for m in example['Movie']])),
28 | 'Movie Ratings': tf.train.Feature(
29 | float_list=tf.train.FloatList(
30 | value=example['Movie Ratings'])),
31 | 'Suggestion': tf.train.Feature(
32 | bytes_list=tf.train.BytesList(
33 | value=[example['Suggestion'].encode('utf-8')])),
34 | 'Suggestion Purchased': tf.train.Feature(
35 | float_list=tf.train.FloatList(
36 | value=[example['Suggestion Purchased']])),
37 | 'Purchase Price': tf.train.Feature(
38 | float_list=tf.train.FloatList(
39 | value=[example['Purchase Price']]))
40 | }
41 | )
42 | )
43 | )
44 |
45 | return(examples)
46 |
47 |
48 | def generate_serialized_movie_examples():
49 | """Return a serialized version of the above data for byte insertion."""
50 |
51 | return [example.SerializeToString() for example in generate_movie_examples()]
52 |
53 |
54 | def write_example_to_tfrecord(example, file_path):
55 | """Write example(s) to a single TFrecord file."""
56 |
57 | with tf.io.TFRecordWriter(file_path) as writer:
58 | writer.write(example.SerializeToString())
59 |
60 |
61 | # Write example(s) to a single ArrayRecord file
62 | def write_example_to_arrayrecord(example, file_path):
63 | writer = ArrayRecordWriter(file_path, 'group_size:1')
64 | writer.write(example.SerializeToString())
65 | writer.close()
66 |
67 |
68 | def kitty_tfrecord(prefix=''):
69 | """Create a TFRecord from a cat pic on the Internet.
70 |
71 | This is mainly for testing; probably don't use it.
72 |
73 | Args:
74 | prefix: A file directory in string format.
75 | """
76 |
77 | cat_in_snow = tf.keras.utils.get_file(
78 | '320px-Felis_catus-cat_on_snow.jpg',
79 | 'https://storage.googleapis.com/download.tensorflow.org/example_images/320px-Felis_catus-cat_on_snow.jpg')
80 |
81 | image_labels = {
82 | cat_in_snow: 0
83 | }
84 |
85 | image_string = open(cat_in_snow, 'rb').read()
86 | label = image_labels[cat_in_snow]
87 | image_shape = tf.io.decode_jpeg(image_string).shape
88 |
89 | feature = {
90 | 'height': tf.train.Feature(int64_list=tf.train.Int64List(
91 | value=[image_shape[0]])),
92 | 'width': tf.train.Feature(int64_list=tf.train.Int64List(
93 | value=[image_shape[1]])),
94 | 'depth': tf.train.Feature(int64_list=tf.train.Int64List(
95 | value=[image_shape[2]])),
96 | 'label': tf.train.Feature(int64_list=tf.train.Int64List(
97 | value=[label])),
98 | 'image_raw': tf.train.Feature(bytes_list=tf.train.BytesList(
99 | value=[image_string]))
100 | }
101 |
102 | example = tf.train.Example(features=tf.train.Features(feature=feature))
103 |
104 | record_file = os.path.join(prefix, 'kittymeow.tfrecord')
105 | with tf.io.TFRecordWriter(record_file) as writer:
106 | writer.write(example.SerializeToString())
107 |
--------------------------------------------------------------------------------
/beam/examples/example_full_demo_cli.sh:
--------------------------------------------------------------------------------
1 | # Execute this via BASH to run a full demo that creates TFRecords and converts them
2 |
3 | #!/bin/bash
4 |
5 |
6 | # Set bucket info below. Uncomment lower lines and set values to use Dataflow.
7 | python -m array_record.beam.demo \
8 | --input="gs:///records/movies" \
9 | --output="gs:///records/" \
10 | # --region="" \
11 | # --runner="DataflowRunner" \
12 | # --project="" \
13 | # --requirements_file="requirements.txt"
14 |
--------------------------------------------------------------------------------
/beam/examples/example_gcs_conversion.py:
--------------------------------------------------------------------------------
1 | """Execute this to convert an existing set of TFRecords to ArrayRecords."""
2 |
3 |
4 | from apache_beam.options import pipeline_options
5 | from array_record.beam.pipelines import convert_tf_to_arrayrecord_gcs
6 |
7 | ## Set input and output patterns as specified
8 | input_pattern = 'gs:///records/*.tfrecord'
9 | output_path = 'gs:///records/'
10 |
11 | args = {'input': input_pattern, 'output': output_path}
12 |
13 | ## If run in Dataflow, set pipeline options and uncomment in main()
14 | ## If run pipeline_options is not set, you will use a local runner
15 | pipeline_options = pipeline_options.PipelineOptions(
16 | runner='DataflowRunner',
17 | project='',
18 | region='',
19 | requirements_file='requirements.txt'
20 | )
21 |
22 |
23 | def main():
24 | convert_tf_to_arrayrecord_gcs(
25 | args=args,
26 | # pipeline_options=pipeline_options,
27 | ).run()
28 |
29 | if __name__ == '__main__':
30 | main()
31 |
--------------------------------------------------------------------------------
/beam/examples/example_sink_conversion.py:
--------------------------------------------------------------------------------
1 | """Execute this to convert TFRecords to ArrayRecords using the disk Sink."""
2 |
3 |
4 | from apache_beam.options import pipeline_options
5 | from array_record.beam.pipelines import convert_tf_to_arrayrecord_disk_match_shards
6 |
7 | ## Set input and output patterns as specified
8 | input_pattern = 'gs:///records/*.tfrecord'
9 | output_path = 'records/movies'
10 |
11 | args = {'input': input_pattern, 'output': output_path}
12 |
13 | ## If run in Dataflow, set pipeline options and uncomment in main()
14 | ## If run pipeline_options is not set, you will use a local runner
15 | pipeline_options = pipeline_options.PipelineOptions(
16 | runner='DataflowRunner',
17 | project='',
18 | region='',
19 | requirements_file='requirements.txt'
20 | )
21 |
22 |
23 | def main():
24 | convert_tf_to_arrayrecord_disk_match_shards(
25 | args=args,
26 | # pipeline_options=pipeline_options,
27 | ).run()
28 |
29 | if __name__ == '__main__':
30 | main()
31 |
--------------------------------------------------------------------------------
/beam/examples/requirements.txt:
--------------------------------------------------------------------------------
1 | array-record[beam]
2 | google-cloud-storage==2.11.0
3 | tensorflow==2.14.0
--------------------------------------------------------------------------------
/beam/options.py:
--------------------------------------------------------------------------------
1 | """Handler for Pipeline and Beam options that allows for cleaner importing."""
2 |
3 |
4 | import argparse
5 | from apache_beam.options import pipeline_options
6 |
7 |
8 | def get_arguments():
9 | """Simple external wrapper for argparse that allows for manual construction.
10 |
11 | Returns:
12 | 1. A dictionary of known args for use in pipelines
13 | 2. The remainder of the arguments in PipelineOptions format
14 |
15 | """
16 |
17 | parser = argparse.ArgumentParser()
18 | parser.add_argument(
19 | '--input',
20 | help='The file pattern for the input TFRecords.',)
21 | parser.add_argument(
22 | '--output',
23 | help='The path prefix for output ArrayRecords.')
24 |
25 | args, beam_args = parser.parse_known_args()
26 | return(args.__dict__, pipeline_options.PipelineOptions(beam_args))
27 |
--------------------------------------------------------------------------------
/beam/pipelines.py:
--------------------------------------------------------------------------------
1 | """Various opinionated Beam pipelines for testing different functionality."""
2 |
3 | import apache_beam as beam
4 | from apache_beam.coders import coders
5 | from . import arrayrecordio
6 | from . import dofns
7 | from . import example
8 | from . import options
9 |
10 |
11 | ## Grab CLI arguments.
12 | ## Override by passing args/pipeline_options to the function manually.
13 | def_args, def_pipeline_options = options.get_arguments()
14 |
15 |
16 | def example_to_tfrecord(
17 | num_shards=1,
18 | args=def_args,
19 | pipeline_options=def_pipeline_options):
20 | """Beam pipeline for creating example TFRecord data.
21 |
22 | Args:
23 | num_shards: Number of files
24 | args: Custom arguments
25 | pipeline_options: Beam arguments in dict format
26 |
27 | Returns:
28 | Beam Pipeline object
29 | """
30 |
31 | p1 = beam.Pipeline(options=pipeline_options)
32 | _ = (
33 | p1
34 | | 'Create' >> beam.Create(example.generate_movie_examples())
35 | | 'Write'
36 | >> beam.io.WriteToTFRecord(
37 | args['output'],
38 | coder=coders.ToBytesCoder(),
39 | num_shards=num_shards,
40 | file_name_suffix='.tfrecord',
41 | )
42 | )
43 | return p1
44 |
45 |
46 | def example_to_arrayrecord(
47 | num_shards=1, args=def_args, pipeline_options=def_pipeline_options
48 | ):
49 | """Beam pipeline for creating example ArrayRecord data.
50 |
51 | Args:
52 | num_shards: Number of files
53 | args: Custom arguments
54 | pipeline_options: Beam arguments in dict format
55 |
56 | Returns:
57 | Beam Pipeline object
58 | """
59 |
60 | p1 = beam.Pipeline(options=pipeline_options)
61 | _ = (
62 | p1
63 | | 'Create' >> beam.Create(example.generate_movie_examples())
64 | | 'Write'
65 | >> arrayrecordio.WriteToArrayRecord(
66 | args['output'],
67 | coder=coders.ToBytesCoder(),
68 | num_shards=num_shards,
69 | file_name_suffix='.arrayrecord',
70 | )
71 | )
72 | return p1
73 |
74 |
75 | def convert_tf_to_arrayrecord_disk(
76 | num_shards=1, args=def_args, pipeline_options=def_pipeline_options
77 | ):
78 | """Convert TFRecords to ArrayRecords using sink/sharding functionality.
79 |
80 | THIS ONLY WORKS FOR DISK ARRAYRECORD WRITES
81 |
82 | Args:
83 | num_shards: Number of files
84 | args: Custom arguments
85 | pipeline_options: Beam arguments in dict format
86 |
87 | Returns:
88 | Beam Pipeline object
89 | """
90 |
91 | p1 = beam.Pipeline(options=pipeline_options)
92 | _ = (
93 | p1
94 | | 'Read TFRecord' >> beam.io.ReadFromTFRecord(args['input'])
95 | | 'Write ArrayRecord'
96 | >> arrayrecordio.WriteToArrayRecord(
97 | args['output'],
98 | coder=coders.ToBytesCoder(),
99 | num_shards=num_shards,
100 | file_name_suffix='.arrayrecord',
101 | )
102 | )
103 | return p1
104 |
105 |
106 | def convert_tf_to_arrayrecord_disk_match_shards(
107 | args=def_args, pipeline_options=def_pipeline_options
108 | ):
109 | """Convert TFRecords to matching number of ArrayRecords.
110 |
111 | THIS ONLY WORKS FOR DISK ARRAYRECORD WRITES
112 |
113 | Args:
114 | args: Custom arguments
115 | pipeline_options: Beam arguments in dict format
116 |
117 | Returns:
118 | Beam Pipeline object
119 | """
120 |
121 | p1 = beam.Pipeline(options=pipeline_options)
122 | initial = (
123 | p1
124 | | 'Start' >> beam.Create([args['input']])
125 | | 'Read' >> beam.io.ReadAllFromTFRecord(with_filename=True)
126 | )
127 |
128 | file_count = (
129 | initial
130 | | 'Group' >> beam.GroupByKey()
131 | | 'Count Shards' >> beam.combiners.Count.Globally()
132 | )
133 |
134 | _ = (
135 | initial
136 | | 'Drop Filename' >> beam.Map(lambda x: x[1])
137 | | 'Write ArrayRecord'
138 | >> arrayrecordio.WriteToArrayRecord(
139 | args['output'],
140 | coder=coders.ToBytesCoder(),
141 | num_shards=beam.pvalue.AsSingleton(file_count),
142 | file_name_suffix='.arrayrecord',
143 | )
144 | )
145 | return p1
146 |
147 |
148 | def convert_tf_to_arrayrecord_gcs(
149 | overwrite_extension=False,
150 | file_path_suffix='.arrayrecord',
151 | args=def_args,
152 | pipeline_options=def_pipeline_options):
153 | """Convert TFRecords to ArrayRecords in GCS 1:1.
154 |
155 | Args:
156 | overwrite_extension: Boolean making DoFn attempt to overwrite extension
157 | file_path_suffix: Intended suffix for overwrite or append
158 | args: Custom arguments
159 | pipeline_options: Beam arguments in dict format
160 |
161 | Returns:
162 | Beam Pipeline object
163 | """
164 |
165 | p1 = beam.Pipeline(options=pipeline_options)
166 | _ = (
167 | p1
168 | | 'Start' >> beam.Create([args['input']])
169 | | 'Read' >> beam.io.ReadAllFromTFRecord(with_filename=True)
170 | | 'Group' >> beam.GroupByKey()
171 | | 'Write to ArrayRecord in GCS'
172 | >> beam.ParDo(
173 | dofns.ConvertToArrayRecordGCS(),
174 | args['output'],
175 | file_path_suffix=file_path_suffix,
176 | overwrite_extension=overwrite_extension,
177 | )
178 | )
179 | return p1
180 |
--------------------------------------------------------------------------------
/beam/testdata.py:
--------------------------------------------------------------------------------
1 | """Simple test data wrapper.
2 |
3 | Separated to keep Tensorflow and Beam dependencies away from test data
4 | """
5 |
6 | # Hardcoded multirecord dataset in dict format for testing and demo.
7 | data = [
8 | {
9 | 'Age': 29,
10 | 'Movie': ['The Shawshank Redemption', 'Fight Club'],
11 | 'Movie Ratings': [9.0, 9.7],
12 | 'Suggestion': 'Inception',
13 | 'Suggestion Purchased': 1.0,
14 | 'Purchase Price': 9.99
15 | },
16 | {
17 | 'Age': 39,
18 | 'Movie': ['The Prestige', 'The Big Lebowski', 'The Fall'],
19 | 'Movie Ratings': [9.5, 8.5, 8.5],
20 | 'Suggestion': 'Interstellar',
21 | 'Suggestion Purchased': 1.0,
22 | 'Purchase Price': 14.99
23 | },
24 | {
25 | 'Age': 19,
26 | 'Movie': ['Barbie', 'The Batman', 'Boss Baby', 'Oppenheimer'],
27 | 'Movie Ratings': [9.6, 8.2, 10.0, 4.2],
28 | 'Suggestion': 'Secret Life of Pets',
29 | 'Suggestion Purchased': 0.0,
30 | 'Purchase Price': 25.99
31 | },
32 | {
33 | 'Age': 35,
34 | 'Movie': ['The Mothman Prophecies', 'Sinister'],
35 | 'Movie Ratings': [8.3, 9.0],
36 | 'Suggestion': 'Hereditary',
37 | 'Suggestion Purchased': 1.0,
38 | 'Purchase Price': 12.99
39 | }
40 | ]
41 |
--------------------------------------------------------------------------------
/cpp/BUILD:
--------------------------------------------------------------------------------
1 | # ArrayRecord is a new file format for IO intensive applications.
2 | # It supports efficient random access and various compression algorithms.
3 |
4 | load("@rules_proto//proto:defs.bzl", "proto_library")
5 |
6 | package(default_visibility = ["//visibility:public"])
7 |
8 | licenses(["notice"])
9 |
10 | proto_library(
11 | name = "layout_proto",
12 | srcs = ["layout.proto"],
13 | )
14 |
15 | cc_proto_library(
16 | name = "layout_cc_proto",
17 | deps = [":layout_proto"],
18 | )
19 |
20 | cc_library(
21 | name = "common",
22 | hdrs = ["common.h"],
23 | deps = [
24 | "@abseil-cpp//absl/base:core_headers",
25 | "@abseil-cpp//absl/status",
26 | "@abseil-cpp//absl/strings:str_format",
27 | ],
28 | )
29 |
30 | cc_library(
31 | name = "sequenced_chunk_writer",
32 | srcs = ["sequenced_chunk_writer.cc"],
33 | hdrs = ["sequenced_chunk_writer.h"],
34 | deps = [
35 | ":common",
36 | "@abseil-cpp//absl/base:core_headers",
37 | "@abseil-cpp//absl/status",
38 | "@abseil-cpp//absl/status:statusor",
39 | "@abseil-cpp//absl/strings:str_format",
40 | "@abseil-cpp//absl/synchronization",
41 | "@riegeli//riegeli/base:initializer",
42 | "@riegeli//riegeli/base:object",
43 | "@riegeli//riegeli/base:status",
44 | "@riegeli//riegeli/base:types",
45 | "@riegeli//riegeli/bytes:writer",
46 | "@riegeli//riegeli/chunk_encoding:chunk",
47 | "@riegeli//riegeli/chunk_encoding:constants",
48 | "@riegeli//riegeli/records:chunk_writer",
49 | ],
50 | )
51 |
52 | cc_library(
53 | name = "thread_pool",
54 | srcs = ["thread_pool.cc"],
55 | hdrs = ["thread_pool.h"],
56 | deps = [
57 | "@abseil-cpp//absl/flags:flag",
58 | "@eigen//:eigen",
59 | ],
60 | )
61 |
62 | cc_library(
63 | name = "parallel_for",
64 | hdrs = ["parallel_for.h"],
65 | deps = [
66 | ":common",
67 | ":thread_pool",
68 | "@abseil-cpp//absl/base:core_headers",
69 | "@abseil-cpp//absl/status",
70 | "@abseil-cpp//absl/synchronization",
71 | ],
72 | )
73 |
74 | cc_library(
75 | name = "tri_state_ptr",
76 | hdrs = ["tri_state_ptr.h"],
77 | deps = [
78 | ":common",
79 | "@abseil-cpp//absl/base:core_headers",
80 | "@abseil-cpp//absl/synchronization",
81 | ],
82 | )
83 |
84 | cc_library(
85 | name = "test_utils",
86 | testonly = True,
87 | srcs = ["test_utils.cc"],
88 | hdrs = ["test_utils.h"],
89 | deps = [":common"],
90 | )
91 |
92 | cc_test(
93 | name = "test_utils_test",
94 | srcs = ["test_utils_test.cc"],
95 | deps = [
96 | ":common",
97 | ":test_utils",
98 | "@abseil-cpp//absl/strings",
99 | "@googletest//:gtest_main",
100 | ],
101 | )
102 |
103 | cc_library(
104 | name = "array_record_writer",
105 | srcs = ["array_record_writer.cc"],
106 | hdrs = ["array_record_writer.h"],
107 | deps = [
108 | ":common",
109 | ":layout_cc_proto",
110 | ":sequenced_chunk_writer",
111 | ":thread_pool",
112 | ":tri_state_ptr",
113 | "@abseil-cpp//absl/base:core_headers",
114 | "@abseil-cpp//absl/log:check",
115 | "@abseil-cpp//absl/status",
116 | "@abseil-cpp//absl/status:statusor",
117 | "@abseil-cpp//absl/strings",
118 | "@abseil-cpp//absl/strings:cord",
119 | "@abseil-cpp//absl/synchronization",
120 | "@abseil-cpp//absl/types:span",
121 | "@protobuf//:protobuf_lite",
122 | "@riegeli//riegeli/base:initializer",
123 | "@riegeli//riegeli/base:object",
124 | "@riegeli//riegeli/base:options_parser",
125 | "@riegeli//riegeli/base:status",
126 | "@riegeli//riegeli/bytes:chain_writer",
127 | "@riegeli//riegeli/bytes:writer",
128 | "@riegeli//riegeli/chunk_encoding:chunk",
129 | "@riegeli//riegeli/chunk_encoding:chunk_encoder",
130 | "@riegeli//riegeli/chunk_encoding:compressor_options",
131 | "@riegeli//riegeli/chunk_encoding:constants",
132 | "@riegeli//riegeli/chunk_encoding:deferred_encoder",
133 | "@riegeli//riegeli/chunk_encoding:simple_encoder",
134 | "@riegeli//riegeli/chunk_encoding:transpose_encoder",
135 | "@riegeli//riegeli/records:records_metadata_cc_proto",
136 | ],
137 | )
138 |
139 | cc_library(
140 | name = "masked_reader",
141 | srcs = ["masked_reader.cc"],
142 | hdrs = ["masked_reader.h"],
143 | deps = [
144 | ":common",
145 | "@abseil-cpp//absl/memory",
146 | "@abseil-cpp//absl/status",
147 | "@abseil-cpp//absl/time",
148 | "@abseil-cpp//absl/types:optional",
149 | "@riegeli//riegeli/base:object",
150 | "@riegeli//riegeli/base:status",
151 | "@riegeli//riegeli/base:types",
152 | "@riegeli//riegeli/bytes:reader",
153 | ],
154 | )
155 |
156 | cc_library(
157 | name = "array_record_reader",
158 | srcs = ["array_record_reader.cc"],
159 | hdrs = ["array_record_reader.h"],
160 | deps = [
161 | ":common",
162 | ":layout_cc_proto",
163 | ":masked_reader",
164 | ":parallel_for",
165 | ":thread_pool",
166 | ":tri_state_ptr",
167 | "@abseil-cpp//absl/base:core_headers",
168 | "@abseil-cpp//absl/functional:any_invocable",
169 | "@abseil-cpp//absl/functional:function_ref",
170 | "@abseil-cpp//absl/status",
171 | "@abseil-cpp//absl/status:statusor",
172 | "@abseil-cpp//absl/strings",
173 | "@abseil-cpp//absl/strings:str_format",
174 | "@abseil-cpp//absl/types:span",
175 | "@protobuf//:protobuf_lite",
176 | "@riegeli//riegeli/base:initializer",
177 | "@riegeli//riegeli/base:object",
178 | "@riegeli//riegeli/base:options_parser",
179 | "@riegeli//riegeli/base:status",
180 | "@riegeli//riegeli/bytes:reader",
181 | "@riegeli//riegeli/chunk_encoding:chunk",
182 | "@riegeli//riegeli/chunk_encoding:chunk_decoder",
183 | "@riegeli//riegeli/records:chunk_reader",
184 | ],
185 | )
186 |
187 | cc_test(
188 | name = "sequenced_chunk_writer_test",
189 | srcs = ["sequenced_chunk_writer_test.cc"],
190 | deps = [
191 | ":common",
192 | ":sequenced_chunk_writer",
193 | ":thread_pool",
194 | "@abseil-cpp//absl/status",
195 | "@abseil-cpp//absl/status:statusor",
196 | "@abseil-cpp//absl/strings:cord",
197 | "@abseil-cpp//absl/strings:string_view",
198 | "@abseil-cpp//absl/types:span",
199 | "@googletest//:gtest_main",
200 | "@riegeli//riegeli/base:initializer",
201 | "@riegeli//riegeli/base:shared_ptr",
202 | "@riegeli//riegeli/bytes:chain_writer",
203 | "@riegeli//riegeli/bytes:cord_writer",
204 | "@riegeli//riegeli/bytes:string_reader",
205 | "@riegeli//riegeli/bytes:string_writer",
206 | "@riegeli//riegeli/chunk_encoding:chunk",
207 | "@riegeli//riegeli/chunk_encoding:compressor_options",
208 | "@riegeli//riegeli/chunk_encoding:constants",
209 | "@riegeli//riegeli/chunk_encoding:simple_encoder",
210 | "@riegeli//riegeli/records:record_reader",
211 | ],
212 | )
213 |
214 | cc_test(
215 | name = "tri_state_ptr_test",
216 | srcs = ["tri_state_ptr_test.cc"],
217 | deps = [
218 | ":common",
219 | ":thread_pool",
220 | ":tri_state_ptr",
221 | "@abseil-cpp//absl/synchronization",
222 | "@googletest//:gtest_main",
223 | "@riegeli//riegeli/base:initializer",
224 | ],
225 | )
226 |
227 | cc_test(
228 | name = "array_record_writer_test",
229 | srcs = ["array_record_writer_test.cc"],
230 | shard_count = 4,
231 | tags = ["notsan"],
232 | deps = [
233 | ":array_record_writer",
234 | ":common",
235 | ":layout_cc_proto",
236 | ":test_utils",
237 | ":thread_pool",
238 | "@abseil-cpp//absl/strings",
239 | "@abseil-cpp//absl/strings:cord",
240 | "@abseil-cpp//absl/strings:cord_test_helpers",
241 | "@googletest//:gtest_main",
242 | "@riegeli//riegeli/base:initializer",
243 | "@riegeli//riegeli/bytes:string_reader",
244 | "@riegeli//riegeli/bytes:string_writer",
245 | "@riegeli//riegeli/chunk_encoding:constants",
246 | "@riegeli//riegeli/records:record_reader",
247 | "@riegeli//riegeli/records:records_metadata_cc_proto",
248 | ],
249 | )
250 |
251 | cc_test(
252 | name = "masked_reader_test",
253 | srcs = ["masked_reader_test.cc"],
254 | deps = [
255 | ":masked_reader",
256 | "@googletest//:gtest_main",
257 | "@riegeli//riegeli/bytes:string_reader",
258 | ],
259 | )
260 |
261 | cc_test(
262 | name = "parallel_for_test",
263 | size = "small",
264 | srcs = ["parallel_for_test.cc"],
265 | deps = [
266 | ":common",
267 | ":parallel_for",
268 | ":thread_pool",
269 | "@abseil-cpp//absl/functional:function_ref",
270 | "@abseil-cpp//absl/status",
271 | "@googletest//:gtest_main",
272 | ],
273 | )
274 |
275 | cc_test(
276 | name = "array_record_reader_test",
277 | srcs = ["array_record_reader_test.cc"],
278 | shard_count = 4,
279 | deps = [
280 | ":array_record_reader",
281 | ":array_record_writer",
282 | ":common",
283 | ":layout_cc_proto",
284 | ":test_utils",
285 | ":thread_pool",
286 | "@abseil-cpp//absl/functional:function_ref",
287 | "@abseil-cpp//absl/status",
288 | "@abseil-cpp//absl/strings",
289 | "@googletest//:gtest_main",
290 | "@riegeli//riegeli/base:initializer",
291 | "@riegeli//riegeli/bytes:string_reader",
292 | "@riegeli//riegeli/bytes:string_writer",
293 | ],
294 | )
295 |
--------------------------------------------------------------------------------
/cpp/array_record_reader_test.cc:
--------------------------------------------------------------------------------
1 | /* Copyright 2022 Google LLC. All Rights Reserved.
2 |
3 | Licensed under the Apache License, Version 2.0 (the "License");
4 | you may not use this file except in compliance with the License.
5 | You may obtain a copy of the License at
6 |
7 | http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | Unless required by applicable law or agreed to in writing, software
10 | distributed under the License is distributed on an "AS IS" BASIS,
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | See the License for the specific language governing permissions and
13 | limitations under the License.
14 | ==============================================================================*/
15 |
16 | #include "cpp/array_record_reader.h"
17 |
18 | #include
19 | #include
20 | #include
21 | #include
22 | #include
23 | #include
24 | #include
25 | #include
26 | #include
27 |
28 | #include "gtest/gtest.h"
29 | #include "absl/functional/function_ref.h"
30 | #include "absl/status/status.h"
31 | #include "absl/strings/string_view.h"
32 | #include "cpp/array_record_writer.h"
33 | #include "cpp/common.h"
34 | #include "cpp/layout.pb.h"
35 | #include "cpp/test_utils.h"
36 | #include "cpp/thread_pool.h"
37 | #include "riegeli/base/maker.h"
38 | #include "riegeli/bytes/string_reader.h"
39 | #include "riegeli/bytes/string_writer.h"
40 |
41 | constexpr uint32_t kDatasetSize = 10050;
42 |
43 | namespace array_record {
44 | namespace {
45 |
46 | enum class CompressionType { kUncompressed, kBrotli, kZstd, kSnappy };
47 |
48 | // Tuple params
49 | // CompressionType
50 | // transpose
51 | // use_thread_pool
52 | // optimize_for_random_access
53 | class ArrayRecordReaderTest
54 | : public testing::TestWithParam<
55 | std::tuple> {
56 | public:
57 | ARThreadPool* get_pool() { return ArrayRecordGlobalPool(); }
58 | ArrayRecordWriterBase::Options GetWriterOptions() {
59 | auto options = ArrayRecordWriterBase::Options();
60 | switch (std::get<0>(GetParam())) {
61 | case CompressionType::kUncompressed:
62 | options.set_uncompressed();
63 | break;
64 | case CompressionType::kBrotli:
65 | options.set_brotli();
66 | break;
67 | case CompressionType::kZstd:
68 | options.set_zstd();
69 | break;
70 | case CompressionType::kSnappy:
71 | options.set_snappy();
72 | break;
73 | }
74 | options.set_transpose(transpose());
75 | return options;
76 | }
77 |
78 | bool transpose() { return std::get<1>(GetParam()); }
79 | bool use_thread_pool() { return std::get<2>(GetParam()); }
80 | bool optimize_for_random_access() { return std::get<3>(GetParam()); }
81 | };
82 |
83 | TEST_P(ArrayRecordReaderTest, MoveTest) {
84 | std::string encoded;
85 | auto writer_options = GetWriterOptions().set_group_size(2);
86 | auto writer = ArrayRecordWriter(
87 | riegeli::Maker(&encoded), writer_options, nullptr);
88 |
89 | // Empty string should not crash the writer or the reader.
90 | std::vector test_str{"aaa", "", "ccc", "dd", "e"};
91 | for (const auto& s : test_str) {
92 | EXPECT_TRUE(writer.WriteRecord(s));
93 | }
94 | ASSERT_TRUE(writer.Close());
95 |
96 | auto reader_opt = ArrayRecordReaderBase::Options();
97 | if (optimize_for_random_access()) {
98 | reader_opt.set_max_parallelism(0);
99 | reader_opt.set_readahead_buffer_size(0);
100 | }
101 |
102 | auto reader_before_move =
103 | ArrayRecordReader(riegeli::Maker(encoded),
104 | reader_opt, use_thread_pool() ? get_pool() : nullptr);
105 | ASSERT_TRUE(reader_before_move.status().ok());
106 |
107 | ASSERT_TRUE(
108 | reader_before_move
109 | .ParallelReadRecords([&](uint64_t record_index,
110 | absl::string_view record) -> absl::Status {
111 | EXPECT_EQ(record, test_str[record_index]);
112 | return absl::OkStatus();
113 | })
114 | .ok());
115 |
116 | EXPECT_EQ(reader_before_move.RecordGroupSize(), 2);
117 |
118 | ArrayRecordReader reader = std::move(reader_before_move);
119 | // Once a reader is moved, it is closed.
120 | ASSERT_FALSE(reader_before_move.is_open()); // NOLINT
121 |
122 | auto recorded_writer_options = ArrayRecordWriterBase::Options::FromString(
123 | reader.WriterOptionsString().value())
124 | .value();
125 | EXPECT_EQ(writer_options.compression_type(),
126 | recorded_writer_options.compression_type());
127 | EXPECT_EQ(writer_options.compression_level(),
128 | recorded_writer_options.compression_level());
129 | EXPECT_EQ(writer_options.transpose(), recorded_writer_options.transpose());
130 |
131 | std::vector indices = {1, 2, 4};
132 | ASSERT_TRUE(reader
133 | .ParallelReadRecordsWithIndices(
134 | indices,
135 | [&](uint64_t indices_idx,
136 | absl::string_view record) -> absl::Status {
137 | EXPECT_EQ(record, test_str[indices[indices_idx]]);
138 | return absl::OkStatus();
139 | })
140 | .ok());
141 |
142 | absl::string_view record_view;
143 | for (auto i : IndicesOf(test_str)) {
144 | EXPECT_TRUE(reader.ReadRecord(&record_view));
145 | EXPECT_EQ(record_view, test_str[i]);
146 | }
147 | // Cannot read once we are at the end of the file.
148 | EXPECT_FALSE(reader.ReadRecord(&record_view));
149 | // But the reader should still be healthy.
150 | EXPECT_TRUE(reader.ok());
151 |
152 | // Seek to a particular record works.
153 | EXPECT_TRUE(reader.SeekRecord(2));
154 | EXPECT_TRUE(reader.ReadRecord(&record_view));
155 | EXPECT_EQ(record_view, test_str[2]);
156 |
157 | // Seek out of bound would not fail.
158 | EXPECT_TRUE(reader.SeekRecord(10));
159 | EXPECT_FALSE(reader.ReadRecord(&record_view));
160 | EXPECT_TRUE(reader.ok());
161 |
162 | EXPECT_EQ(reader.RecordGroupSize(), 2);
163 |
164 | ASSERT_TRUE(reader.Close());
165 | }
166 |
167 | TEST_P(ArrayRecordReaderTest, RandomDatasetTest) {
168 | std::mt19937 bitgen;
169 | std::vector records(kDatasetSize);
170 | std::uniform_int_distribution<> dist(0, 123);
171 | for (auto i : Seq(kDatasetSize)) {
172 | size_t len = dist(bitgen);
173 | records[i] = MTRandomBytes(bitgen, len);
174 | }
175 |
176 | std::string encoded;
177 | auto writer =
178 | ArrayRecordWriter(riegeli::Maker(&encoded),
179 | GetWriterOptions(), get_pool());
180 | for (auto i : Seq(kDatasetSize)) {
181 | EXPECT_TRUE(writer.WriteRecord(records[i]));
182 | }
183 | ASSERT_TRUE(writer.Close());
184 |
185 | auto reader_opt = ArrayRecordReaderBase::Options();
186 | if (optimize_for_random_access()) {
187 | reader_opt.set_max_parallelism(0);
188 | reader_opt.set_readahead_buffer_size(0);
189 | }
190 |
191 | auto reader =
192 | ArrayRecordReader(riegeli::Maker(encoded),
193 | reader_opt, use_thread_pool() ? get_pool() : nullptr);
194 | ASSERT_TRUE(reader.status().ok());
195 | EXPECT_EQ(reader.NumRecords(), kDatasetSize);
196 | uint64_t group_size =
197 | std::min(ArrayRecordWriterBase::Options::kDefaultGroupSize, kDatasetSize);
198 | EXPECT_EQ(reader.RecordGroupSize(), group_size);
199 |
200 | std::vector read_all_records(kDatasetSize, false);
201 | ASSERT_TRUE(reader
202 | .ParallelReadRecords(
203 | [&](uint64_t record_index,
204 | absl::string_view result_view) -> absl::Status {
205 | EXPECT_EQ(result_view, records[record_index]);
206 | EXPECT_FALSE(read_all_records[record_index]);
207 | read_all_records[record_index] = true;
208 | return absl::OkStatus();
209 | })
210 | .ok());
211 | for (bool record_was_read : read_all_records) {
212 | EXPECT_TRUE(record_was_read);
213 | }
214 |
215 | std::vector indices = {0, 3, 5, 7, 101, 2000};
216 | std::vector read_indexed_records(indices.size(), false);
217 | ASSERT_TRUE(reader
218 | .ParallelReadRecordsWithIndices(
219 | indices,
220 | [&](uint64_t indices_idx,
221 | absl::string_view result_view) -> absl::Status {
222 | EXPECT_EQ(result_view, records[indices[indices_idx]]);
223 | EXPECT_FALSE(read_indexed_records[indices_idx]);
224 | read_indexed_records[indices_idx] = true;
225 | return absl::OkStatus();
226 | })
227 | .ok());
228 | for (bool record_was_read : read_indexed_records) {
229 | EXPECT_TRUE(record_was_read);
230 | }
231 |
232 | uint64_t begin = 10, end = 101;
233 | std::vector read_range_records(end - begin, false);
234 | ASSERT_TRUE(reader
235 | .ParallelReadRecordsInRange(
236 | begin, end,
237 | [&](uint64_t record_index,
238 | absl::string_view result_view) -> absl::Status {
239 | EXPECT_EQ(result_view, records[record_index]);
240 | EXPECT_FALSE(read_range_records[record_index - begin]);
241 | read_range_records[record_index - begin] = true;
242 | return absl::OkStatus();
243 | })
244 | .ok());
245 | for (bool record_was_read : read_range_records) {
246 | EXPECT_TRUE(record_was_read);
247 | }
248 |
249 | // Test sequential read
250 | absl::string_view result_view;
251 | for (auto record_index : Seq(kDatasetSize)) {
252 | ASSERT_TRUE(reader.ReadRecord(&result_view));
253 | EXPECT_EQ(result_view, records[record_index]);
254 | }
255 | // Reached to the end.
256 | EXPECT_FALSE(reader.ReadRecord(&result_view));
257 | EXPECT_TRUE(reader.ok());
258 | // We can still seek back.
259 | EXPECT_TRUE(reader.SeekRecord(5));
260 | EXPECT_TRUE(reader.ReadRecord(&result_view));
261 | EXPECT_EQ(result_view, records[5]);
262 |
263 | ASSERT_TRUE(reader.Close());
264 | }
265 |
266 | INSTANTIATE_TEST_SUITE_P(
267 | ParamTest, ArrayRecordReaderTest,
268 | testing::Combine(testing::Values(CompressionType::kUncompressed,
269 | CompressionType::kBrotli,
270 | CompressionType::kZstd,
271 | CompressionType::kSnappy),
272 | testing::Bool(), testing::Bool(), testing::Bool()));
273 |
274 | TEST(ArrayRecordReaderOptionTest, ParserTest) {
275 | {
276 | auto option = ArrayRecordReaderBase::Options::FromString("").value();
277 | EXPECT_EQ(option.max_parallelism(), std::nullopt);
278 | EXPECT_EQ(option.readahead_buffer_size(),
279 | ArrayRecordReaderBase::Options::kDefaultReadaheadBufferSize);
280 | }
281 | {
282 | auto option =
283 | ArrayRecordReaderBase::Options::FromString("max_parallelism:16")
284 | .value();
285 | EXPECT_EQ(option.max_parallelism(), 16);
286 | EXPECT_EQ(option.readahead_buffer_size(),
287 | ArrayRecordReaderBase::Options::kDefaultReadaheadBufferSize);
288 | }
289 | {
290 | auto option = ArrayRecordReaderBase::Options::FromString(
291 | "max_parallelism:16,readahead_buffer_size:16384")
292 | .value();
293 | EXPECT_EQ(option.max_parallelism(), 16);
294 | EXPECT_EQ(option.readahead_buffer_size(), 16384);
295 | }
296 | {
297 | auto option = ArrayRecordReaderBase::Options::FromString(
298 | "max_parallelism:0,readahead_buffer_size:0")
299 | .value();
300 | EXPECT_EQ(option.max_parallelism(), 0);
301 | EXPECT_EQ(option.readahead_buffer_size(), 0);
302 | }
303 | }
304 |
305 | } // namespace
306 | } // namespace array_record
307 |
--------------------------------------------------------------------------------
/cpp/array_record_writer_test.cc:
--------------------------------------------------------------------------------
1 | /* Copyright 2022 Google LLC. All Rights Reserved.
2 |
3 | Licensed under the Apache License, Version 2.0 (the "License");
4 | you may not use this file except in compliance with the License.
5 | You may obtain a copy of the License at
6 |
7 | http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | Unless required by applicable law or agreed to in writing, software
10 | distributed under the License is distributed on an "AS IS" BASIS,
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | See the License for the specific language governing permissions and
13 | limitations under the License.
14 | ==============================================================================*/
15 |
16 | #include "cpp/array_record_writer.h"
17 |
18 | #include
19 | #include
20 | #include
21 | #include
22 | #include
23 | #include
24 | #include
25 | #include
26 |
27 | #include "gtest/gtest.h"
28 | #include "absl/strings/cord.h"
29 | #include "absl/strings/cord_test_helpers.h"
30 | #include "absl/strings/string_view.h"
31 | #include "cpp/common.h"
32 | #include "cpp/layout.pb.h"
33 | #include "cpp/test_utils.h"
34 | #include "cpp/thread_pool.h"
35 | #include "riegeli/base/maker.h"
36 | #include "riegeli/bytes/string_reader.h"
37 | #include "riegeli/bytes/string_writer.h"
38 | #include "riegeli/chunk_encoding/constants.h"
39 | #include "riegeli/records/record_reader.h"
40 | #include "riegeli/records/records_metadata.pb.h"
41 |
42 | namespace array_record {
43 |
44 | namespace {
45 |
46 | enum class CompressionType { kUncompressed, kBrotli, kZstd, kSnappy };
47 |
48 | // Tuple params
49 | // CompressionType
50 | // padding
51 | // transpose
52 | // use ThreadPool
53 | class ArrayRecordWriterTest
54 | : public testing::TestWithParam<
55 | std::tuple> {
56 | public:
57 | ArrayRecordWriterBase::Options GetOptions() {
58 | auto options = ArrayRecordWriterBase::Options();
59 | switch (std::get<0>(GetParam())) {
60 | case CompressionType::kUncompressed:
61 | options.set_uncompressed();
62 | break;
63 | case CompressionType::kBrotli:
64 | options.set_brotli();
65 | break;
66 | case CompressionType::kZstd:
67 | options.set_zstd();
68 | break;
69 | case CompressionType::kSnappy:
70 | options.set_snappy();
71 | break;
72 | }
73 | options.set_pad_to_block_boundary(std::get<1>(GetParam()));
74 | options.set_transpose(std::get<2>(GetParam()));
75 | return options;
76 | }
77 | };
78 |
79 | template
80 | void SilenceMoveAfterUseForTest(T&) {}
81 |
82 | TEST_P(ArrayRecordWriterTest, MoveTest) {
83 | std::string encoded;
84 | ARThreadPool* pool = nullptr;
85 | if (std::get<3>(GetParam())) {
86 | pool = ArrayRecordGlobalPool();
87 | }
88 | auto options = GetOptions();
89 | options.set_group_size(2);
90 | auto writer = ArrayRecordWriter(
91 | riegeli::Maker(&encoded), options, pool);
92 |
93 | // Empty string should not crash the writer/reader.
94 | std::vector test_str{"aaa", "", "ccc", "dd", "e"};
95 | for (auto i : Seq(3)) {
96 | EXPECT_TRUE(writer.WriteRecord(test_str[i]));
97 | }
98 |
99 | auto moved_writer = std::move(writer);
100 | SilenceMoveAfterUseForTest(writer);
101 | // Once moved, writer is closed.
102 | ASSERT_FALSE(writer.is_open());
103 | ASSERT_TRUE(moved_writer.is_open());
104 | // Once moved we can no longer write records.
105 | EXPECT_FALSE(writer.WriteRecord(test_str[3]));
106 |
107 | ASSERT_TRUE(moved_writer.status().ok());
108 | EXPECT_TRUE(moved_writer.WriteRecord(test_str[3]));
109 | EXPECT_TRUE(moved_writer.WriteRecord(test_str[4]));
110 | ASSERT_TRUE(moved_writer.Close());
111 |
112 | auto reader =
113 | riegeli::RecordReader(riegeli::Maker(encoded));
114 | for (const auto& expected : test_str) {
115 | std::string result;
116 | reader.ReadRecord(result);
117 | EXPECT_EQ(result, expected);
118 | }
119 | }
120 |
121 | TEST_P(ArrayRecordWriterTest, CordTest) {
122 | std::string encoded;
123 | ARThreadPool* pool = nullptr;
124 | if (std::get<3>(GetParam())) {
125 | pool = ArrayRecordGlobalPool();
126 | }
127 | auto options = GetOptions();
128 | options.set_group_size(2);
129 | auto writer = ArrayRecordWriter(
130 | riegeli::Maker(&encoded), options, pool);
131 |
132 | absl::Cord flat_cord("test");
133 | // Empty string should not crash the writer.
134 | absl::Cord empty_cord("");
135 | absl::Cord fragmented_cord = absl::MakeFragmentedCord({"aaa ", "", "c"});
136 |
137 | EXPECT_TRUE(writer.WriteRecord(flat_cord));
138 | EXPECT_TRUE(writer.WriteRecord(empty_cord));
139 | EXPECT_TRUE(writer.WriteRecord(fragmented_cord));
140 | ASSERT_TRUE(writer.Close());
141 |
142 | // Empty string should not crash the reader.
143 | std::vector expected_strings{"test", "", "aaa c"};
144 |
145 | auto reader =
146 | riegeli::RecordReader(riegeli::Maker(encoded));
147 | for (const auto& expected : expected_strings) {
148 | std::string result;
149 | reader.ReadRecord(result);
150 | EXPECT_EQ(result, expected);
151 | }
152 | }
153 |
154 | TEST_P(ArrayRecordWriterTest, RandomDatasetTest) {
155 | std::mt19937 bitgen;
156 | constexpr uint32_t kGroupSize = 100;
157 | constexpr uint32_t num_records = 1357;
158 | std::vector records(num_records);
159 | std::uniform_int_distribution<> dist(0, 123);
160 | for (auto i : Seq(num_records)) {
161 | size_t len = dist(bitgen);
162 | records[i] = MTRandomBytes(bitgen, len);
163 | }
164 | // results are stored in encoded
165 | std::string encoded;
166 |
167 | ARThreadPool* pool = nullptr;
168 | if (std::get<3>(GetParam())) {
169 | pool = ArrayRecordGlobalPool();
170 | }
171 | auto options = GetOptions();
172 | options.set_group_size(kGroupSize);
173 |
174 | auto writer = ArrayRecordWriter(
175 | riegeli::Maker(&encoded), options, pool);
176 |
177 | for (auto i : Seq(num_records)) {
178 | EXPECT_TRUE(writer.WriteRecord(records[i]));
179 | }
180 | ASSERT_TRUE(writer.Close());
181 |
182 | auto reader =
183 | riegeli::RecordReader(riegeli::Maker(encoded));
184 |
185 | // Verify metadata
186 | ASSERT_TRUE(reader.CheckFileFormat());
187 |
188 | // Verify each record
189 | for (auto i : Seq(num_records)) {
190 | absl::string_view result_view;
191 | ASSERT_TRUE(reader.ReadRecord(result_view));
192 | EXPECT_EQ(result_view, records[i]);
193 | }
194 |
195 | // Verify postcript
196 | ASSERT_TRUE(reader.Seek(reader.Size().value() - (1 << 16)));
197 | RiegeliPostscript postscript;
198 | ASSERT_TRUE(reader.ReadRecord(postscript));
199 | ASSERT_EQ(postscript.magic(), 0x71930e704fdae05eULL);
200 |
201 | // Verify Footer
202 | ASSERT_TRUE(reader.Seek(postscript.footer_offset()));
203 | RiegeliFooterMetadata footer_metadata;
204 | ASSERT_TRUE(reader.ReadRecord(footer_metadata));
205 | ASSERT_EQ(footer_metadata.array_record_metadata().version(), 1);
206 | auto num_chunks = footer_metadata.array_record_metadata().num_chunks();
207 | std::vector footers(num_chunks);
208 | for (auto i : Seq(num_chunks)) {
209 | ASSERT_TRUE(reader.ReadRecord(footers[i]));
210 | }
211 |
212 | // Verify we can access the file randomly by chunk_offset recorded in the
213 | // footer
214 | for (auto i = 0UL; i < num_chunks; ++i) {
215 | ASSERT_TRUE(reader.Seek(footers[i].chunk_offset()));
216 | absl::string_view result_view;
217 | ASSERT_TRUE(reader.ReadRecord(result_view)) << reader.status();
218 | EXPECT_EQ(result_view, records[i * kGroupSize]);
219 | }
220 | ASSERT_TRUE(reader.Close());
221 | }
222 |
223 | INSTANTIATE_TEST_SUITE_P(
224 | ParamTest, ArrayRecordWriterTest,
225 | testing::Combine(testing::Values(CompressionType::kUncompressed,
226 | CompressionType::kBrotli,
227 | CompressionType::kZstd,
228 | CompressionType::kSnappy),
229 | testing::Bool(), testing::Bool(), testing::Bool()));
230 |
231 | TEST(ArrayRecordWriterOptionsTest, ParsingTest) {
232 | {
233 | auto option = ArrayRecordWriterBase::Options();
234 | EXPECT_EQ(option.group_size(),
235 | ArrayRecordWriterBase::Options::kDefaultGroupSize);
236 | EXPECT_FALSE(option.transpose());
237 | EXPECT_EQ(option.max_parallelism(), std::nullopt);
238 | EXPECT_EQ(option.compressor_options().compression_type(),
239 | riegeli::CompressionType::kZstd);
240 | EXPECT_EQ(option.compressor_options().compression_level(), 3);
241 | EXPECT_FALSE(option.pad_to_block_boundary());
242 | }
243 | {
244 | auto option = ArrayRecordWriterBase::Options::FromString("").value();
245 | EXPECT_EQ(option.group_size(),
246 | ArrayRecordWriterBase::Options::kDefaultGroupSize);
247 | EXPECT_FALSE(option.transpose());
248 | EXPECT_EQ(option.max_parallelism(), std::nullopt);
249 | EXPECT_EQ(option.compressor_options().compression_type(),
250 | riegeli::CompressionType::kZstd);
251 | EXPECT_EQ(option.compressor_options().compression_level(), 3);
252 | EXPECT_EQ(option.compressor_options().window_log().value(), 20);
253 | EXPECT_FALSE(option.pad_to_block_boundary());
254 |
255 | EXPECT_EQ(option.ToString(),
256 | "group_size:65536,"
257 | "transpose:false,"
258 | "pad_to_block_boundary:false,"
259 | "zstd:3,"
260 | "window_log:20");
261 | EXPECT_TRUE(
262 | ArrayRecordWriterBase::Options::FromString(option.ToString()).ok());
263 | }
264 | {
265 | auto option = ArrayRecordWriterBase::Options::FromString("default").value();
266 | EXPECT_EQ(option.group_size(),
267 | ArrayRecordWriterBase::Options::kDefaultGroupSize);
268 | EXPECT_FALSE(option.transpose());
269 | EXPECT_EQ(option.max_parallelism(), std::nullopt);
270 | EXPECT_EQ(option.compressor_options().compression_type(),
271 | riegeli::CompressionType::kZstd);
272 | EXPECT_EQ(option.compressor_options().compression_level(), 3);
273 | EXPECT_EQ(option.compressor_options().window_log().value(), 20);
274 | EXPECT_FALSE(option.pad_to_block_boundary());
275 |
276 | EXPECT_EQ(option.ToString(),
277 | "group_size:65536,"
278 | "transpose:false,"
279 | "pad_to_block_boundary:false,"
280 | "zstd:3,"
281 | "window_log:20");
282 | EXPECT_TRUE(
283 | ArrayRecordWriterBase::Options::FromString(option.ToString()).ok());
284 | }
285 | {
286 | auto option = ArrayRecordWriterBase::Options::FromString(
287 | "group_size:32,transpose,window_log:20")
288 | .value();
289 | EXPECT_EQ(option.group_size(), 32);
290 | EXPECT_TRUE(option.transpose());
291 | EXPECT_EQ(option.max_parallelism(), std::nullopt);
292 | EXPECT_EQ(option.compressor_options().compression_type(),
293 | riegeli::CompressionType::kZstd);
294 | EXPECT_EQ(option.compressor_options().window_log(), 20);
295 | EXPECT_FALSE(option.pad_to_block_boundary());
296 |
297 | EXPECT_EQ(option.ToString(),
298 | "group_size:32,"
299 | "transpose:true,"
300 | "pad_to_block_boundary:false,"
301 | "transpose_bucket_size:256,"
302 | "zstd:3,"
303 | "window_log:20");
304 | EXPECT_TRUE(
305 | ArrayRecordWriterBase::Options::FromString(option.ToString()).ok());
306 | }
307 | {
308 | auto option = ArrayRecordWriterBase::Options::FromString(
309 | "brotli:6,group_size:32,transpose,window_log:25")
310 | .value();
311 | EXPECT_EQ(option.group_size(), 32);
312 | EXPECT_TRUE(option.transpose());
313 | EXPECT_EQ(option.max_parallelism(), std::nullopt);
314 | EXPECT_EQ(option.compressor_options().compression_type(),
315 | riegeli::CompressionType::kBrotli);
316 | EXPECT_EQ(option.compressor_options().window_log(), 25);
317 | EXPECT_FALSE(option.pad_to_block_boundary());
318 |
319 | EXPECT_EQ(option.ToString(),
320 | "group_size:32,"
321 | "transpose:true,"
322 | "pad_to_block_boundary:false,"
323 | "transpose_bucket_size:256,"
324 | "brotli:6,"
325 | "window_log:25");
326 | EXPECT_TRUE(
327 | ArrayRecordWriterBase::Options::FromString(option.ToString()).ok());
328 | }
329 | {
330 | auto option = ArrayRecordWriterBase::Options::FromString(
331 | "group_size:32,transpose,zstd:5")
332 | .value();
333 | EXPECT_EQ(option.group_size(), 32);
334 | EXPECT_TRUE(option.transpose());
335 | EXPECT_EQ(option.max_parallelism(), std::nullopt);
336 | EXPECT_EQ(option.compressor_options().compression_type(),
337 | riegeli::CompressionType::kZstd);
338 | EXPECT_EQ(option.compressor_options().window_log(), 20);
339 | EXPECT_EQ(option.compressor_options().compression_level(), 5);
340 | EXPECT_FALSE(option.pad_to_block_boundary());
341 |
342 | EXPECT_EQ(option.ToString(),
343 | "group_size:32,"
344 | "transpose:true,"
345 | "pad_to_block_boundary:false,"
346 | "transpose_bucket_size:256,"
347 | "zstd:5,"
348 | "window_log:20");
349 | EXPECT_TRUE(
350 | ArrayRecordWriterBase::Options::FromString(option.ToString()).ok());
351 | }
352 | {
353 | auto option = ArrayRecordWriterBase::Options::FromString(
354 | "uncompressed,pad_to_block_boundary:true")
355 | .value();
356 | EXPECT_EQ(option.group_size(),
357 | ArrayRecordWriterBase::Options::kDefaultGroupSize);
358 | EXPECT_FALSE(option.transpose());
359 | EXPECT_EQ(option.max_parallelism(), std::nullopt);
360 | EXPECT_EQ(option.compressor_options().compression_type(),
361 | riegeli::CompressionType::kNone);
362 | EXPECT_TRUE(option.pad_to_block_boundary());
363 |
364 | EXPECT_EQ(option.ToString(),
365 | "group_size:65536,"
366 | "transpose:false,"
367 | "pad_to_block_boundary:true,"
368 | "uncompressed");
369 | EXPECT_TRUE(
370 | ArrayRecordWriterBase::Options::FromString(option.ToString()).ok());
371 | }
372 | {
373 | auto option = ArrayRecordWriterBase::Options::FromString(
374 | "snappy,pad_to_block_boundary:true")
375 | .value();
376 | EXPECT_EQ(option.group_size(),
377 | ArrayRecordWriterBase::Options::kDefaultGroupSize);
378 | EXPECT_FALSE(option.transpose());
379 | EXPECT_EQ(option.max_parallelism(), std::nullopt);
380 | EXPECT_EQ(option.compressor_options().compression_type(),
381 | riegeli::CompressionType::kSnappy);
382 | EXPECT_TRUE(option.pad_to_block_boundary());
383 |
384 | EXPECT_EQ(option.ToString(),
385 | "group_size:65536,"
386 | "transpose:false,"
387 | "pad_to_block_boundary:true,"
388 | "snappy");
389 | EXPECT_TRUE(
390 | ArrayRecordWriterBase::Options::FromString(option.ToString()).ok());
391 | }
392 | }
393 |
394 | } // namespace
395 | } // namespace array_record
396 |
--------------------------------------------------------------------------------
/cpp/common.h:
--------------------------------------------------------------------------------
1 | /* Copyright 2022 Google LLC. All Rights Reserved.
2 |
3 | Licensed under the Apache License, Version 2.0 (the "License");
4 | you may not use this file except in compliance with the License.
5 | You may obtain a copy of the License at
6 |
7 | http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | Unless required by applicable law or agreed to in writing, software
10 | distributed under the License is distributed on an "AS IS" BASIS,
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | See the License for the specific language governing permissions and
13 | limitations under the License.
14 | ==============================================================================*/
15 |
16 | #ifndef ARRAY_RECORD_CPP_COMMON_H_
17 | #define ARRAY_RECORD_CPP_COMMON_H_
18 |
19 | #include "absl/base/attributes.h"
20 | #include "absl/status/status.h"
21 | #include "absl/strings/str_format.h"
22 |
23 | namespace array_record {
24 |
25 | ////////////////////////////////////////////////////////////////////////////////
26 | // Canonical Errors (with formatting!)
27 | ////////////////////////////////////////////////////////////////////////////////
28 |
29 | template
30 | ABSL_MUST_USE_RESULT absl::Status FailedPreconditionError(
31 | const absl::FormatSpec& fmt, const Args&... args) {
32 | return absl::FailedPreconditionError(absl::StrFormat(fmt, args...));
33 | }
34 |
35 | template
36 | ABSL_MUST_USE_RESULT absl::Status InternalError(
37 | const absl::FormatSpec& fmt, const Args&... args) {
38 | return absl::InternalError(absl::StrFormat(fmt, args...));
39 | }
40 |
41 | template
42 | ABSL_MUST_USE_RESULT absl::Status InvalidArgumentError(
43 | const absl::FormatSpec& fmt, const Args&... args) {
44 | return absl::InvalidArgumentError(absl::StrFormat(fmt, args...));
45 | }
46 |
47 | template
48 | ABSL_MUST_USE_RESULT absl::Status NotFoundError(
49 | const absl::FormatSpec& fmt, const Args&... args) {
50 | return absl::NotFoundError(absl::StrFormat(fmt, args...));
51 | }
52 |
53 | template
54 | ABSL_MUST_USE_RESULT absl::Status OutOfRangeError(
55 | const absl::FormatSpec& fmt, const Args&... args) {
56 | return absl::OutOfRangeError(absl::StrFormat(fmt, args...));
57 | }
58 |
59 | template
60 | ABSL_MUST_USE_RESULT absl::Status UnavailableError(
61 | const absl::FormatSpec& fmt, const Args&... args) {
62 | return absl::UnavailableError(absl::StrFormat(fmt, args...));
63 | }
64 |
65 | template
66 | ABSL_MUST_USE_RESULT absl::Status UnimplementedError(
67 | const absl::FormatSpec& fmt, const Args&... args) {
68 | return absl::UnimplementedError(absl::StrFormat(fmt, args...));
69 | }
70 |
71 | template
72 | ABSL_MUST_USE_RESULT absl::Status UnknownError(
73 | const absl::FormatSpec& fmt, const Args&... args) {
74 | return absl::UnknownError(absl::StrFormat(fmt, args...));
75 | }
76 |
77 | // TODO(fchern): Align with what XLA do.
78 | template
79 | constexpr Int DivRoundUp(Int num, DenomInt denom) {
80 | // Note: we want DivRoundUp(my_uint64, 17) to just work, so we cast the denom
81 | // to the numerator's type. The result of division always fits in the
82 | // numerator's type, so this is very safe.
83 | return (num + static_cast(denom) - static_cast(1)) /
84 | static_cast(denom);
85 | }
86 |
87 | ////////////////////////////////////////////////////////////////////////////////
88 | // Class Decorators
89 | ////////////////////////////////////////////////////////////////////////////////
90 |
91 | #define DECLARE_COPYABLE_CLASS(ClassName) \
92 | ClassName(ClassName&&) = default; \
93 | ClassName& operator=(ClassName&&) = default; \
94 | ClassName(const ClassName&) = default; \
95 | ClassName& operator=(const ClassName&) = default
96 |
97 | #define DECLARE_MOVE_ONLY_CLASS(ClassName) \
98 | ClassName(ClassName&&) = default; \
99 | ClassName& operator=(ClassName&&) = default; \
100 | ClassName(const ClassName&) = delete; \
101 | ClassName& operator=(const ClassName&) = delete
102 |
103 | #define DECLARE_IMMOBILE_CLASS(ClassName) \
104 | ClassName(ClassName&&) = delete; \
105 | ClassName& operator=(ClassName&&) = delete; \
106 | ClassName(const ClassName&) = delete; \
107 | ClassName& operator=(const ClassName&) = delete
108 |
109 | ////////////////////////////////////////////////////////////////////////////////
110 | // Seq / SeqWithStride / IndicesOf
111 | ////////////////////////////////////////////////////////////////////////////////
112 | //
113 | // Seq facilitates iterating over [begin, end) index ranges.
114 | //
115 | // * Avoids 3X stutter of the 'idx' variable, facilitating use of more
116 | // descriptive variable names like 'datapoint_idx', 'centroid_idx', etc.
117 | //
118 | // * Unifies the syntax between ParallelFor and vanilla for-loops.
119 | //
120 | // * Reverse iteration is much easier to read and less error prone.
121 | //
122 | // * Strided iteration becomes harder to miss when skimming code.
123 | //
124 | // * Reduction in boilerplate '=', '<', '+=' symbols makes it easier to
125 | // skim-read code with lots of small for-loops interleaed with operator heavy
126 | // logic (ie, most of ScaM).
127 | //
128 | // * Zero runtime overhead.
129 | //
130 | //
131 | // Equivalent for-loops (basic iteration):
132 | //
133 | // for (size_t idx : Seq(collection.size()) { ... }
134 | // for (size_t idx : Seq(0, collection.size()) { ... }
135 | // for (size_t idx = 0; idx < collection.size(); idx++) { ... }
136 | //
137 | //
138 | // In particular, reverse iteration becomes much simpler and more readable:
139 | //
140 | // for (size_t idx : ReverseSeq(collection.size())) { ... }
141 | // for (ssize_t idx = collection.size() - 1; idx >= 0; idx--) { ... }
142 | //
143 | //
144 | // Strided iteration works too:
145 | //
146 | // for (size_t idx : SeqWithStride<8>(filenames.size())) { ... }
147 | // for (size_t idx = 0; idx < filenames.size(); idx += 8) { ... }
148 | //
149 | //
150 | // Iteration count without using a variable:
151 | //
152 | // for (auto _ : Seq(16)) { ... }
153 | //
154 | //
155 | // Clarifies the ParallelFor syntax:
156 | //
157 | // ParallelFor<1>(Seq(dataset.size()), &pool, [&](size_t datapoint_idx) {
158 | // ...
159 | // });
160 | //
161 | template
162 | class SeqWithStride {
163 | public:
164 | static constexpr size_t Stride() { return kStride; }
165 |
166 | // Constructor for iterating [0, end).
167 | inline explicit SeqWithStride(size_t end) : begin_(0), end_(end) {}
168 |
169 | // Constructor for iterating [begin, end).
170 | inline SeqWithStride(size_t begin, size_t end) : begin_(begin), end_(end) {
171 | static_assert(kStride != 0);
172 | }
173 |
174 | // SizeT is an internal detail that helps suppress 'unused variable' compiler
175 | // errors. It's implicitly convertible to size_t, but by virtue of having a
176 | // destructor, the compiler doesn't complain about unused SizeT variables.
177 | //
178 | // These are equivalent:
179 | //
180 | // for (auto _ : Seq(10)) // Suppresses 'unused variable' error.
181 | // for (SizeT _ : Seq(10)) // Suppresses 'unused variable' error.
182 | //
183 | // Prefer the 'auto' variant. Don't use SizeT directly.
184 | //
185 | class SizeT {
186 | public:
187 | // Implicit SizeT <=> SizeT conversions.
188 | inline SizeT(size_t val) : val_(val) {} // NOLINT
189 | inline operator size_t() const { return val_; } // NOLINT
190 |
191 | // Defining a destructor suppresses 'unused variable' errors for the
192 | // following pattern: for (auto _ : Seq(kNumIters)) { ... }
193 | inline ~SizeT() {}
194 |
195 | private:
196 | size_t val_;
197 | };
198 |
199 | // Iterator implements the "!=", "++" and "*" operators required to support
200 | // the C++ for-each syntax. Not intended for direct use.
201 | class Iterator {
202 | public:
203 | // Constructor.
204 | inline explicit Iterator(size_t idx) : idx_(idx) {}
205 | // The '*' operator.
206 | inline SizeT operator*() const { return idx_; }
207 | // The '++' operator.
208 | inline Iterator& operator++() {
209 | idx_ += kStride;
210 | return *this;
211 | }
212 | // The '!=' operator.
213 | inline bool operator!=(Iterator end) const {
214 | // Note: The comparison below is "<", not "!=", in order to generate the
215 | // correct behavior when (end - begin) is not a multiple of kStride; note
216 | // that the Iterator class only exists to support the C++ for-each syntax,
217 | // and is *not* intended for direct use.
218 | //
219 | // Consider the case where (end - begin) is not a multple of kStride:
220 | //
221 | // for (size_t j : SeqWithStride<5>(9)) { ... }
222 | // for (size_t j = 0; j < 9; j += 5) { ... } // '==' wouldn't work.
223 | //
224 | if constexpr (kStride > 0) {
225 | return idx_ < end.idx_;
226 | }
227 |
228 | // The reverse-iteration case:
229 | //
230 | // for (size_t j : ReverseSeq(sz)) { ... }
231 | // for (ssize_t j = sz-1; j >= 0; j -= 5) { ... }
232 | //
233 | return static_cast(idx_) >= static_cast(end.idx_);
234 | }
235 |
236 | private:
237 | size_t idx_;
238 | };
239 | using iterator = Iterator;
240 |
241 | inline Iterator begin() const { return Iterator(begin_); }
242 | inline Iterator end() const { return Iterator(end_); }
243 |
244 | private:
245 | size_t begin_;
246 | size_t end_;
247 | };
248 |
249 | // Seq iterates [0, end)
250 | inline auto Seq(size_t end) { return SeqWithStride<1>(0, end); }
251 |
252 | // Seq iterates [begin, end).
253 | inline auto Seq(size_t begin, size_t end) {
254 | return SeqWithStride<1>(begin, end);
255 | }
256 |
257 | // IndicesOf provides the following equivalence class:
258 | //
259 | // for (size_t j : IndicesOf(container)) { ... }
260 | // for (size_t j : Seq(container.size()) { ... }
261 | //
262 | template
263 | SeqWithStride<1> IndicesOf(const Container& container) {
264 | return Seq(container.size());
265 | }
266 |
267 | ////////////////////////////////////////////////////////////////////////////////
268 | // Enumerate
269 | ////////////////////////////////////////////////////////////////////////////////
270 |
271 | template ())),
273 | typename = decltype(std::end(std::declval()))>
274 | constexpr auto Enumerate(T&& iterable) {
275 | class IteratorWithIndex {
276 | public:
277 | IteratorWithIndex(IdxType idx, TIter it) : idx_(idx), it_(it) {}
278 | bool operator!=(const IteratorWithIndex& other) const {
279 | return it_ != other.it_;
280 | }
281 | void operator++() { idx_++, it_++; }
282 | auto operator*() const { return std::tie(idx_, *it_); }
283 |
284 | private:
285 | IdxType idx_;
286 | TIter it_;
287 | };
288 | struct iterator_wrapper {
289 | T iterable;
290 | auto begin() { return IteratorWithIndex{0, std::begin(iterable)}; }
291 | auto end() { return IteratorWithIndex{0, std::end(iterable)}; }
292 | };
293 | return iterator_wrapper{std::forward(iterable)};
294 | }
295 |
296 | ////////////////////////////////////////////////////////////////////////////////
297 | // Profiling Helpers
298 | ////////////////////////////////////////////////////////////////////////////////
299 |
300 | #define AR_ENDO_MARKER(...)
301 | #define AR_ENDO_MARKER_TIMEOUT(...)
302 | #define AR_ENDO_TASK(...)
303 | #define AR_ENDO_JOB(...)
304 | #define AR_ENDO_TASK_TIMEOUT(...)
305 | #define AR_ENDO_JOB_TIMEOUT(...)
306 | #define AR_ENDO_SCOPE(...)
307 | #define AR_ENDO_SCOPE_TIMEOUT(...)
308 | #define AR_ENDO_EVENT(...)
309 | #define AR_ENDO_ERROR(...)
310 | #define AR_ENDO_UNITS(...)
311 | #define AR_ENDO_THREAD_NAME(...)
312 | #define AR_ENDO_GROUP(...)
313 |
314 | } // namespace array_record
315 |
316 | #endif // ARRAY_RECORD_CPP_COMMON_H_
317 |
--------------------------------------------------------------------------------
/cpp/layout.proto:
--------------------------------------------------------------------------------
1 | // Copyright 2022 Google LLC
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // http://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | // Specialized Riegeli file format designed for IO heavy tasks. It works for
16 | // storing large block of raw data without a size limit as well as structured
17 | // data like protobufs. In both cases it supports random access and parallel
18 | // reads in a single file.
19 | syntax = "proto2";
20 |
21 | package array_record;
22 |
23 | // Riegeli files are composed in data chunks. Each data chunk contains multiple
24 | // records, and a record can be a serialized proto (with a size limit of 2GB) or
25 | // arbitrary bytes without size limits.
26 | //
27 | // Each Riegeli data chunk is encoded/compressed separately. The chunks are the
28 | // entry points for decoding, which allows us to read the chunks in parallel if
29 | // we know where these entry points are.
30 | //
31 | // We would not know the offsets to the chunks until we serialized them. Hence,
32 | // a natural way to record these offsets is to store them as a Riegeli chunk in
33 | // the footer. Finally, to tell where the footer start, we need a postscript
34 | // storing the offset of the footer. The postscript itself is also a Riegeli
35 | // chunk size 64KB and is 64KB aligned. Therefore we can locate it by
36 | // subtracting the file size by 64KB. See the illustrated file layout below.
37 | //
38 | // +-----------------+
39 | // | User Data |
40 | // | Riegeli Chunk |
41 | // +-----------------+
42 | // | User Data |
43 | // | Riegeli Chunk |
44 | // +-----------------+
45 | // /\/\/\/\/\/\/\/\/\/\/
46 | // /\/\/\/\/\/\/\/\/\/\/ _+-----------------------+
47 | // +-----------------+ _/ | RiegeliFooterMetadata |
48 | // | Last User Data | __/ +-----------------------+
49 | // | Chunk | _/ | Footer Proto |
50 | // +-----------------_/ +-----------------------+
51 | // | | | Footer Proto |
52 | // | Footer Chunk | +-----------------------+
53 | // | | | Footer Proto |
54 | // +-----------------+---------+-----------------------+
55 | // |RiegeliPostscript| <--- Must Align 64KB and fit in 64KB.
56 | // +-----------------+
57 | //
58 | // The footer is composed of a header (RiegeliFooterMetadata) and an array of
59 | // proto records. We choose an array of proto records instead of repeated fields
60 | // in a single proto to avoid the 2GB proto size limit. We can use proto `enum`
61 | // or `oneof` in the footer metadata to create polymorphic indices. In other
62 | // words, we can develop indices beyond array-like access patterns and extend to
63 | // any data structures worth serializing to disk.
64 |
65 | // Footer proto for locating user data chunk.
66 | message ArrayRecordFooter {
67 | optional uint64 chunk_offset = 1;
68 | optional uint64 decoded_data_size = 2;
69 | optional uint64 num_records = 3;
70 | }
71 |
72 | // Metadata/Header in the footer.
73 | message RiegeliFooterMetadata {
74 | // Metadata for ArrayRecordFooter.
75 | message ArrayRecordMetadata {
76 | // Version number of the ArrayRecord format itself.
77 | // This number should rarely change unless there's a new great layout design
78 | // that wasn't backward compatible and justifies its performance and
79 | // reliability worth us to implement.
80 | optional uint32 version = 1;
81 | optional uint64 num_chunks = 2;
82 | optional uint64 num_records = 3;
83 |
84 | // Writer options for debugging purposes.
85 | optional string writer_options = 4;
86 | }
87 |
88 | oneof metadata {
89 | // Specifying the metadata be array_record_metadata implies the footer
90 | // protos are ArrayRecordFooter.
91 | ArrayRecordMetadata array_record_metadata = 1;
92 | }
93 |
94 | // Useful field for us to support append in the future.
95 | reserved "previous_metadata";
96 | }
97 |
98 | // A small proto message serialized at the end of the file. This proto would be
99 | // stored three times in the postscript chunk for redundancy. Therefore its
100 | // serialized size times three replication count should be smaller than the 64KB
101 | // Riegeli block boundary.
102 | message RiegeliPostscript {
103 | optional uint64 footer_offset = 1;
104 | optional uint64 magic = 2;
105 |
106 | reserved "footer_encoding";
107 | }
108 |
--------------------------------------------------------------------------------
/cpp/masked_reader.cc:
--------------------------------------------------------------------------------
1 | /* Copyright 2022 Google LLC. All Rights Reserved.
2 |
3 | Licensed under the Apache License, Version 2.0 (the "License");
4 | you may not use this file except in compliance with the License.
5 | You may obtain a copy of the License at
6 |
7 | http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | Unless required by applicable law or agreed to in writing, software
10 | distributed under the License is distributed on an "AS IS" BASIS,
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | See the License for the specific language governing permissions and
13 | limitations under the License.
14 | ==============================================================================*/
15 |
16 | #include "cpp/masked_reader.h"
17 |
18 | #include
19 | #include
20 | #include
21 |
22 | #include "absl/memory/memory.h"
23 | #include "absl/status/status.h"
24 | #include "absl/time/clock.h"
25 | #include "absl/time/time.h"
26 | #include "cpp/common.h"
27 | #include "riegeli/base/object.h"
28 | #include "riegeli/base/status.h"
29 | #include "riegeli/base/types.h"
30 |
31 | namespace array_record {
32 |
33 | using riegeli::Annotate;
34 | using riegeli::Position;
35 | using riegeli::Reader;
36 |
37 | MaskedReader::MaskedReader(std::unique_ptr src_reader,
38 | size_t length)
39 | : buffer_(std::make_shared()) {
40 | auto pos = src_reader->pos();
41 | buffer_->resize(length);
42 | if (!src_reader->Read(length, buffer_->data())) {
43 | Fail(Annotate(src_reader->status(),
44 | "Could not read from the underlying reader"));
45 | return;
46 | }
47 | /*
48 | * limit_pos
49 | * |---------------------------|
50 | * buffer_start buffer_limit
51 | * |................|----------|
52 | */
53 | set_buffer(buffer_->data(), buffer_->size());
54 | set_limit_pos(pos + buffer_->size());
55 | }
56 |
57 | MaskedReader::MaskedReader(std::shared_ptr buffer,
58 | Position limit_pos)
59 | : buffer_(buffer) {
60 | /*
61 | * limit_pos
62 | * |---------------------------|
63 | * buffer_start buffer_limit
64 | * |................|----------|
65 | */
66 | set_buffer(buffer_->data(), buffer_->size());
67 | set_limit_pos(limit_pos);
68 | }
69 |
70 | MaskedReader::MaskedReader(MaskedReader &&other) noexcept
71 | : Reader(std::move(other)) {
72 | buffer_ = other.buffer_; // NOLINT(bugprone-use-after-move)
73 | other.Reset(riegeli::kClosed); // NOLINT(bugprone-use-after-move)
74 | }
75 |
76 | MaskedReader &MaskedReader::operator=(MaskedReader &&other) noexcept {
77 | // Move other
78 | Reader::operator=(static_cast(other));
79 | // Copy the shared buffer.
80 | buffer_ = other.buffer_;
81 | // Close `other`
82 | other.Reset(riegeli::kClosed);
83 | return *this;
84 | }
85 |
86 | bool MaskedReader::PullSlow(size_t min_length, size_t recommended_length) {
87 | Fail(FailedPreconditionError("Should not pull beyond buffer"));
88 | return false;
89 | }
90 |
91 | bool MaskedReader::SeekSlow(riegeli::Position new_pos) {
92 | Fail(FailedPreconditionError("Should not seek beyond buffer"));
93 | return false;
94 | }
95 |
96 | absl::optional MaskedReader::SizeImpl() {
97 | return limit_pos();
98 | }
99 |
100 | std::unique_ptr MaskedReader::NewReaderImpl(Position initial_pos) {
101 | if (!ok()) {
102 | return nullptr;
103 | }
104 | std::unique_ptr reader =
105 | absl::WrapUnique(new MaskedReader(buffer_, limit_pos()));
106 | if (!reader->Seek(initial_pos)) {
107 | return nullptr;
108 | }
109 | return reader;
110 | }
111 |
112 | } // namespace array_record
113 |
--------------------------------------------------------------------------------
/cpp/masked_reader.h:
--------------------------------------------------------------------------------
1 | /* Copyright 2022 Google LLC. All Rights Reserved.
2 |
3 | Licensed under the Apache License, Version 2.0 (the "License");
4 | you may not use this file except in compliance with the License.
5 | You may obtain a copy of the License at
6 |
7 | http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | Unless required by applicable law or agreed to in writing, software
10 | distributed under the License is distributed on an "AS IS" BASIS,
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | See the License for the specific language governing permissions and
13 | limitations under the License.
14 | ==============================================================================*/
15 |
16 | #ifndef ARRAY_RECORD_CPP_MASKED_READER_H_
17 | #define ARRAY_RECORD_CPP_MASKED_READER_H_
18 |
19 | #include
20 | #include
21 |
22 | #include "absl/types/optional.h"
23 | #include "riegeli/base/object.h"
24 | #include "riegeli/base/types.h"
25 | #include "riegeli/bytes/reader.h"
26 |
27 | namespace array_record {
28 |
29 | // `MaskedReader` is a riegeli reader constructed from the other reader with a
30 | // cached buffer over a specified region. User can further derive new readers
31 | // that share the same buffer.
32 | //
33 | // original file |----------------------------------------------|
34 | //
35 | // masked buffer |..............|---------------|
36 | // |--------------------^ Position is addressed from the
37 | // beginning of the underlying buffer.
38 | //
39 | // `MaskedBuffer` and the original file uses the same position base address.
40 | // User cannot seek to the region beyond the buffer region.
41 | //
42 | // This class is useful for reducing the number of PReads. User may create a
43 | // MaskedReader containing multiple chunks, then derive multiple chunk readers
44 | // from this reader sharing the same buffer. Hence, there's only one PRead
45 | // issued for multiple chunks.
46 | class MaskedReader : public riegeli::Reader {
47 | public:
48 | explicit MaskedReader(riegeli::Closed) : riegeli::Reader(riegeli::kClosed) {}
49 |
50 | MaskedReader(std::unique_ptr src_reader, size_t length);
51 |
52 | MaskedReader(MaskedReader &&other) noexcept;
53 | MaskedReader &operator=(MaskedReader &&other) noexcept;
54 |
55 | bool SupportsRandomAccess() override { return true; }
56 | bool SupportsNewReader() override { return true; }
57 |
58 | protected:
59 | bool PullSlow(size_t min_length, size_t recommended_length) override;
60 | bool SeekSlow(riegeli::Position new_pos) override;
61 |
62 | absl::optional SizeImpl() override;
63 | std::unique_ptr NewReaderImpl(
64 | riegeli::Position initial_pos) override;
65 |
66 | private:
67 | // Private constructor that copies itself.
68 | MaskedReader(std::shared_ptr buffer,
69 | riegeli::Position limit_pos);
70 |
71 | std::shared_ptr buffer_;
72 | };
73 |
74 | } // namespace array_record
75 |
76 | #endif // ARRAY_RECORD_CPP_MASKED_READER_H_
77 |
--------------------------------------------------------------------------------
/cpp/masked_reader_test.cc:
--------------------------------------------------------------------------------
1 | /* Copyright 2022 Google LLC. All Rights Reserved.
2 |
3 | Licensed under the Apache License, Version 2.0 (the "License");
4 | you may not use this file except in compliance with the License.
5 | You may obtain a copy of the License at
6 |
7 | http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | Unless required by applicable law or agreed to in writing, software
10 | distributed under the License is distributed on an "AS IS" BASIS,
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | See the License for the specific language governing permissions and
13 | limitations under the License.
14 | ==============================================================================*/
15 |
16 | #include "cpp/masked_reader.h"
17 |
18 | #include
19 |
20 | #include "gtest/gtest.h"
21 | #include "riegeli/bytes/string_reader.h"
22 |
23 | namespace array_record {
24 | namespace {
25 |
26 | using riegeli::StringReader;
27 |
28 | TEST(MaskedReaderTest, SanityTest) {
29 | auto data = std::string("0123456789abcdef");
30 | auto base_reader = StringReader(data);
31 | // 56789abc
32 | auto masked_reader1 = MaskedReader(base_reader.NewReader(5), 8);
33 | // Matches where we offset the reader.
34 | EXPECT_EQ(masked_reader1.pos(), 5);
35 | // Matches offset + mask length
36 | EXPECT_EQ(masked_reader1.Size(), 8 + 5);
37 | {
38 | std::string result;
39 | masked_reader1.Read(4, result);
40 | EXPECT_EQ(result, "5678");
41 | EXPECT_EQ(masked_reader1.pos(), 5 + 4);
42 | masked_reader1.Read(4, result);
43 | EXPECT_EQ(result, "9abc");
44 | EXPECT_EQ(masked_reader1.pos(), 5 + 8);
45 | }
46 |
47 | auto masked_reader2 = masked_reader1.NewReader(7);
48 | // Size does not change
49 | EXPECT_EQ(masked_reader2->Size(), 8 + 5);
50 | // pos is the new position we set from NewReader
51 | EXPECT_EQ(masked_reader2->pos(), 7);
52 | {
53 | std::string result;
54 | masked_reader2->Read(4, result);
55 | EXPECT_EQ(result, "789a");
56 | }
57 |
58 | // Reaching position that is out of bound does not fail the base reader.
59 | // It simply returns a nullptr.
60 | EXPECT_EQ(masked_reader1.NewReader(0), nullptr);
61 | EXPECT_EQ(masked_reader1.NewReader(20), nullptr);
62 | EXPECT_TRUE(masked_reader1.ok());
63 |
64 | // Support seek
65 | masked_reader1.Seek(6);
66 | {
67 | std::string result;
68 | masked_reader1.Read(4, result);
69 | EXPECT_EQ(result, "6789");
70 | }
71 |
72 | // Seek beyond buffer is a failure.
73 | EXPECT_FALSE(masked_reader1.Seek(20));
74 | }
75 |
76 | } // namespace
77 |
78 | } // namespace array_record
79 |
--------------------------------------------------------------------------------
/cpp/parallel_for.h:
--------------------------------------------------------------------------------
1 | /* Copyright 2022 Google LLC. All Rights Reserved.
2 |
3 | Licensed under the Apache License, Version 2.0 (the "License");
4 | you may not use this file except in compliance with the License.
5 | You may obtain a copy of the License at
6 |
7 | http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | Unless required by applicable law or agreed to in writing, software
10 | distributed under the License is distributed on an "AS IS" BASIS,
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | See the License for the specific language governing permissions and
13 | limitations under the License.
14 | ==============================================================================*/
15 |
16 | #ifndef ARRAY_RECORD_CPP_PARALLEL_FOR_H_
17 | #define ARRAY_RECORD_CPP_PARALLEL_FOR_H_
18 |
19 | #include
20 | #include
21 | #include
22 | #include
23 | #include
24 |
25 | #include "absl/base/optimization.h"
26 | #include "absl/status/status.h"
27 | #include "absl/synchronization/mutex.h"
28 | #include "cpp/common.h"
29 | #include "cpp/thread_pool.h"
30 |
31 | namespace array_record {
32 |
33 | // kDynamicBatchSize - when a batch size isn't specified, ParallelFor defaults
34 | // to dividing the work into (4 * num_threads) batches, enabling decently good
35 | // parallelism, while minimizing coordination overhead.
36 | enum : size_t {
37 | kDynamicBatchSize = std::numeric_limits::max(),
38 | };
39 |
40 | // Options for ParallelFor. The defaults are sufficient for most users.
41 | struct ParallelForOptions {
42 | // It may be desirable to limit parallelism in some cases if e.g.:
43 | // 1. A portion of the loop body requires synchronization and Amdahl's Law
44 | // prevents scaling past a small number of threads.
45 | // 2. You're running on a NUMA system and don't want this loop to execute
46 | // across NUMA nodes.
47 | size_t max_parallelism = std::numeric_limits::max();
48 | };
49 |
50 | // ParallelFor - execute a for-loop in parallel, using both the calling thread,
51 | // plus the threads available in a ARThreadPool argument.
52 | //
53 | // Arguments:
54 | //
55 | // * - the sequence to be processed, eg "Seq(vec.size())".
56 | //
57 | // * - the threadpool to use (in addition to the main thread). If this
58 | // parameter is nullptr, then ParallelFor will compile down to a vanilla
59 | // single-threaded for-loop. It is permissible for the calling thread to be a
60 | // member of .
61 | //
62 | // * - the method to call for each value in .
63 | //
64 | // * [template param] - the number of calls to that
65 | // each thread will perform before synchronizing access to the loop counter. If
66 | // not provided, the array will be divided into (n_threads * 4) work batches,
67 | // enabling good parallelism in most cases, while minimizing synchronization
68 | // overhead.
69 | //
70 | //
71 | // Example: Compute 1M sqrts in parallel.
72 | //
73 | // vector sqrts(1000000);
74 | // ParallelFor(Seq(sqrts.size()), &pool, [&](size_t i) {
75 | // sqrts[i] = sqrt(i);
76 | // });
77 | //
78 | //
79 | // Example: Only the evens, using SeqWithStride.
80 | //
81 | // ParallelFor(SeqWithStride<2>(0, sqrts.size()), &pool, [&sqrts](size_t i) {
82 | // sqrts[i] = i;
83 | // });
84 | //
85 | //
86 | // Example: Execute N expensive tasks in parallel, in batches of 1-at-a-time.
87 | //
88 | // ParallelFor<1>(Seq(tasks.size()), &pool, [&](size_t j) {
89 | // DoSomethingExpensive(tasks[j]);
90 | // });
91 | //
92 | template
94 | inline void ParallelFor(SeqT seq, ARThreadPool* pool, Function func,
95 | ParallelForOptions opts = ParallelForOptions());
96 |
97 | // ParallelForWithStatus - Similar to ParallelFor, except it can short circuit
98 | // if an error occurred.
99 | //
100 | // Arguments:
101 | // * - the sequence to be processed, eg "Seq(vec.size())".
102 | //
103 | // * - the threadpool to use (in addition to the main thread). If this
104 | // parameter is nullptr, then ParallelFor will compile down to a vanilla
105 | // single-threaded for-loop. It is permissible for the calling thread to be a
106 | // member of .
107 | //
108 | // * - the method to call for each value in with type
109 | // interface of std::function.
110 | //
111 | // * [template param] - the number of calls to that
112 | // each thread will perform before synchronizing access to the loop counter. If
113 | // not provided, the array will be divided into (n_threads * 4) work batches,
114 | // enabling good parallelism in most cases, while minimizing synchronization
115 | // overhead.
116 | //
117 | // Example:
118 | //
119 | // auto status = ParallelForWithStatus<1>(
120 | // Seq(tasks.size()), &pool, [&](size_t idx) -> absl::Status {
121 | // RETURN_IF_ERROR(RunTask(tasks[idx]));
122 | // return absl::OkStatus();
123 | // });
124 | //
125 | template
127 | inline absl::Status ParallelForWithStatus(
128 | SeqT seq, ARThreadPool* pool, Function Func,
129 | ParallelForOptions opts = ParallelForOptions()) {
130 | absl::Status finite_check_status = absl::OkStatus();
131 |
132 | std::atomic_bool is_ok_status{true};
133 | absl::Mutex mutex;
134 | ParallelFor(
135 | seq, pool,
136 | [&](size_t idx) {
137 | if (!is_ok_status.load(std::memory_order_relaxed)) {
138 | return;
139 | }
140 | absl::Status status = Func(idx);
141 | if (!status.ok()) {
142 | absl::MutexLock lock(&mutex);
143 | finite_check_status = status;
144 | is_ok_status.store(false, std::memory_order_relaxed);
145 | }
146 | },
147 | opts);
148 | return finite_check_status;
149 | }
150 |
151 | ////////////////////////////////////////////////////////////////////////////////
152 | // IMPLEMENTATION DETAILS
153 | ////////////////////////////////////////////////////////////////////////////////
154 |
155 | namespace parallel_for_internal {
156 |
157 | // ParallelForClosure - a single heap-allocated object that holds the loop's
158 | // state. The object will delete itself when the final task completes.
159 | template
160 | class ParallelForClosure {
161 | public:
162 | static constexpr bool kIsDynamicBatch = (kItersPerBatch == kDynamicBatchSize);
163 | ParallelForClosure(SeqT seq, Function func)
164 | : func_(func),
165 | index_(*seq.begin()),
166 | range_end_(*seq.end()),
167 | reference_count_(1) {}
168 |
169 | inline void RunParallel(ARThreadPool* pool, size_t desired_threads) {
170 | // Don't push more tasks to the pool than we have work for. Also, if
171 | // parallelism is limited by desired_threads not thread pool size, subtract
172 | // 1 from the number of threads to push to account for the main thread.
173 | size_t n_threads =
174 | std::min(desired_threads - 1, pool->NumThreads());
175 |
176 | // Handle dynamic batch size.
177 | if (kIsDynamicBatch) {
178 | batch_size_ =
179 | SeqT::Stride() * std::max(1ul, desired_threads / 4 / n_threads);
180 | }
181 |
182 | reference_count_ += n_threads;
183 | while (n_threads--) {
184 | pool->Schedule([this]() { Run(); });
185 | }
186 |
187 | // Do work on the main thread. Once this returns, we are guaranteed that all
188 | // batches have been assigned to some thread.
189 | DoWork();
190 |
191 | // Then wait for all worker threads to exit the core loop. Thus, once the
192 | // main thread is able to take a WriterLock, we are guaranteed that all
193 | // batches have finished, allowing the main thread to move on.
194 | //
195 | // The main thread does *NOT* wait for ARThreadPool tasks that haven't yet
196 | // entered the core loop. This is important for handling scenarios where
197 | // the ARThreadPool falls significantly behind and hasn't started some of
198 | // the tasks assigned to it. Once assigned, those tasks will quickly realize
199 | // that there is no work left, and the final task to schedule will delete
200 | // this heap-allocated object.
201 | //
202 | termination_mutex_.WriterLock();
203 | termination_mutex_.WriterUnlock();
204 |
205 | // Drop main thread's reference.
206 | if (--reference_count_ == 0) delete this;
207 | }
208 |
209 | void Run() {
210 | // Do work on a child thread. Before starting any work, each child thread
211 | // takes a reader lock, preventing the main thread from finishing while
212 | // any child threads are still executing in the core loop.
213 | termination_mutex_.ReaderLock();
214 | DoWork();
215 | termination_mutex_.ReaderUnlock();
216 |
217 | // Drop child thread's reference.
218 | if (--reference_count_ == 0) delete this;
219 | }
220 |
221 | // DoWork - the "core loop", executed in parallel on N threads.
222 | inline void DoWork() {
223 | // Performance Note: Copying constant values to the stack allows the
224 | // compiler to know that they are actually constant and can be assigned to
225 | // registers (the 'const' keyword is insufficient).
226 | const size_t range_end = range_end_;
227 |
228 | // Performance Note: when batch size is not dynamic, the compiler will treat
229 | // it as a constant that can be directly inlined into the code w/o consuming
230 | // a register.
231 | constexpr size_t kStaticBatchSize = SeqT::Stride() * kItersPerBatch;
232 | const size_t batch_size = kIsDynamicBatch ? batch_size_ : kStaticBatchSize;
233 |
234 | // The core loop:
235 | for (;;) {
236 | // The std::atomic index_ coordinates sequential batch assignment.
237 | const size_t batch_begin = index_.fetch_add(batch_size);
238 | // Once assigned, batches execute w/o further coordination.
239 | const size_t batch_end = std::min(batch_begin + batch_size, range_end);
240 | if (ABSL_PREDICT_FALSE(batch_begin >= range_end)) break;
241 | for (size_t idx : SeqWithStride(batch_begin, batch_end)) {
242 | func_(idx);
243 | }
244 | }
245 | }
246 |
247 | private:
248 | Function func_;
249 |
250 | // The index_ is used by worker threads to coordinate sequential batch
251 | // assignment. This is the only coordination mechanism inside the core loop.
252 | std::atomic index_;
253 |
254 | // The iteration stops at range_end_.
255 | const size_t range_end_;
256 |
257 | // The termination_mutex_ coordinates termination of the for-loop, allowing
258 | // the main thread to batch until all child threads have exited the core-loop.
259 | absl::Mutex termination_mutex_;
260 |
261 | // For smaller arrays, it's possible for the work to finish before some of
262 | // the ARThreadPool tasks have started running, and in some extreme cases, it
263 | // might take entire milliseconds before these tasks begin running. The main
264 | // thread will continue doing other work, and the last task to schedule and
265 | // terminate will delete this heap allocated object.
266 | std::atomic reference_count_;
267 |
268 | // The batch_size_ member variable is only used when the batch size is
269 | // dynamic (io, kItersPerBatch == kDynamicBatchSize).
270 | size_t batch_size_ = kItersPerBatch;
271 | };
272 |
273 | } // namespace parallel_for_internal
274 |
275 | template
276 | inline void ParallelFor(SeqT seq, ARThreadPool* pool, Function func,
277 | ParallelForOptions opts) {
278 | // Figure out how many batches of work we have.
279 | constexpr size_t kMinItersPerBatch =
280 | kItersPerBatch == kDynamicBatchSize ? 1 : kItersPerBatch;
281 | const size_t desired_threads = std::min(
282 | opts.max_parallelism, DivRoundUp(*seq.end() - *seq.begin(),
283 | SeqT::Stride() * kMinItersPerBatch));
284 |
285 | // Unfortunately TF ThreadPool has no interface to monitor queue fullness
286 | // Serialized vanilla for-loop for handling any of:
287 | //
288 | // * No ThreadPool provided.
289 | //
290 | // * ThreadPool has fallen very far behind.
291 | //
292 | // * Small arrays w/ only 1 batch of work to do.
293 | //
294 | // Note that the compiler will inline this logic, making ParallelFor
295 | // equivalent to a traditional C++ for-loop for the cases above.
296 | if (pool == nullptr ||
297 | desired_threads <= 1
298 | ) {
299 | for (size_t idx : seq) {
300 | func(idx);
301 | }
302 | return;
303 | }
304 |
305 | // Otherwise, fire up the threadpool. Note that the shared closure object is
306 | // heap-allocated, allowing this method to finish even if some tasks haven't
307 | // started running yet. The object will be deleted by the last finishing task,
308 | // or possibly by this thread, whichever is last to terminate.
309 | using parallel_for_internal::ParallelForClosure;
310 | auto closure =
311 | new ParallelForClosure(seq, func);
312 | closure->RunParallel(pool, desired_threads);
313 | }
314 |
315 | } // namespace array_record
316 |
317 | #endif // ARRAY_RECORD_CPP_PARALLEL_FOR_H_
318 |
--------------------------------------------------------------------------------
/cpp/parallel_for_test.cc:
--------------------------------------------------------------------------------
1 | /* Copyright 2022 Google LLC. All Rights Reserved.
2 |
3 | Licensed under the Apache License, Version 2.0 (the "License");
4 | you may not use this file except in compliance with the License.
5 | You may obtain a copy of the License at
6 |
7 | http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | Unless required by applicable law or agreed to in writing, software
10 | distributed under the License is distributed on an "AS IS" BASIS,
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | See the License for the specific language governing permissions and
13 | limitations under the License.
14 | ==============================================================================*/
15 |
16 | // Tests for parallel_for.h.
17 | #include "cpp/parallel_for.h"
18 |
19 | #include
20 | #include
21 | #include
22 | #include
23 | #include
24 | #include
25 | #include
26 | #include
27 |
28 | #include "gmock/gmock.h"
29 | #include "gtest/gtest.h"
30 | #include "absl/functional/function_ref.h"
31 | #include "absl/status/status.h"
32 | #include "cpp/common.h"
33 | #include "cpp/thread_pool.h"
34 |
35 | namespace array_record {
36 |
37 | class ParallelForTest : public testing::Test {
38 | protected:
39 | void SetUp() final { pool_ = ArrayRecordGlobalPool(); }
40 |
41 | public:
42 | static constexpr int32_t kNumElements = 1000000;
43 | ARThreadPool* pool_;
44 | };
45 |
46 | TEST_F(ParallelForTest, AtomicCounterTest) {
47 | std::vector items(kNumElements, 0);
48 | ParallelFor(Seq(kNumElements), pool_, [&](size_t j) { items[j] = 1; });
49 | int32_t num_accessed = 0;
50 | for (auto item_was_accessed : items) {
51 | if (item_was_accessed) {
52 | num_accessed++;
53 | }
54 | }
55 | EXPECT_EQ(num_accessed, kNumElements);
56 | }
57 |
58 | TEST_F(ParallelForTest, ImplicitStepImplicitBlock) {
59 | std::vector result(kNumElements);
60 | auto l = [&result](size_t j) { result[j] += sqrt(j); };
61 |
62 | ParallelFor(Seq(kNumElements), pool_, l);
63 | for (size_t j : Seq(kNumElements)) {
64 | EXPECT_EQ(sqrt(j), result[j]);
65 | }
66 |
67 | result.clear();
68 | result.resize(kNumElements);
69 | ParallelFor(Seq(kNumElements), nullptr, l);
70 | for (size_t j : Seq(kNumElements)) {
71 | EXPECT_EQ(sqrt(j), result[j]);
72 | }
73 | }
74 |
75 | TEST_F(ParallelForTest, ImplicitStepExplicitBlock) {
76 | std::vector result(kNumElements);
77 | auto l = [&result](size_t j) { result[j] += sqrt(j); };
78 | ParallelFor<10>(Seq(kNumElements), pool_, l);
79 |
80 | for (size_t j : Seq(kNumElements)) {
81 | EXPECT_EQ(sqrt(j), result[j]);
82 | }
83 |
84 | result.clear();
85 | result.resize(kNumElements);
86 | ParallelFor<10>(Seq(kNumElements), nullptr, l);
87 |
88 | for (size_t j : Seq(kNumElements)) {
89 | EXPECT_EQ(sqrt(j), result[j]);
90 | }
91 | }
92 |
93 | TEST_F(ParallelForTest, ExplicitStepExplicitBlock) {
94 | std::vector result(kNumElements);
95 | auto l = [&result](size_t j) { result[j] += sqrt(j); };
96 | ParallelFor<10>(SeqWithStride<2>(kNumElements), pool_, l);
97 |
98 | for (size_t j : Seq(kNumElements)) {
99 | // We only did the even numbered elements, so the odd ones should be zero.
100 | if (j & 1) {
101 | EXPECT_EQ(result[j], 0.0);
102 | } else {
103 | EXPECT_EQ(sqrt(j), result[j]);
104 | }
105 | }
106 |
107 | result.clear();
108 | result.resize(kNumElements);
109 | ParallelFor<10>(SeqWithStride<2>(kNumElements), nullptr, l);
110 |
111 | for (size_t j : Seq(kNumElements)) {
112 | // We only did the even numbered elements, so the odd ones should be zero.
113 | if (j & 1) {
114 | EXPECT_EQ(result[j], 0.0);
115 | } else {
116 | EXPECT_EQ(sqrt(j), result[j]);
117 | }
118 | }
119 | }
120 |
121 | TEST_F(ParallelForTest, ExplicitStepImplicitBlock) {
122 | std::vector result(kNumElements);
123 | auto l = [&result](size_t j) { result[j] += sqrt(j); };
124 | ParallelFor(SeqWithStride<2>(kNumElements), pool_, l);
125 |
126 | for (size_t j : Seq(kNumElements)) {
127 | // We only did the even numbered elements, so the odd ones should be zero.
128 | if (j & 1) {
129 | EXPECT_EQ(result[j], 0.0);
130 | } else {
131 | EXPECT_EQ(sqrt(j), result[j]);
132 | }
133 | }
134 |
135 | result.clear();
136 | result.resize(kNumElements);
137 | ParallelFor(SeqWithStride<2>(kNumElements), nullptr, l);
138 |
139 | for (size_t j : Seq(kNumElements)) {
140 | // We only did the even numbered elements, so the odd ones should be zero.
141 | if (j & 1) {
142 | EXPECT_EQ(result[j], 0.0);
143 | } else {
144 | EXPECT_EQ(sqrt(j), result[j]);
145 | }
146 | }
147 | }
148 |
149 | TEST_F(ParallelForTest, ExampleCompiles) {
150 | // Once c-style approves lambdas, usage of this library will become very clean
151 | // as illustrated below. Compute the square root of every number from 0 to
152 | // 1000000 in parallel.
153 | std::vector sqrts(1000000);
154 | auto pool = ArrayRecordGlobalPool();
155 |
156 | ParallelFor(Seq(sqrts.size()), pool,
157 | [&sqrts](size_t j) { sqrts[j] = sqrt(j); });
158 |
159 | // Only compute the square roots of even numbers by using an explicit step.
160 | ParallelFor(SeqWithStride<2>(sqrts.size()), pool,
161 | [&sqrts](size_t j) { sqrts[j] = j; });
162 |
163 | // The block_size parameter can be adjusted to control the granularity of
164 | // parallelism. This parameter represents the number of iterations of the
165 | // loop that will be done in a single thread before communicating with any
166 | // other threads. Smaller block sizes lead to better load balancing between
167 | // threads. Larger block sizes lead to less communication overhead and less
168 | // risk of false sharing (http://en.wikipedia.org/wiki/False_sharing)
169 | // when writing to adjacent array elements from different threads based
170 | // on the loop index. The default block size creates approximately 4
171 | // blocks per thread.
172 | //
173 | // Use an explicit block size of 10.
174 | ParallelFor<10>(SeqWithStride<2>(sqrts.size()), pool,
175 | [&sqrts](size_t j) { sqrts[j] = j; });
176 | }
177 |
178 | TEST_F(ParallelForTest, ParallelForWithStatusTest) {
179 | std::atomic_int counter = 0;
180 | auto status =
181 | ParallelForWithStatus<1>(Seq(kNumElements), pool_, [&](size_t i) {
182 | counter.fetch_add(1, std::memory_order_release);
183 | return absl::OkStatus();
184 | });
185 | EXPECT_TRUE(status.ok());
186 | EXPECT_EQ(counter.load(std::memory_order_acquire), kNumElements);
187 | }
188 |
189 | TEST_F(ParallelForTest, ParallelForWithStatusTestShortCircuit) {
190 | std::atomic_int counter = 0;
191 | auto status =
192 | ParallelForWithStatus<1>(Seq(kNumElements), pool_, [&](size_t i) {
193 | counter.fetch_add(1, std::memory_order_release);
194 | return absl::UnknownError("Intended error");
195 | });
196 | EXPECT_EQ(status.code(), absl::StatusCode::kUnknown);
197 | EXPECT_LT(counter.load(std::memory_order_acquire), kNumElements);
198 | }
199 |
200 | } // namespace array_record
201 |
--------------------------------------------------------------------------------
/cpp/sequenced_chunk_writer.cc:
--------------------------------------------------------------------------------
1 | /* Copyright 2022 Google LLC. All Rights Reserved.
2 |
3 | Licensed under the Apache License, Version 2.0 (the "License");
4 | you may not use this file except in compliance with the License.
5 | You may obtain a copy of the License at
6 |
7 | http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | Unless required by applicable law or agreed to in writing, software
10 | distributed under the License is distributed on an "AS IS" BASIS,
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | See the License for the specific language governing permissions and
13 | limitations under the License.
14 | ==============================================================================*/
15 |
16 | #include "cpp/sequenced_chunk_writer.h"
17 |
18 | #include // NOLINT(build/c++11)
19 | #include
20 | #include // NOLINT(build/c++11)
21 | #include
22 |
23 | #include "absl/status/status.h"
24 | #include "absl/status/statusor.h"
25 | #include "absl/strings/str_format.h"
26 | #include "absl/synchronization/mutex.h"
27 | #include "riegeli/base/status.h"
28 | #include "riegeli/base/types.h"
29 | #include "riegeli/chunk_encoding/chunk.h"
30 | #include "riegeli/chunk_encoding/constants.h"
31 | #include "riegeli/records/chunk_writer.h"
32 |
33 | namespace array_record {
34 |
35 | bool SequencedChunkWriterBase::CommitFutureChunk(
36 | std::future>&& future_chunk) {
37 | absl::MutexLock l(&mu_);
38 | if (!ok()) {
39 | return false;
40 | }
41 | queue_.push(std::move(future_chunk));
42 | return true;
43 | }
44 |
45 | bool SequencedChunkWriterBase::SubmitFutureChunks(bool block) {
46 | // We need to use TryLock to prevent deadlock.
47 | //
48 | // std::future::get() blocks if the result wasn't ready.
49 | // Hence the following scenario triggers a deadlock.
50 | // T1:
51 | // SubmitFutureChunks(true)
52 | // mu_ holds
53 | // Blocks on queue_.front().get();
54 | // T2:
55 | // In charge to fulfill the future of queue_.front() on its exit.
56 | // SubmitFutureChunks(false)
57 | // Blocks on mu_ if we used mu_.Lock() instead of mu_.TryLock()
58 | //
59 | // NOTE: Even if ok() is false, the below loop will drain queue_, either
60 | // completely if block is true, or until a non-ready future is at the front of
61 | // the queue in the non-blocking case. If ok() is false, the front element is
62 | // popped from the queue and discarded.
63 |
64 | if (block) {
65 | // When blocking, we block both on mutex acquisition and on future
66 | // completion.
67 | absl::MutexLock lock(&mu_);
68 | riegeli::ChunkWriter* writer = get_writer();
69 | while (!queue_.empty()) {
70 | TrySubmitFirstFutureChunk(writer);
71 | }
72 | return ok();
73 | } else if (mu_.TryLock()) {
74 | // When non-blocking, we only proceed if we can lock the mutex without
75 | // blocking, and we only process those futures that are ready. We need
76 | // to unlock the mutex manually in this case, and take care to call ok()
77 | // under the lock.
78 | riegeli::ChunkWriter* writer = get_writer();
79 | while (!queue_.empty() &&
80 | queue_.front().wait_for(std::chrono::microseconds::zero()) ==
81 | std::future_status::ready) {
82 | TrySubmitFirstFutureChunk(writer);
83 | }
84 | bool result = ok();
85 | mu_.Unlock();
86 | return result;
87 | } else {
88 | return true;
89 | }
90 | }
91 |
92 | void SequencedChunkWriterBase::TrySubmitFirstFutureChunk(
93 | riegeli::ChunkWriter* chunk_writer) {
94 | auto status_or_chunk = queue_.front().get();
95 | queue_.pop();
96 |
97 | if (!ok() || !chunk_writer->ok()) {
98 | // Note (see above): the front of the queue is popped even if we discard it
99 | // now.
100 | return;
101 | }
102 | // Set self unhealthy for bad chunks.
103 | if (!status_or_chunk.ok()) {
104 | Fail(riegeli::Annotate(
105 | status_or_chunk.status(),
106 | absl::StrFormat("Could not submit chunk: %d", submitted_chunks_)));
107 | return;
108 | }
109 | riegeli::Chunk chunk = std::move(status_or_chunk.value());
110 | uint64_t chunk_offset = chunk_writer->pos();
111 | uint64_t decoded_data_size = chunk.header.decoded_data_size();
112 | uint64_t num_records = chunk.header.num_records();
113 |
114 | if (!chunk_writer->WriteChunk(std::move(chunk))) {
115 | Fail(riegeli::Annotate(
116 | chunk_writer->status(),
117 | absl::StrFormat("Could not submit chunk: %d", submitted_chunks_)));
118 | return;
119 | }
120 | if (pad_to_block_boundary_) {
121 | if (!chunk_writer->PadToBlockBoundary()) {
122 | Fail(riegeli::Annotate(
123 | chunk_writer->status(),
124 | absl::StrFormat("Could not pad boundary for chunk: %d",
125 | submitted_chunks_)));
126 | return;
127 | }
128 | }
129 | if (!chunk_writer->Flush(riegeli::FlushType::kFromObject)) {
130 | Fail(riegeli::Annotate(
131 | chunk_writer->status(),
132 | absl::StrFormat("Could not flush chunk: %d", submitted_chunks_)));
133 | return;
134 | }
135 | if (callback_) {
136 | (*callback_)(submitted_chunks_, chunk_offset, decoded_data_size,
137 | num_records);
138 | }
139 | submitted_chunks_++;
140 | }
141 |
142 | void SequencedChunkWriterBase::Initialize() {
143 | auto* chunk_writer = get_writer();
144 | riegeli::Chunk chunk;
145 | chunk.header = riegeli::ChunkHeader(chunk.data,
146 | riegeli::ChunkType::kFileSignature, 0, 0);
147 | if (!chunk_writer->WriteChunk(chunk)) {
148 | Fail(riegeli::Annotate(chunk_writer->status(),
149 | "Failed to create the file header"));
150 | }
151 | if (!chunk_writer->Flush(riegeli::FlushType::kFromObject)) {
152 | Fail(riegeli::Annotate(chunk_writer->status(), "Could not flush"));
153 | }
154 | }
155 |
156 | void SequencedChunkWriterBase::Done() {
157 | if (!SubmitFutureChunks(true)) {
158 | Fail(absl::InternalError("Unable to submit pending chunks"));
159 | return;
160 | }
161 | auto* chunk_writer = get_writer();
162 | if (!chunk_writer->Close()) {
163 | Fail(riegeli::Annotate(chunk_writer->status(),
164 | "Failed to close chunk_writer"));
165 | }
166 | }
167 |
168 | } // namespace array_record
169 |
--------------------------------------------------------------------------------
/cpp/sequenced_chunk_writer.h:
--------------------------------------------------------------------------------
1 | /* Copyright 2022 Google LLC. All Rights Reserved.
2 |
3 | Licensed under the Apache License, Version 2.0 (the "License");
4 | you may not use this file except in compliance with the License.
5 | You may obtain a copy of the License at
6 |
7 | http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | Unless required by applicable law or agreed to in writing, software
10 | distributed under the License is distributed on an "AS IS" BASIS,
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | See the License for the specific language governing permissions and
13 | limitations under the License.
14 | ==============================================================================*/
15 |
16 | // Low-level API for building off-memory data structures with riegeli.
17 | //
18 | // `SequencedChunkWriter` writes chunks of records to an abstracted destination.
19 | // This class abstract out the generic chunk writing logic from concrete logic
20 | // that builds up the data structures for future access. This class is
21 | // thread-safe and allows users to encode each chunk concurrently while
22 | // maintaining the sequence order of the chunks.
23 |
24 | #ifndef ARRAY_RECORD_CPP_SEQUENCED_CHUNK_WRITER_H_
25 | #define ARRAY_RECORD_CPP_SEQUENCED_CHUNK_WRITER_H_
26 |
27 | #include
28 | #include // NOLINT(build/c++11)
29 | #include
30 | #include
31 |
32 | #include "absl/base/thread_annotations.h"
33 | #include "absl/status/status.h"
34 | #include "absl/status/statusor.h"
35 | #include "absl/synchronization/mutex.h"
36 | #include "cpp/common.h"
37 | #include "riegeli/base/initializer.h"
38 | #include "riegeli/base/object.h"
39 | #include "riegeli/bytes/writer.h"
40 | #include "riegeli/chunk_encoding/chunk.h"
41 | #include "riegeli/records/chunk_writer.h"
42 |
43 | namespace array_record {
44 |
45 | // Template parameter independent part of `SequencedChunkWriter`.
46 | class SequencedChunkWriterBase : public riegeli::Object {
47 | SequencedChunkWriterBase(const SequencedChunkWriterBase&) = delete;
48 | SequencedChunkWriterBase& operator=(const SequencedChunkWriterBase&) = delete;
49 | SequencedChunkWriterBase(SequencedChunkWriterBase&&) = delete;
50 | SequencedChunkWriterBase& operator=(SequencedChunkWriterBase&&) = delete;
51 |
52 | public:
53 | // The `SequencedChunkWriter` invokes `SubmitChunkCallback` for every
54 | // successful chunk writing. In other words, the invocation only happens when
55 | // none of the errors occur (SequencedChunkWriter internal state, chunk
56 | // correctness, underlying writer state, etc.) The callback consists of four
57 | // arguments:
58 | //
59 | // chunk_seq: sequence number of the chunk in the file. Indexed from 0.
60 | // chunk_offset: byte offset of the chunk in the file. A reader can seek this
61 | // offset and decode the chunk without other information.
62 | // decoded_data_size: byte size of the decoded data. Users may serialize this
63 | // field for readers to allocate memory for the decoded data.
64 | // num_records: number of records in the chunk.
65 | class SubmitChunkCallback {
66 | public:
67 | virtual ~SubmitChunkCallback() {}
68 | virtual void operator()(uint64_t chunk_seq, uint64_t chunk_offset,
69 | uint64_t decoded_data_size,
70 | uint64_t num_records) = 0;
71 | };
72 |
73 | // Commits a future chunk to the `SequencedChunkWriter` before materializing
74 | // the chunk. Users can encode the chunk in a separated thread at the cost of
75 | // larger temporal memory usage. `SequencedChunkWriter` serializes the chunks
76 | // at the order of this function call.
77 | //
78 | // Example 1: packaged_task
79 | //
80 | // std::packaged_task()> encoding_task(
81 | // []() -> absl::StatusOr {
82 | // ... returns a riegeli::Chunk on success.
83 | // });
84 | // std::future> task_future =
85 | // encoding_task.get();
86 | // sequenced_chunk_writer->CommitFutureChunk(std::move(task_future));
87 | //
88 | // // Computes the encoding task in a thread pool.
89 | // pool->Schedule(std::move(encoding_task));
90 | //
91 | // Example 2: promise and future
92 | //
93 | // std::promise> chunk_promise;
94 | // RET_CHECK(sequenced_chunk_writer->CommitFutureChunk(
95 | // chunk_promise.get_future())) << sequenced_chunk_writer->status();
96 | // pool->Schedule([chunk_promise = std::move(chunk_promise)] {
97 | // // computes chunk
98 | // chunk_promise.set_value(status_or_chunk);
99 | // });
100 | //
101 | // Although `SequencedChunkWriter` is thread-safe, this method should be
102 | // invoked from a single thread because it doesn't make sense to submit future
103 | // chunks without a proper order.
104 | bool CommitFutureChunk(
105 | std::future>&& future_chunk);
106 |
107 | // Extracts the future chunks and submits them to the underlying destination.
108 | // This operation may block if the argument `block` was true. This method is
109 | // thread-safe, and we recommend users invoke it with `block=false` in each
110 | // thread to reduce the temporal memory usage.
111 | //
112 | // If ok() is false before or during this operation, queue elements continue
113 | // to be extracted, but are immediately discarded.
114 | //
115 | // Example 1: single thread usage
116 | //
117 | // std::promise> chunk_promise;
118 | // RET_CHECK(sequenced_chunk_writer->CommitFutureChunk(
119 | // chunk_promise.get_future())) << sequenced_chunk_writer->status();
120 | // chunk_promise.set_value(ComputesChunk());
121 | // RET_CHECK(writer->SubmitFutureChunks(true)) << writer->status();
122 | //
123 | // Example 2: concurrent access
124 | //
125 | // riegeli::SharedPtr writer(riegeli::MakerSchedule([writer,
128 | // chunk_promise = std::move(chunk_promise)]() mutable {
129 | // chunk_promise.set_value(status_or_chunk);
130 | // // Should not block otherwise would enter deadlock!
131 | // writer->SubmitFutureChunks(false);
132 | // });
133 | // // Blocking the main thread is fine.
134 | // RET_CHECK(writer->SubmitFutureChunks(true)) << writer->status();
135 | //
136 | bool SubmitFutureChunks(bool block = false);
137 |
138 | // Pads to 64KB boundary for future chunk submission. (Default false).
139 | void set_pad_to_block_boundary(bool pad_to_block_boundary) {
140 | absl::MutexLock l(&mu_);
141 | pad_to_block_boundary_ = pad_to_block_boundary;
142 | }
143 | bool pad_to_block_boundary() {
144 | absl::MutexLock l(&mu_);
145 | return pad_to_block_boundary_;
146 | }
147 |
148 | // Setup a callback for each committed chunk. See CommitChunkCallback
149 | // comments for details.
150 | void set_submit_chunk_callback(SubmitChunkCallback* callback) {
151 | absl::MutexLock l(&mu_);
152 | callback_ = callback;
153 | }
154 |
155 | // Guard the status access.
156 | absl::Status status() const {
157 | absl::ReaderMutexLock l(&mu_);
158 | return riegeli::Object::status();
159 | }
160 |
161 | protected:
162 | SequencedChunkWriterBase() {}
163 | virtual riegeli::ChunkWriter* get_writer() = 0;
164 |
165 | // Initializes and validates the underlying writer states.
166 | void Initialize();
167 |
168 | // Callback for riegeli::Object::Close.
169 | void Done() override;
170 |
171 | private:
172 | // Attempts to submit the first chunk from the queue. Expects that the lock is
173 | // already held. Even if ok() is false on entry already, the queue element is
174 | // removed (and discarded).
175 | void TrySubmitFirstFutureChunk(riegeli::ChunkWriter* chunk_writer)
176 | ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
177 |
178 | mutable absl::Mutex mu_;
179 | bool pad_to_block_boundary_ ABSL_GUARDED_BY(mu_) = false;
180 | SubmitChunkCallback* callback_ ABSL_GUARDED_BY(mu_) = nullptr;
181 |
182 | // Records the sequence number of submitted chunks.
183 | uint64_t submitted_chunks_ ABSL_GUARDED_BY(mu_) = 0;
184 |
185 | // Queue for storing the future chunks.
186 | std::queue>> queue_
187 | ABSL_GUARDED_BY(mu_);
188 | };
189 |
190 | // A `SequencedChunkWriter` writes chunks (a blob of multiple and possibly
191 | // compressed records) rather than individual records to an abstracted
192 | // destination. `SequencedChunkWriter` allows users to encode each chunk
193 | // concurrently while keeping the chunk sequence order as the input order.
194 | //
195 | // Users can also supply a `CommitChunkCallback` to collect chunk sequence
196 | // numbers, offsets in the file, decoded data size, and the number of records in
197 | // each chunk. Users may use the callback information to produce a lookup table
198 | // in the footer for an efficient reader to decode multiple chunks in parallel.
199 | //
200 | // Example usage:
201 | //
202 | // // Step 1: open the writer with file backend.
203 | // File* file = file::OpenOrDie(...);
204 | // riegeli::SharedPtr writer(riegeli::Maker(
205 | // riegeli::Maker(filename_or_file)));
206 | //
207 | // // Step 2: create a chunk encoding task.
208 | // std::packaged_task()> encoding_task(
209 | // []() -> absl::StatusOr {
210 | // ... returns a riegeli::Chunk on success.
211 | // });
212 | //
213 | // // Step 3: book a slot for writing the encoded chunk.
214 | // RET_CHECK(writer->CommitFutureChunk(
215 | // encoding_task.get_future())) << writer->status();
216 | //
217 | // // Step 4: Computes the encoding task in a thread pool.
218 | // pool->Schedule([=,encoding_task=std::move(encoding_task)]() mutable {
219 | // encoding_task(); // std::promise fulfilled.
220 | // // riegeli::SharedPtr pevents the writer to go out of scope, so it is
221 | // // safe to invoke the method here.
222 | // writer->SubmitFutureChunks(false);
223 | // });
224 | //
225 | // // Repeats step 2 to 4.
226 | //
227 | // // Finally, close the writer.
228 | // RET_CHECK(writer->Close()) << writer->status();
229 | //
230 | //
231 | // It is necessary to call `Close()` at the end of a successful writing session,
232 | // and it is recommended to call `Close()` at the end of a successful reading
233 | // session. It is not needed to call `Close()` on early returns, assuming that
234 | // contents of the destination do not matter after all, e.g. because a failure
235 | // is being reported instead; the destructor releases resources in any case.
236 | //
237 | // `SequencedChunkWriter` inherits riegeli::Object which provides useful
238 | // abstractions for state management of IO-like operations. Instead of the
239 | // common absl::Status/StatusOr for each method, the riegeli::Object's error
240 | // handling mechanism uses bool and separated `status()`, `ok()`, `is_open()`
241 | // for users to handle different types of failure states.
242 | //
243 | //
244 | // `SequencedChunkWriter` use templated backend abstraction. To serialize the
245 | // output to a string, user simply write:
246 | //
247 | // std::string dest;
248 | // SequencedChunkWriter writes_to_string(
249 | // riegeli::Maker(&dest));
250 | //
251 | // Similarly, user can write the output to a cord or to a file.
252 | //
253 | // absl::Cord cord;
254 | // SequencedChunkWriter writes_to_cord(
255 | // riegeli::Maker(&cord));
256 | //
257 | // SequencedChunkWriter writes_to_file(
258 | // riegeli::Maker(filename_or_file));
259 | //
260 | // User may also use riegeli::SharedPtr or std::make_unique to construct the
261 | // instance, as shown in the previous example.
262 | template
263 | class SequencedChunkWriter : public SequencedChunkWriterBase {
264 | public:
265 | DECLARE_IMMOBILE_CLASS(SequencedChunkWriter);
266 |
267 | // Will write to the `Writer` provided by `dest`.
268 | explicit SequencedChunkWriter(riegeli::Initializer dest)
269 | : dest_(std::move(dest)) {
270 | Initialize();
271 | }
272 |
273 | protected:
274 | riegeli::ChunkWriter* get_writer() final { return &dest_; }
275 |
276 | private:
277 | riegeli::DefaultChunkWriter dest_;
278 | };
279 |
280 | template
281 | explicit SequencedChunkWriter(Dest&& dest)
282 | -> SequencedChunkWriter>;
283 |
284 | } // namespace array_record
285 |
286 | #endif // ARRAY_RECORD_CPP_SEQUENCED_CHUNK_WRITER_H_
287 |
--------------------------------------------------------------------------------
/cpp/sequenced_chunk_writer_test.cc:
--------------------------------------------------------------------------------
1 | /* Copyright 2022 Google LLC. All Rights Reserved.
2 |
3 | Licensed under the Apache License, Version 2.0 (the "License");
4 | you may not use this file except in compliance with the License.
5 | You may obtain a copy of the License at
6 |
7 | http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | Unless required by applicable law or agreed to in writing, software
10 | distributed under the License is distributed on an "AS IS" BASIS,
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | See the License for the specific language governing permissions and
13 | limitations under the License.
14 | ==============================================================================*/
15 |
16 | #include "cpp/sequenced_chunk_writer.h"
17 |
18 | #include
19 | #include // NOLINT(build/c++11)
20 | #include
21 | #include
22 | #include
23 | #include
24 |
25 | #include "gtest/gtest.h"
26 | #include "absl/status/status.h"
27 | #include "absl/status/statusor.h"
28 | #include "absl/strings/cord.h"
29 | #include "absl/strings/string_view.h"
30 | #include "absl/types/span.h"
31 | #include "cpp/common.h"
32 | #include "cpp/thread_pool.h"
33 | #include "riegeli/base/maker.h"
34 | #include "riegeli/bytes/chain_writer.h"
35 | #include "riegeli/bytes/cord_writer.h"
36 | #include "riegeli/bytes/string_reader.h"
37 | #include "riegeli/bytes/string_writer.h"
38 | #include "riegeli/chunk_encoding/chunk.h"
39 | #include "riegeli/chunk_encoding/compressor_options.h"
40 | #include "riegeli/chunk_encoding/constants.h"
41 | #include "riegeli/chunk_encoding/simple_encoder.h"
42 | #include "riegeli/records/record_reader.h"
43 |
44 | namespace array_record {
45 | namespace {
46 |
47 | TEST(SequencedChunkWriterTest, RvalCtorTest) {
48 | // Constructs SequencedChunkWriter by taking the ownership of the other
49 | // riegeli writer.
50 | {
51 | std::string dest;
52 | auto str_writer = riegeli::StringWriter(&dest);
53 | auto to_string = SequencedChunkWriter(std::move(str_writer));
54 | }
55 | {
56 | absl::Cord cord;
57 | auto cord_writer = riegeli::CordWriter(&cord);
58 | auto to_cord = SequencedChunkWriter(std::move(cord_writer));
59 | }
60 | {
61 | std::string dest;
62 | auto str_writer = riegeli::StringWriter(&dest);
63 | auto to_string =
64 | std::make_unique>>(
65 | std::move(str_writer));
66 | }
67 | {
68 | absl::Cord cord;
69 | auto cord_writer = riegeli::CordWriter(&cord);
70 | auto to_cord =
71 | std::make_unique>>(
72 | std::move(cord_writer));
73 | }
74 | }
75 |
76 | TEST(SequencedChunkWriterTest, DestArgsCtorTest) {
77 | // Constructs SequencedChunkWriter by forwarding the constructor arguments to
78 | // templated riegeli writer.
79 | {
80 | std::string dest;
81 | auto to_string =
82 | SequencedChunkWriter(riegeli::Maker(&dest));
83 | }
84 | {
85 | absl::Cord cord;
86 | auto to_cord =
87 | SequencedChunkWriter(riegeli::Maker(&cord));
88 | }
89 |
90 | {
91 | std::string dest;
92 | auto to_string =
93 | std::make_unique>>(
94 | riegeli::Maker(&dest));
95 | }
96 | {
97 | absl::Cord cord;
98 | auto to_cord =
99 | std::make_unique>>(
100 | riegeli::Maker(&cord));
101 | }
102 | }
103 |
104 | class TestCommitChunkCallback
105 | : public SequencedChunkWriterBase::SubmitChunkCallback {
106 | public:
107 | void operator()(uint64_t chunk_seq, uint64_t chunk_offset,
108 | uint64_t decoded_data_size, uint64_t num_records) override {
109 | chunk_offsets_.push_back(chunk_offset);
110 | }
111 | absl::Span get_chunk_offsets() const {
112 | return chunk_offsets_;
113 | }
114 |
115 | private:
116 | std::vector chunk_offsets_;
117 | };
118 |
119 | TEST(SequencedChunkWriterTest, SanityTestCodeSnippet) {
120 | std::string encoded;
121 | auto callback = TestCommitChunkCallback();
122 |
123 | auto writer = std::make_shared>>(
124 | riegeli::Maker(&encoded));
125 | writer->set_submit_chunk_callback(&callback);
126 | ASSERT_TRUE(writer->ok()) << writer->status();
127 |
128 | for (auto i : Seq(3)) {
129 | std::packaged_task()> encoding_task([i] {
130 | riegeli::Chunk chunk;
131 | riegeli::SimpleEncoder encoder(
132 | riegeli::CompressorOptions().set_uncompressed(),
133 | riegeli::SimpleEncoder::TuningOptions().set_size_hint(1));
134 | std::string text_to_encode = std::to_string(i);
135 | EXPECT_TRUE(encoder.AddRecord(absl::string_view(text_to_encode)));
136 | riegeli::ChunkType chunk_type;
137 | uint64_t decoded_data_size;
138 | uint64_t num_records;
139 | riegeli::ChainWriter chain_writer(&chunk.data);
140 | EXPECT_TRUE(encoder.EncodeAndClose(chain_writer, chunk_type, num_records,
141 | decoded_data_size));
142 | EXPECT_TRUE(chain_writer.Close());
143 | chunk.header = riegeli::ChunkHeader(chunk.data, chunk_type, num_records,
144 | decoded_data_size);
145 | return chunk;
146 | });
147 | ASSERT_TRUE(writer->CommitFutureChunk(encoding_task.get_future()));
148 | encoding_task();
149 | writer->SubmitFutureChunks(false);
150 | }
151 | // Calling SubmitFutureChunks(true) blocks the current thread until all
152 | // encoding tasks complete.
153 | EXPECT_TRUE(writer->SubmitFutureChunks(true));
154 | // Paddings should not cause any failure.
155 | EXPECT_TRUE(writer->Close());
156 |
157 | // File produced by SequencedChunkWriter should be a valid riegeli file.
158 | auto reader =
159 | riegeli::RecordReader(riegeli::Maker(encoded));
160 | ASSERT_TRUE(reader.CheckFileFormat());
161 | // Read sequentially
162 | absl::Cord result;
163 | EXPECT_TRUE(reader.ReadRecord(result));
164 | EXPECT_EQ(result, "0");
165 | EXPECT_TRUE(reader.ReadRecord(result));
166 | EXPECT_EQ(result, "1");
167 | EXPECT_TRUE(reader.ReadRecord(result));
168 | EXPECT_EQ(result, "2");
169 | EXPECT_FALSE(reader.ReadRecord(result));
170 |
171 | // We can use the chunk_offsets information to randomly access records.
172 | auto offsets = callback.get_chunk_offsets();
173 | EXPECT_TRUE(reader.Seek(offsets[1]));
174 | EXPECT_TRUE(reader.ReadRecord(result));
175 | EXPECT_EQ(result, "1");
176 | EXPECT_TRUE(reader.Seek(offsets[0]));
177 | EXPECT_TRUE(reader.ReadRecord(result));
178 | EXPECT_EQ(result, "0");
179 | EXPECT_TRUE(reader.Seek(offsets[2]));
180 | EXPECT_TRUE(reader.ReadRecord(result));
181 | EXPECT_EQ(result, "2");
182 |
183 | EXPECT_TRUE(reader.Close());
184 | }
185 |
186 | TEST(SequencedChunkWriterTest, SanityTestBadChunk) {
187 | std::string encoded;
188 | auto callback = TestCommitChunkCallback();
189 |
190 | auto writer = std::make_shared>>(
191 | riegeli::Maker(&encoded));
192 | writer->set_submit_chunk_callback(&callback);
193 | ASSERT_TRUE(writer->ok()) << writer->status();
194 | std::packaged_task()> encoding_task(
195 | [] { return absl::InternalError("On purpose"); });
196 | EXPECT_TRUE(writer->CommitFutureChunk(encoding_task.get_future()));
197 | EXPECT_TRUE(writer->SubmitFutureChunks(false));
198 | encoding_task();
199 | // We should see the error being populated even when we try to run it with the
200 | // non-blocking version.
201 | EXPECT_FALSE(writer->SubmitFutureChunks(false));
202 | EXPECT_EQ(writer->status().code(), absl::StatusCode::kInternal);
203 |
204 | EXPECT_FALSE(writer->Close());
205 | EXPECT_TRUE(callback.get_chunk_offsets().empty());
206 | }
207 |
208 | } // namespace
209 | } // namespace array_record
210 |
--------------------------------------------------------------------------------
/cpp/test_utils.cc:
--------------------------------------------------------------------------------
1 | /* Copyright 2022 Google LLC. All Rights Reserved.
2 |
3 | Licensed under the Apache License, Version 2.0 (the "License");
4 | you may not use this file except in compliance with the License.
5 | You may obtain a copy of the License at
6 |
7 | http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | Unless required by applicable law or agreed to in writing, software
10 | distributed under the License is distributed on an "AS IS" BASIS,
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | See the License for the specific language governing permissions and
13 | limitations under the License.
14 | ==============================================================================*/
15 |
16 | #include "cpp/test_utils.h"
17 |
18 | #include
19 | #include
20 |
21 | #include "cpp/common.h"
22 |
23 | namespace array_record {
24 |
25 | std::string MTRandomBytes(std::mt19937& bitgen, size_t length) {
26 | std::string result(length, '\0');
27 |
28 | size_t gen_bytes = sizeof(uint32_t);
29 | size_t rem = length % gen_bytes;
30 | std::mt19937::result_type val = bitgen();
31 | char* ptr = result.data();
32 | std::memcpy(ptr, &val, rem);
33 | ptr += rem;
34 |
35 | for (auto _ : Seq(length / gen_bytes)) {
36 | uint32_t val = bitgen();
37 | std::memcpy(ptr, &val, gen_bytes);
38 | ptr += gen_bytes;
39 | }
40 | return result;
41 | }
42 |
43 | } // namespace array_record
44 |
--------------------------------------------------------------------------------
/cpp/test_utils.h:
--------------------------------------------------------------------------------
1 | /* Copyright 2022 Google LLC. All Rights Reserved.
2 |
3 | Licensed under the Apache License, Version 2.0 (the "License");
4 | you may not use this file except in compliance with the License.
5 | You may obtain a copy of the License at
6 |
7 | http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | Unless required by applicable law or agreed to in writing, software
10 | distributed under the License is distributed on an "AS IS" BASIS,
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | See the License for the specific language governing permissions and
13 | limitations under the License.
14 | ==============================================================================*/
15 |
16 | #ifndef ARRAY_RECORD_CPP_TEST_UTILS_H_
17 | #define ARRAY_RECORD_CPP_TEST_UTILS_H_
18 |
19 | #include
20 |
21 | namespace array_record {
22 |
23 | // Generates a sequence of random bytes deterministically.
24 | std::string MTRandomBytes(std::mt19937& bitgen, size_t length);
25 |
26 | } // namespace array_record
27 |
28 | #endif // ARRAY_RECORD_CPP_TEST_UTILS_H_
29 |
--------------------------------------------------------------------------------
/cpp/test_utils_test.cc:
--------------------------------------------------------------------------------
1 | /* Copyright 2022 Google LLC. All Rights Reserved.
2 |
3 | Licensed under the Apache License, Version 2.0 (the "License");
4 | you may not use this file except in compliance with the License.
5 | You may obtain a copy of the License at
6 |
7 | http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | Unless required by applicable law or agreed to in writing, software
10 | distributed under the License is distributed on an "AS IS" BASIS,
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | See the License for the specific language governing permissions and
13 | limitations under the License.
14 | ==============================================================================*/
15 |
16 | #include "cpp/test_utils.h"
17 |
18 | #include
19 |
20 | #include "gmock/gmock.h"
21 | #include "gtest/gtest.h"
22 | #include "absl/strings/string_view.h"
23 | #include "cpp/common.h"
24 |
25 | namespace array_record {
26 | namespace {
27 |
28 | TEST(MTRandomBytesTest, ZeroLen) {
29 | std::mt19937 bitgen;
30 | auto result = MTRandomBytes(bitgen, 0);
31 | ASSERT_EQ(result.size(), 0);
32 | }
33 |
34 | TEST(MTRandomBytesTest, OneByte) {
35 | std::mt19937 bitgen, bitgen2;
36 | auto result = MTRandomBytes(bitgen, 1);
37 | ASSERT_EQ(result.size(), 1);
38 | ASSERT_NE(result[0], '\0');
39 |
40 | auto val = bitgen2();
41 | char char_val = *reinterpret_cast(&val);
42 | ASSERT_EQ(result[0], char_val);
43 | }
44 |
45 | TEST(MTRandomBytesTest, LargeVals) {
46 | constexpr size_t len = 123;
47 | std::mt19937 bitgen, bitgen2;
48 |
49 | auto result1 = MTRandomBytes(bitgen, len);
50 | auto result2 = MTRandomBytes(bitgen2, len);
51 | ASSERT_EQ(result1.size(), len);
52 | ASSERT_EQ(result2.size(), len);
53 | ASSERT_EQ(result1, result2);
54 | }
55 |
56 | } // namespace
57 |
58 | } // namespace array_record
59 |
--------------------------------------------------------------------------------
/cpp/thread_pool.cc:
--------------------------------------------------------------------------------
1 | /* Copyright 2022 Google LLC. All Rights Reserved.
2 |
3 | Licensed under the Apache License, Version 2.0 (the "License");
4 | you may not use this file except in compliance with the License.
5 | You may obtain a copy of the License at
6 |
7 | http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | Unless required by applicable law or agreed to in writing, software
10 | distributed under the License is distributed on an "AS IS" BASIS,
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | See the License for the specific language governing permissions and
13 | limitations under the License.
14 | ==============================================================================*/
15 |
16 | #include "cpp/thread_pool.h"
17 |
18 | #include "absl/flags/flag.h"
19 |
20 | ABSL_FLAG(uint32_t, array_record_global_pool_size, 64,
21 | "Number of threads for ArrayRecordGlobalPool");
22 |
23 | namespace array_record {
24 |
25 | ARThreadPool* ArrayRecordGlobalPool() {
26 | static ARThreadPool* pool_ = []() -> ARThreadPool* {
27 | ARThreadPool* pool = new
28 | Eigen::ThreadPool(absl::GetFlag(FLAGS_array_record_global_pool_size));
29 | return pool;
30 | }();
31 | return pool_;
32 | }
33 |
34 | } // namespace array_record
35 |
--------------------------------------------------------------------------------
/cpp/thread_pool.h:
--------------------------------------------------------------------------------
1 | /* Copyright 2022 Google LLC. All Rights Reserved.
2 |
3 | Licensed under the Apache License, Version 2.0 (the "License");
4 | you may not use this file except in compliance with the License.
5 | You may obtain a copy of the License at
6 |
7 | http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | Unless required by applicable law or agreed to in writing, software
10 | distributed under the License is distributed on an "AS IS" BASIS,
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | See the License for the specific language governing permissions and
13 | limitations under the License.
14 | ==============================================================================*/
15 |
16 | #ifndef ARRAY_RECORD_CPP_THREAD_POOL_H_
17 | #define ARRAY_RECORD_CPP_THREAD_POOL_H_
18 |
19 | #define EIGEN_USE_CUSTOM_THREAD_POOL
20 | #include "unsupported/Eigen/CXX11/ThreadPool"
21 |
22 | namespace array_record {
23 |
24 | using ARThreadPool = Eigen::ThreadPoolInterface;
25 |
26 | ARThreadPool* ArrayRecordGlobalPool();
27 |
28 | } // namespace array_record
29 |
30 | #endif // ARRAY_RECORD_CPP_THREAD_POOL_H_
31 |
--------------------------------------------------------------------------------
/cpp/tri_state_ptr.h:
--------------------------------------------------------------------------------
1 | /* Copyright 2024 Google LLC. All Rights Reserved.
2 |
3 | Licensed under the Apache License, Version 2.0 (the "License");
4 | you may not use this file except in compliance with the License.
5 | You may obtain a copy of the License at
6 |
7 | http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | Unless required by applicable law or agreed to in writing, software
10 | distributed under the License is distributed on an "AS IS" BASIS,
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | See the License for the specific language governing permissions and
13 | limitations under the License.
14 | ==============================================================================*/
15 |
16 | #ifndef ARRAY_RECORD_CPP_TRI_STATE_PTR_H_
17 | #define ARRAY_RECORD_CPP_TRI_STATE_PTR_H_
18 |
19 | #include
20 |
21 | #include
22 | #include
23 | #include
24 | #include
25 |
26 | #include "absl/base/thread_annotations.h"
27 | #include "absl/synchronization/mutex.h"
28 | #include "cpp/common.h"
29 |
30 | namespace array_record {
31 |
32 | /** TriStatePtr is a wrapper around a pointer that allows for a unique and
33 | * shared reference.
34 | *
35 | * There are three states:
36 | *
37 | * - NoRef: The object does not have shared or unique references.
38 | * - Sharing: The object is shared.
39 | * - Unique: The object is referenced by a unique pointer wrapper.
40 | *
41 | * The state transition from NoRef to Shared when MakeShared is called.
42 | * An internal refernce count is incremented when a SharedRef is created.
43 | *
44 | * SharedRef ref = MakeShared(); --
45 | * NoRef ----------------------------> Sharing / | MakeShared()
46 | * All SharedRef deallocated <--
47 | * <----------------------------
48 | *
49 | * The state can also transition to Unique when WaitAndMakeUnique is called.
50 | * We can only hold one unique reference at a time.
51 | *
52 | * UniqueRef ref = WaitAndMakeUnique();
53 | * NoRef ----------------------------> Unique
54 | * The UniqueRef is deallocated
55 | * <----------------------------
56 | *
57 | * Other than the state transition above, state transitions methods would block
58 | * until the specified state is possible. On deallocation, the destructor blocks
59 | * until the state is NoRef.
60 | *
61 | * Example usage:
62 | *
63 | * TriStatePtr main(riegeli::Maker(...));
64 | * // Create a shared reference to work on other threads.
65 | * pool->Schedule([refobj = foo_ptr.MakeShared()] {
66 | * refobj->FooMethod();
67 | * });
68 | *
69 | * // Blocks until refobj is out of scope.
70 | * auto unique_ref = main.WaitAndMakeUnique();
71 | * unique_ref->CleanupFoo();
72 | *
73 | */
74 | template
75 | class TriStatePtr {
76 | public:
77 | DECLARE_IMMOBILE_CLASS(TriStatePtr);
78 | TriStatePtr() = default;
79 |
80 | ~TriStatePtr() {
81 | absl::MutexLock l(&mu_);
82 | mu_.Await(absl::Condition(
83 | +[](State* sharing_state) { return *sharing_state == State::kNoRef; },
84 | &state_));
85 | }
86 |
87 | explicit TriStatePtr(std::unique_ptr ptr) : ptr_(std::move(ptr)) {}
88 |
89 | class SharedRef {
90 | public:
91 | SharedRef(TriStatePtr* parent) : parent_(parent) {}
92 |
93 | SharedRef(const SharedRef& other) : parent_(other.parent_) {
94 | parent_->ref_count_++;
95 | }
96 | SharedRef& operator=(const SharedRef& other) {
97 | parent_ = other.parent_;
98 | parent_->ref_count_++;
99 | return *this;
100 | }
101 |
102 | ~SharedRef() {
103 | int32_t ref_count =
104 | parent_->ref_count_.fetch_sub(1, std::memory_order_acq_rel) - 1;
105 | if (ref_count == 0) {
106 | absl::MutexLock l(&parent_->mu_);
107 | parent_->state_ = State::kNoRef;
108 | }
109 | }
110 |
111 | const BaseT& operator*() const { return *parent_->ptr_.get(); }
112 | const BaseT* operator->() const { return parent_->ptr_.get(); }
113 | BaseT& operator*() { return *parent_->ptr_.get(); }
114 | BaseT* operator->() { return parent_->ptr_.get(); }
115 |
116 | private:
117 | TriStatePtr* parent_ = nullptr;
118 | };
119 |
120 | class UniqueRef {
121 | public:
122 | DECLARE_MOVE_ONLY_CLASS(UniqueRef);
123 | UniqueRef(TriStatePtr* parent) : parent_(parent) {}
124 |
125 | ~UniqueRef() {
126 | absl::MutexLock l(&parent_->mu_);
127 | parent_->state_ = State::kNoRef;
128 | }
129 |
130 | const BaseT& operator*() const { return *parent_->ptr_.get(); }
131 | const BaseT* operator->() const { return parent_->ptr_.get(); }
132 | BaseT& operator*() { return *parent_->ptr_.get(); }
133 | BaseT* operator->() { return parent_->ptr_.get(); }
134 |
135 | private:
136 | TriStatePtr* parent_;
137 | };
138 |
139 | SharedRef MakeShared() {
140 | absl::MutexLock l(&mu_);
141 | mu_.Await(absl::Condition(
142 | +[](State* sharing_state) { return *sharing_state != State::kUnique; },
143 | &state_));
144 | state_ = State::kSharing;
145 | ref_count_++;
146 | return SharedRef(this);
147 | }
148 |
149 | UniqueRef WaitAndMakeUnique() {
150 | absl::MutexLock l(&mu_);
151 | mu_.Await(absl::Condition(
152 | +[](State* sharing_state) { return *sharing_state == State::kNoRef; },
153 | &state_));
154 | state_ = State::kUnique;
155 | return UniqueRef(this);
156 | }
157 |
158 | enum class State {
159 | kNoRef = 0,
160 | kSharing = 1,
161 | kUnique = 2,
162 | };
163 |
164 | State state() const {
165 | absl::MutexLock l(&mu_);
166 | return state_;
167 | }
168 |
169 | private:
170 | mutable absl::Mutex mu_;
171 | std::atomic_int32_t ref_count_ = 0;
172 | State state_ ABSL_GUARDED_BY(mu_) = State::kNoRef;
173 | std::unique_ptr ptr_;
174 | };
175 |
176 | } // namespace array_record
177 |
178 | #endif // ARRAY_RECORD_CPP_TRI_STATE_PTR_H_
179 |
--------------------------------------------------------------------------------
/cpp/tri_state_ptr_test.cc:
--------------------------------------------------------------------------------
1 | /* Copyright 2022 Google LLC. All Rights Reserved.
2 |
3 | Licensed under the Apache License, Version 2.0 (the "License");
4 | you may not use this file except in compliance with the License.
5 | You may obtain a copy of the License at
6 |
7 | http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | Unless required by applicable law or agreed to in writing, software
10 | distributed under the License is distributed on an "AS IS" BASIS,
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | See the License for the specific language governing permissions and
13 | limitations under the License.
14 | ==============================================================================*/
15 |
16 | #include "cpp/tri_state_ptr.h"
17 | #include
18 |
19 | #include "gtest/gtest.h"
20 | #include "absl/synchronization/notification.h"
21 | #include "cpp/common.h"
22 | #include "cpp/thread_pool.h"
23 | #include "riegeli/base/maker.h"
24 |
25 | namespace array_record {
26 | namespace {
27 |
28 | class FooBase {
29 | public:
30 | virtual ~FooBase() = default;
31 | virtual int value() const = 0;
32 | virtual void add_value(int v) = 0;
33 | virtual void mul_value(int v) = 0;
34 | };
35 |
36 | class Foo : public FooBase {
37 | public:
38 | explicit Foo(int v) : value_(v) {}
39 | DECLARE_MOVE_ONLY_CLASS(Foo);
40 |
41 | int value() const override { return value_; };
42 | void add_value(int v) override { value_ += v; }
43 | void mul_value(int v) override { value_ *= v; }
44 |
45 | private:
46 | int value_;
47 | };
48 |
49 | class TriStatePtrTest : public testing::Test {
50 | public:
51 | TriStatePtrTest() : pool_(ArrayRecordGlobalPool()) {}
52 |
53 | protected:
54 | ARThreadPool* pool_;
55 | };
56 |
57 | TEST_F(TriStatePtrTest, SanityTest) {
58 | TriStatePtr foo_main(std::move(riegeli::Maker(1)));
59 | EXPECT_EQ(foo_main.state(), TriStatePtr::State::kNoRef);
60 | absl::Notification notification;
61 | {
62 | pool_->Schedule(
63 | [foo_shared = foo_main.MakeShared(), ¬ification]() mutable {
64 | notification.WaitForNotification();
65 | EXPECT_EQ(foo_shared->value(), 1);
66 | const auto second_foo_shared = foo_shared;
67 | foo_shared->add_value(1);
68 | EXPECT_EQ(second_foo_shared->value(), 2);
69 | });
70 | }
71 | EXPECT_EQ(foo_main.state(), TriStatePtr::State::kSharing);
72 | notification.Notify();
73 | auto foo_unique = foo_main.WaitAndMakeUnique();
74 | foo_unique->mul_value(3);
75 | EXPECT_EQ(foo_unique->value(), 6);
76 | EXPECT_EQ(foo_main.state(), TriStatePtr::State::kUnique);
77 | }
78 |
79 | } // namespace
80 | } // namespace array_record
81 |
--------------------------------------------------------------------------------
/oss/README.md:
--------------------------------------------------------------------------------
1 | # Steps to build a new array_record pip package
2 |
3 | 1. Update the version number in setup.py
4 |
5 | 2. In the root folder, run
6 |
7 | ```
8 | ./oss/build_whl.sh
9 | ```
10 | to use the current `python3` version. Otherwise, optionally set
11 | ```
12 | PYTHON_VERSION=3.9 ./oss/build_whl.sh
13 | ```
14 |
15 | 3. Wheels are in `all_dist/`.
16 |
--------------------------------------------------------------------------------
/oss/build.Dockerfile:
--------------------------------------------------------------------------------
1 | # Constructs the environment within which we will build the pip wheels.
2 |
3 |
4 | ARG AUDITWHEEL_PLATFORM
5 |
6 | FROM quay.io/pypa/${AUDITWHEEL_PLATFORM}
7 |
8 | ARG PYTHON_VERSION
9 | ARG PYTHON_BIN
10 | ARG BAZEL_VERSION
11 |
12 | ENV DEBIAN_FRONTEND=noninteractive
13 |
14 | RUN ulimit -n 1024 && yum install -y rsync
15 | ENV PATH="${PYTHON_BIN}:${PATH}"
16 |
17 | # Download the correct bazel version and make sure it's on path.
18 | RUN BAZEL_ARCH_SUFFIX="$(uname -m | sed s/aarch64/arm64/)" \
19 | && curl -sSL --fail -o /usr/local/bin/bazel "https://github.com/bazelbuild/bazel/releases/download/$BAZEL_VERSION/bazel-$BAZEL_VERSION-linux-$BAZEL_ARCH_SUFFIX" \
20 | && chmod a+x /usr/local/bin/bazel
21 |
22 | # Install dependencies needed for array_record.
23 | RUN --mount=type=cache,target=/root/.cache \
24 | ${PYTHON_BIN}/python -m pip install -U \
25 | absl-py \
26 | auditwheel \
27 | etils[epath] \
28 | patchelf \
29 | setuptools \
30 | twine \
31 | wheel;
32 |
33 | WORKDIR "/tmp/array_record"
--------------------------------------------------------------------------------
/oss/build_whl.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # Build wheel for the python version specified by $PYTHON_VERSION.
3 | # Optionally, can set the environment variable $PYTHON_BIN to refer to a
4 | # specific python interpreter.
5 |
6 | set -e -x
7 |
8 | if [ -z ${PYTHON_BIN} ]; then
9 | if [ -z ${PYTHON_VERSION} ]; then
10 | PYTHON_BIN=$(which python3)
11 | else
12 | PYTHON_BIN=$(which python${PYTHON_VERSION})
13 | fi
14 | fi
15 |
16 | PYTHON_MAJOR_VERSION=$(${PYTHON_BIN} -c 'import sys; print(sys.version_info.major)')
17 | PYTHON_MINOR_VERSION=$(${PYTHON_BIN} -c 'import sys; print(sys.version_info.minor)')
18 | PYTHON_VERSION="${PYTHON_MAJOR_VERSION}.${PYTHON_MINOR_VERSION}"
19 | export PYTHON_VERSION="${PYTHON_VERSION}"
20 |
21 | function write_to_bazelrc() {
22 | echo "$1" >> .bazelrc
23 | }
24 |
25 | function main() {
26 | # Remove .bazelrc if it already exists
27 | [ -e .bazelrc ] && rm .bazelrc
28 |
29 | write_to_bazelrc "build -c opt"
30 | write_to_bazelrc "build --cxxopt=-std=c++17"
31 | write_to_bazelrc "build --host_cxxopt=-std=c++17"
32 | write_to_bazelrc "build --experimental_repo_remote_exec"
33 | write_to_bazelrc "build --python_path=\"${PYTHON_BIN}\""
34 |
35 | if [ -n "${CROSSTOOL_TOP}" ]; then
36 | write_to_bazelrc "build --crosstool_top=${CROSSTOOL_TOP}"
37 | write_to_bazelrc "test --crosstool_top=${CROSSTOOL_TOP}"
38 | fi
39 |
40 | export USE_BAZEL_VERSION="${BAZEL_VERSION}"
41 | bazel clean
42 | bazel build ...
43 | bazel test --verbose_failures --test_output=errors ...
44 |
45 | DEST="/tmp/array_record/all_dist"
46 | # Create the directory, then do dirname on a non-existent file inside it to
47 | # give us an absolute paths with tilde characters resolved to the destination
48 | # directory.
49 | mkdir -p "${DEST}"
50 | echo "=== destination directory: ${DEST}"
51 |
52 | TMPDIR=$(mktemp -d -t tmp.XXXXXXXXXX)
53 |
54 | echo $(date) : "=== Using tmpdir: ${TMPDIR}"
55 | mkdir "${TMPDIR}/array_record"
56 |
57 | echo $(date) : "=== Copy array_record files"
58 |
59 | cp setup.py "${TMPDIR}"
60 | cp LICENSE "${TMPDIR}"
61 | rsync -avm -L --exclude="bazel-*/" . "${TMPDIR}/array_record"
62 | rsync -avm -L --include="*.so" --include="*_pb2.py" \
63 | --exclude="*.runfiles" --exclude="*_obj" --include="*/" --exclude="*" \
64 | bazel-bin/cpp "${TMPDIR}/array_record"
65 | rsync -avm -L --include="*.so" --include="*_pb2.py" \
66 | --exclude="*.runfiles" --exclude="*_obj" --include="*/" --exclude="*" \
67 | bazel-bin/python "${TMPDIR}/array_record"
68 |
69 | pushd ${TMPDIR}
70 | echo $(date) : "=== Building wheel"
71 | ${PYTHON_BIN} setup.py bdist_wheel --python-tag py3${PYTHON_MINOR_VERSION}
72 |
73 | if [ -n "${AUDITWHEEL_PLATFORM}" ]; then
74 | echo $(date) : "=== Auditing wheel"
75 | auditwheel repair --plat ${AUDITWHEEL_PLATFORM} -w dist dist/*.whl
76 | fi
77 |
78 | echo $(date) : "=== Listing wheel"
79 | ls -lrt dist/*.whl
80 | cp dist/*.whl "${DEST}"
81 | popd
82 |
83 | echo $(date) : "=== Output wheel file is in: ${DEST}"
84 | }
85 |
86 | main
87 |
--------------------------------------------------------------------------------
/oss/runner_common.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Builds ArrayRecord from source code located in SOURCE_DIR producing wheels
4 | # under $SOURCE_DIR/all_dist.
5 | function build_and_test_array_record_linux() {
6 | SOURCE_DIR=$1
7 |
8 | # Automatically decide which platform to build for by checking on which
9 | # platform this runs.
10 | AUDITWHEEL_PLATFORM="manylinux2014_$(uname -m)"
11 |
12 | # Using a previous version of Blaze to avoid:
13 | # https://github.com/bazelbuild/bazel/issues/8622
14 | export BAZEL_VERSION="8.0.0"
15 |
16 | # Build wheels for multiple Python minor versions.
17 | PYTHON_MAJOR_VERSION=3
18 | for PYTHON_MINOR_VERSION in 10 11 12
19 | do
20 | PYTHON_VERSION=${PYTHON_MAJOR_VERSION}.${PYTHON_MINOR_VERSION}
21 | PYTHON_BIN=/opt/python/cp${PYTHON_MAJOR_VERSION}${PYTHON_MINOR_VERSION}-cp${PYTHON_MAJOR_VERSION}${PYTHON_MINOR_VERSION}/bin
22 |
23 | # Cleanup older images.
24 | docker rmi -f array_record:${PYTHON_VERSION}
25 | docker rm -f array_record
26 |
27 | DOCKER_BUILDKIT=1 docker build --progress=plain --no-cache \
28 | --build-arg AUDITWHEEL_PLATFORM=${AUDITWHEEL_PLATFORM} \
29 | --build-arg PYTHON_VERSION=${PYTHON_VERSION} \
30 | --build-arg PYTHON_BIN=${PYTHON_BIN} \
31 | --build-arg BAZEL_VERSION=${BAZEL_VERSION} \
32 | -t array_record:${PYTHON_VERSION} - < ${SOURCE_DIR}/oss/build.Dockerfile
33 |
34 | docker run --rm -a stdin -a stdout -a stderr \
35 | --env PYTHON_BIN="${PYTHON_BIN}/python" \
36 | --env BAZEL_VERSION=${BAZEL_VERSION} \
37 | --env AUDITWHEEL_PLATFORM=${AUDITWHEEL_PLATFORM} \
38 | -v $SOURCE_DIR:/tmp/array_record \
39 | --name array_record array_record:${PYTHON_VERSION} \
40 | bash oss/build_whl.sh
41 | done
42 |
43 | ls ${SOURCE_DIR}/all_dist/*.whl
44 | }
45 |
46 | function install_and_init_pyenv {
47 | pyenv_root=${1:-$HOME/.pyenv}
48 | export PYENV_ROOT=$pyenv_root
49 | if [[ ! -d $PYENV_ROOT ]]; then
50 | echo "Installing pyenv.."
51 | git clone https://github.com/pyenv/pyenv.git "$PYENV_ROOT"
52 | export PATH="/home/kbuilder/.local/bin:$PYENV_ROOT/bin:$PATH"
53 | eval "$(pyenv init --path)"
54 | fi
55 |
56 | echo "Python setup..."
57 | pyenv install -s "$PYENV_PYTHON_VERSION"
58 | pyenv global "$PYENV_PYTHON_VERSION"
59 | PYTHON=$(pyenv which python)
60 | }
61 |
62 | function setup_env_vars_py310 {
63 | # This controls the python binary to use.
64 | PYTHON=python3.10
65 | PYTHON_STR=python3.10
66 | PYTHON_MAJOR_VERSION=3
67 | PYTHON_MINOR_VERSION=10
68 | # This is for pyenv install.
69 | PYENV_PYTHON_VERSION=3.10.13
70 | }
71 |
72 | function update_bazel_macos {
73 | BAZEL_VERSION=$1
74 | ARCH="$(uname -m)"
75 | curl -L https://github.com/bazelbuild/bazel/releases/download/${BAZEL_VERSION}/bazel-${BAZEL_VERSION}-installer-darwin-${ARCH}.sh -O
76 | ls
77 | chmod +x bazel-*.sh
78 | ./bazel-${BAZEL_VERSION}-installer-darwin-${ARCH}.sh --user
79 | rm -f ./bazel-${BAZEL_VERSION}-installer-darwin-${ARCH}.sh
80 | # Add new bazel installation to path
81 | PATH="/Users/kbuilder/bin:$PATH"
82 | }
83 |
84 | function build_and_test_array_record_macos() {
85 | SOURCE_DIR=$1
86 | # Set up Bazel.
87 | # Using a previous version of Bazel to avoid:
88 | # https://github.com/bazelbuild/bazel/issues/8622
89 | export BAZEL_VERSION="8.0.0"
90 | update_bazel_macos ${BAZEL_VERSION}
91 | bazel --version
92 |
93 | # Set up Pyenv.
94 | setup_env_vars_py310
95 | install_and_init_pyenv
96 |
97 | # Build and test ArrayRecord.
98 | cd ${SOURCE_DIR}
99 | bash ${SOURCE_DIR}/oss/build_whl.sh
100 |
101 | ls ${SOURCE_DIR}/all_dist/*.whl
102 | }
103 |
--------------------------------------------------------------------------------
/python/BUILD:
--------------------------------------------------------------------------------
1 | # Python binding for ArrayRecord
2 |
3 | load("@pybind11_bazel//:build_defs.bzl", "pybind_extension")
4 | load("@pypi//:requirements.bzl", "requirement")
5 |
6 | package(default_visibility = ["//visibility:public"])
7 |
8 | licenses(["notice"])
9 |
10 | pybind_extension(
11 | name = "array_record_module",
12 | srcs = ["array_record_module.cc"],
13 | deps = [
14 | "@abseil-cpp//absl/status",
15 | "@abseil-cpp//absl/strings",
16 | "@abseil-cpp//absl/strings:str_format",
17 | "//cpp:array_record_reader",
18 | "//cpp:array_record_writer",
19 | "//cpp:thread_pool",
20 | "@riegeli//riegeli/base:initializer",
21 | "@riegeli//riegeli/bytes:fd_reader",
22 | "@riegeli//riegeli/bytes:fd_writer",
23 | ],
24 | )
25 |
26 | py_test(
27 | name = "array_record_module_test",
28 | srcs = ["array_record_module_test.py"],
29 | data = [":array_record_module.so"],
30 | deps = [
31 | "@abseil-py//absl/testing:absltest",
32 | ],
33 | )
34 |
35 | py_library(
36 | name = "array_record_data_source",
37 | srcs = ["array_record_data_source.py"],
38 | data = [":array_record_module.so"],
39 | deps = [
40 | requirement("etils"),
41 | ],
42 | )
43 |
44 | py_test(
45 | name = "array_record_data_source_test",
46 | srcs = ["array_record_data_source_test.py"],
47 | args = ["--test_srcdir=python/testdata"],
48 | data = [
49 | ":array_record_module.so",
50 | "//python/testdata:digits.array_record-00000-of-00002",
51 | "//python/testdata:digits.array_record-00001-of-00002",
52 | ],
53 | deps = [
54 | ":array_record_data_source",
55 | "@abseil-py//absl/testing:absltest",
56 | "@abseil-py//absl/testing:flagsaver",
57 | "@abseil-py//absl/testing:parameterized",
58 | ],
59 | )
60 |
--------------------------------------------------------------------------------
/python/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/array_record/6cbfde75d747832769baffa40a8858046bfc1532/python/__init__.py
--------------------------------------------------------------------------------
/python/array_record_data_source_test.py:
--------------------------------------------------------------------------------
1 | #
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # https://www.apache.org/licenses/LICENSE-2.0
7 | #
8 | # Unless required by applicable law or agreed to in writing, software
9 | # distributed under the License is distributed on an "AS IS" BASIS,
10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 | # See the License for the specific language governing permissions and
12 | # limitations under the License.
13 |
14 | """Tests for ArrayRecord data sources."""
15 |
16 | from concurrent import futures
17 | import dataclasses
18 | import os
19 | import pathlib
20 | from unittest import mock
21 |
22 | from absl import flags
23 | from absl.testing import absltest
24 | from absl.testing import flagsaver
25 | from absl.testing import parameterized
26 |
27 | from python import array_record_data_source
28 | from python import array_record_module
29 |
30 |
31 | FLAGS = flags.FLAGS
32 |
33 |
34 | @dataclasses.dataclass
35 | class DummyFileInstruction:
36 | filename: str
37 | skip: int
38 | take: int
39 | examples_in_shard: int
40 |
41 |
42 | class ArrayRecordDataSourcesTest(absltest.TestCase):
43 |
44 | def setUp(self):
45 | super().setUp()
46 | self.testdata_dir = pathlib.Path(FLAGS.test_srcdir)
47 |
48 | def test_check_default_group_size(self):
49 | filename = os.path.join(FLAGS.test_tmpdir, "test.array_record")
50 | writer = array_record_module.ArrayRecordWriter(filename)
51 | writer.write(b"foobar")
52 | writer.close()
53 | reader = array_record_module.ArrayRecordReader(filename)
54 | with self.assertLogs(level="ERROR") as log_output:
55 | array_record_data_source._check_group_size(filename, reader)
56 | self.assertRegex(
57 | log_output.output[0],
58 | (
59 | r"File .* was created with group size 65536. Grain requires group"
60 | r" size 1 for good performance"
61 | ),
62 | )
63 |
64 | def test_check_valid_group_size(self):
65 | filename = os.path.join(FLAGS.test_tmpdir, "test.array_record")
66 | writer = array_record_module.ArrayRecordWriter(filename, "group_size:1")
67 | writer.write(b"foobar")
68 | writer.close()
69 | reader = array_record_module.ArrayRecordReader(filename)
70 |
71 | def test_check_invalid_group_size(self):
72 | filename = os.path.join(FLAGS.test_tmpdir, "test.array_record")
73 | writer = array_record_module.ArrayRecordWriter(filename, "group_size:11")
74 | writer.write(b"foobar")
75 | writer.close()
76 | reader = array_record_module.ArrayRecordReader(filename)
77 | with self.assertLogs(level="ERROR") as log_output:
78 | array_record_data_source._check_group_size(filename, reader)
79 | self.assertRegex(
80 | log_output.output[0],
81 | (
82 | r"File .* was created with group size 11. Grain requires group size"
83 | r" 1 for good performance"
84 | ),
85 | )
86 |
87 | def test_array_record_data_source_len(self):
88 | ar = array_record_data_source.ArrayRecordDataSource([
89 | self.testdata_dir / "digits.array_record-00000-of-00002",
90 | self.testdata_dir / "digits.array_record-00001-of-00002",
91 | ])
92 | self.assertLen(ar, 10)
93 |
94 | def test_array_record_data_source_iter(self):
95 | ar = array_record_data_source.ArrayRecordDataSource([
96 | self.testdata_dir / "digits.array_record-00000-of-00002",
97 | self.testdata_dir / "digits.array_record-00001-of-00002",
98 | ])
99 | digits = [b"0", b"1", b"2", b"3", b"4", b"5", b"6", b"7", b"8", b"9"]
100 | for actual, expected in zip(ar, digits):
101 | self.assertEqual(actual, expected)
102 |
103 | def test_array_record_data_source_single_path(self):
104 | indices_to_read = [0, 1, 2, 3, 4]
105 | expected_data = [b"0", b"1", b"2", b"3", b"4"]
106 | # Use a single path instead of a list of paths/file_instructions.
107 | with array_record_data_source.ArrayRecordDataSource(
108 | self.testdata_dir / "digits.array_record-00000-of-00002"
109 | ) as ar:
110 | actual_data = [ar[x] for x in indices_to_read]
111 | self.assertEqual(expected_data, actual_data)
112 | self.assertTrue(all(reader is None for reader in ar._readers))
113 |
114 | def test_array_record_data_source_string_read_instructions(self):
115 | indices_to_read = [0, 1, 2, 3, 4]
116 | expected_data = [b"0", b"1", b"2", b"7", b"8"]
117 | # Use a single path instead of a list of paths/file_instructions.
118 | ar = array_record_data_source.ArrayRecordDataSource([
119 | self.testdata_dir / "digits.array_record-00000-of-00002[0:3]",
120 | self.testdata_dir / "digits.array_record-00001-of-00002[2:4]",
121 | ])
122 | self.assertLen(ar, 5)
123 | actual_data = [ar[x] for x in indices_to_read]
124 | self.assertEqual(expected_data, actual_data)
125 |
126 | def test_array_record_data_source_reverse_order(self):
127 | indices_to_read = [9, 8, 7, 6, 5, 4, 3, 2, 1, 0]
128 | expected_data = [b"9", b"8", b"7", b"6", b"5", b"4", b"3", b"2", b"1", b"0"]
129 | with array_record_data_source.ArrayRecordDataSource([
130 | self.testdata_dir / "digits.array_record-00000-of-00002",
131 | self.testdata_dir / "digits.array_record-00001-of-00002",
132 | ]) as ar:
133 | actual_data = [ar[x] for x in indices_to_read]
134 | self.assertEqual(expected_data, actual_data)
135 | self.assertTrue(all(reader is None for reader in ar._readers))
136 |
137 | def test_array_record_data_source_random_order(self):
138 | # some random permutation
139 | indices_to_read = [3, 0, 5, 9, 2, 1, 4, 7, 8, 6]
140 | expected_data = [b"3", b"0", b"5", b"9", b"2", b"1", b"4", b"7", b"8", b"6"]
141 | with array_record_data_source.ArrayRecordDataSource([
142 | self.testdata_dir / "digits.array_record-00000-of-00002",
143 | self.testdata_dir / "digits.array_record-00001-of-00002",
144 | ]) as ar:
145 | actual_data = [ar[x] for x in indices_to_read]
146 | self.assertEqual(expected_data, actual_data)
147 | self.assertTrue(all(reader is None for reader in ar._readers))
148 |
149 | def test_array_record_data_source_random_order_batched(self):
150 | # some random permutation
151 | indices_to_read = [3, 0, 5, 9, 2, 1, 4, 7, 8, 6]
152 | expected_data = [b"3", b"0", b"5", b"9", b"2", b"1", b"4", b"7", b"8", b"6"]
153 | with array_record_data_source.ArrayRecordDataSource([
154 | self.testdata_dir / "digits.array_record-00000-of-00002",
155 | self.testdata_dir / "digits.array_record-00001-of-00002",
156 | ]) as ar:
157 | actual_data = ar.__getitems__(indices_to_read)
158 | self.assertEqual(expected_data, actual_data)
159 | self.assertTrue(all(reader is None for reader in ar._readers))
160 |
161 | def test_array_record_data_source_file_instructions(self):
162 | file_instruction_one = DummyFileInstruction(
163 | filename=os.fspath(
164 | self.testdata_dir / "digits.array_record-00000-of-00002"
165 | ),
166 | skip=2,
167 | take=1,
168 | examples_in_shard=3,
169 | )
170 |
171 | file_instruction_two = DummyFileInstruction(
172 | filename=os.fspath(
173 | self.testdata_dir / "digits.array_record-00001-of-00002"
174 | ),
175 | skip=2,
176 | take=2,
177 | examples_in_shard=99,
178 | )
179 |
180 | indices_to_read = [0, 1, 2]
181 | expected_data = [b"2", b"7", b"8"]
182 |
183 | with array_record_data_source.ArrayRecordDataSource(
184 | [file_instruction_one, file_instruction_two]
185 | ) as ar:
186 | self.assertLen(ar, 3)
187 | actual_data = [ar[x] for x in indices_to_read]
188 |
189 | self.assertEqual(expected_data, actual_data)
190 | self.assertTrue(all(reader is None for reader in ar._readers))
191 |
192 | def test_array_record_source_reader_idx_and_position(self):
193 | file_instructions = [
194 | # 2 records
195 | DummyFileInstruction(
196 | filename="file_1", skip=0, take=2, examples_in_shard=2
197 | ),
198 | # 3 records
199 | DummyFileInstruction(
200 | filename="file_2", skip=2, take=3, examples_in_shard=99
201 | ),
202 | # 1 record
203 | DummyFileInstruction(
204 | filename="file_3", skip=10, take=1, examples_in_shard=99
205 | ),
206 | ]
207 |
208 | expected_indices_and_positions = [
209 | (0, 0),
210 | (0, 1),
211 | (1, 2),
212 | (1, 3),
213 | (1, 4),
214 | (2, 10),
215 | ]
216 |
217 | with array_record_data_source.ArrayRecordDataSource(
218 | file_instructions
219 | ) as ar:
220 | self.assertLen(ar, 6)
221 | for record_key in range(len(ar)):
222 | self.assertEqual(
223 | expected_indices_and_positions[record_key],
224 | ar._reader_idx_and_position(record_key),
225 | )
226 |
227 | def test_array_record_source_reader_idx_and_position_negative_idx(self):
228 | with array_record_data_source.ArrayRecordDataSource([
229 | self.testdata_dir / "digits.array_record-00000-of-00002",
230 | self.testdata_dir / "digits.array_record-00001-of-00002",
231 | ]) as ar:
232 | with self.assertRaises(ValueError):
233 | ar._reader_idx_and_position(-1)
234 |
235 | with self.assertRaises(ValueError):
236 | ar._reader_idx_and_position(len(ar))
237 |
238 | def test_array_record_source_empty_sequence(self):
239 | with self.assertRaises(ValueError):
240 | with array_record_data_source.ArrayRecordDataSource([]):
241 | pass
242 |
243 | def test_repr(self):
244 | ar = array_record_data_source.ArrayRecordDataSource([
245 | self.testdata_dir / "digits.array_record-00000-of-00002",
246 | self.testdata_dir / "digits.array_record-00001-of-00002",
247 | ])
248 | self.assertRegex(repr(ar), r"ArrayRecordDataSource\(hash_of_paths=[\w]+\)")
249 |
250 |
251 | class RunInParallelTest(parameterized.TestCase):
252 |
253 | def test_the_function_is_executed_with_kwargs(self):
254 | function = mock.Mock(return_value="return value")
255 | list_of_kwargs_to_function = [
256 | {"foo": 1},
257 | {"bar": 2},
258 | ]
259 | result = array_record_data_source._run_in_parallel(
260 | function=function,
261 | list_of_kwargs_to_function=list_of_kwargs_to_function,
262 | num_workers=1,
263 | )
264 | self.assertEqual(result, ["return value", "return value"])
265 | self.assertEqual(function.call_count, 2)
266 | function.assert_has_calls([mock.call(foo=1), mock.call(bar=2)])
267 |
268 | def test_exception_is_re_raised(self):
269 | function = mock.Mock()
270 | side_effect = ["return value", ValueError("Raised!")]
271 | function.side_effect = side_effect
272 | list_of_kwargs_to_function = [
273 | {"foo": 1},
274 | {"bar": 2},
275 | ]
276 | self.assertEqual(len(side_effect), len(list_of_kwargs_to_function))
277 | with self.assertRaisesRegex(ValueError, "Raised!"):
278 | array_record_data_source._run_in_parallel(
279 | function=function,
280 | list_of_kwargs_to_function=list_of_kwargs_to_function,
281 | num_workers=1,
282 | )
283 |
284 | @parameterized.parameters([(-2,), (-1,), (0,)])
285 | def test_num_workers_cannot_be_null_or_negative(self, num_workers):
286 | function = mock.Mock(return_value="return value")
287 | list_of_kwargs_to_function = [
288 | {"foo": 1},
289 | {"bar": 2},
290 | ]
291 | with self.assertRaisesRegex(
292 | ValueError, "num_workers must be >=1 for parallelism."
293 | ):
294 | array_record_data_source._run_in_parallel(
295 | function=function,
296 | list_of_kwargs_to_function=list_of_kwargs_to_function,
297 | num_workers=num_workers,
298 | )
299 |
300 | def test_num_workers_is_passed_to_thread_executor(self):
301 | function = mock.Mock(return_value="return value")
302 | list_of_kwargs_to_function = [
303 | {"foo": 1},
304 | {"bar": 2},
305 | ]
306 | num_workers = 42
307 | with mock.patch(
308 | "concurrent.futures.ThreadPoolExecutor",
309 | wraps=futures.ThreadPoolExecutor,
310 | ) as executor:
311 | array_record_data_source._run_in_parallel(
312 | function=function,
313 | list_of_kwargs_to_function=list_of_kwargs_to_function,
314 | num_workers=num_workers,
315 | )
316 | executor.assert_called_with(num_workers)
317 |
318 |
319 | if __name__ == "__main__":
320 | absltest.main()
321 |
--------------------------------------------------------------------------------
/python/array_record_module.cc:
--------------------------------------------------------------------------------
1 | /* Copyright 2022 Google LLC. All Rights Reserved.
2 |
3 | Licensed under the Apache License, Version 2.0 (the "License");
4 | you may not use this file except in compliance with the License.
5 | You may obtain a copy of the License at
6 |
7 | http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | Unless required by applicable law or agreed to in writing, software
10 | distributed under the License is distributed on an "AS IS" BASIS,
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | See the License for the specific language governing permissions and
13 | limitations under the License.
14 | ==============================================================================*/
15 |
16 | #include
17 | #include
18 |
19 | #include
20 | #include
21 | #include
22 | #include
23 |
24 | #include "absl/status/status.h"
25 | #include "absl/strings/str_format.h"
26 | #include "absl/strings/string_view.h"
27 | #include "cpp/array_record_reader.h"
28 | #include "cpp/array_record_writer.h"
29 | #include "cpp/thread_pool.h"
30 | #include "pybind11/gil.h"
31 | #include "pybind11/pybind11.h"
32 | #include "pybind11/pytypes.h"
33 | #include "pybind11/stl.h"
34 | #include "riegeli/base/maker.h"
35 | #include "riegeli/bytes/fd_reader.h"
36 | #include "riegeli/bytes/fd_writer.h"
37 |
38 | namespace py = pybind11;
39 |
40 | PYBIND11_MODULE(array_record_module, m) {
41 | using array_record::ArrayRecordReaderBase;
42 | using array_record::ArrayRecordWriterBase;
43 |
44 | py::class_(m, "ArrayRecordWriter")
45 | .def(py::init([](const std::string& path, const std::string& options,
46 | const py::kwargs& kwargs) -> ArrayRecordWriterBase* {
47 | auto status_or_option =
48 | ArrayRecordWriterBase::Options::FromString(options);
49 | if (!status_or_option.ok()) {
50 | throw py::value_error(
51 | std::string(status_or_option.status().message()));
52 | }
53 | // Release the GIL because IO is time consuming.
54 | py::gil_scoped_release scoped_release;
55 | return new array_record::ArrayRecordWriter(
56 | riegeli::Maker(path),
57 | status_or_option.value());
58 | }),
59 | py::arg("path"), py::arg("options") = "")
60 | .def("ok", &ArrayRecordWriterBase::ok)
61 | .def("close",
62 | [](ArrayRecordWriterBase& writer) {
63 | if (!writer.Close()) {
64 | throw std::runtime_error(std::string(writer.status().message()));
65 | }
66 | })
67 | .def("is_open", &ArrayRecordWriterBase::is_open)
68 | // We accept only py::bytes (and not unicode strings) since we expect
69 | // most users to store binary data (e.g. serialized protocol buffers).
70 | // We cannot know if a users wants to write+read unicode and forcing users
71 | // to encode() their unicode strings avoids accidental conversions.
72 | .def("write", [](ArrayRecordWriterBase& writer, py::bytes record) {
73 | if (!writer.WriteRecord(record)) {
74 | throw std::runtime_error(std::string(writer.status().message()));
75 | }
76 | });
77 | py::class_(m, "ArrayRecordReader")
78 | .def(py::init([](const std::string& path, const std::string& options,
79 | const py::kwargs& kwargs) -> ArrayRecordReaderBase* {
80 | auto status_or_option =
81 | ArrayRecordReaderBase::Options::FromString(options);
82 | if (!status_or_option.ok()) {
83 | throw py::value_error(
84 | std::string(status_or_option.status().message()));
85 | }
86 | riegeli::FdReaderBase::Options file_reader_options;
87 | if (kwargs.contains("file_reader_buffer_size")) {
88 | auto file_reader_buffer_size =
89 | kwargs["file_reader_buffer_size"].cast();
90 | file_reader_options.set_buffer_size(file_reader_buffer_size);
91 | }
92 | // Release the GIL because IO is time consuming.
93 | py::gil_scoped_release scoped_release;
94 | return new array_record::ArrayRecordReader(
95 | riegeli::Maker(
96 | path, std::move(file_reader_options)),
97 | status_or_option.value(),
98 | array_record::ArrayRecordGlobalPool());
99 | }),
100 | py::arg("path"), py::arg("options") = "", R"(
101 | ArrayRecordReader for fast sequential or random access.
102 |
103 | Args:
104 | path: File path to the input file.
105 | options: String with options for ArrayRecord. See syntax below.
106 | Kwargs:
107 | file_reader_buffer_size: Optional size of the buffer (in bytes)
108 | for the underlying file (Riegeli) reader. The default buffer
109 | size is 1 MiB.
110 | file_options: Optional file::Options to use for the underlying
111 | file (Riegeli) reader.
112 |
113 | options ::= option? ("," option?)*
114 | option ::=
115 | "readahead_buffer_size" ":" readahead_buffer_size |
116 | "max_parallelism" ":" max_parallelism
117 | readahead_buffer_size ::= non-negative integer expressed as real with
118 | optional suffix [BkKMGTPE]. (Default 16MB). Set to 0 optimizes
119 | random access performance.
120 | max_parallelism ::= `auto` or non-negative integer. Each parallel
121 | thread owns its readhaed buffer with the size
122 | `readahead_buffer_size`. (Default thread pool size) Set to 0
123 | optimizes random access performance.
124 |
125 | The default option is optimized for sequential access. To optimize
126 | the random access performance, set the options to
127 | "readahead_buffer_size:0,max_parallelism:0".
128 | )")
129 | .def("ok", &ArrayRecordReaderBase::ok)
130 | .def("close",
131 | [](ArrayRecordReaderBase& reader) {
132 | if (!reader.Close()) {
133 | throw std::runtime_error(std::string(reader.status().message()));
134 | }
135 | })
136 | .def("is_open", &ArrayRecordReaderBase::is_open)
137 | .def("num_records", &ArrayRecordReaderBase::NumRecords)
138 | .def("record_index", &ArrayRecordReaderBase::RecordIndex)
139 | .def("writer_options_string", &ArrayRecordReaderBase::WriterOptionsString)
140 | .def("seek",
141 | [](ArrayRecordReaderBase& reader, int64_t record_index) {
142 | if (!reader.SeekRecord(record_index)) {
143 | throw std::runtime_error(std::string(reader.status().message()));
144 | }
145 | })
146 | // See write() for why this returns py::bytes.
147 | .def("read",
148 | [](ArrayRecordReaderBase& reader) {
149 | absl::string_view string_view;
150 | if (!reader.ReadRecord(&string_view)) {
151 | if (reader.ok()) {
152 | throw std::out_of_range(absl::StrFormat(
153 | "Out of range of num_records: %d", reader.NumRecords()));
154 | }
155 | throw std::runtime_error(std::string(reader.status().message()));
156 | }
157 | return py::bytes(string_view);
158 | })
159 | .def("read",
160 | [](ArrayRecordReaderBase& reader, std::vector indices) {
161 | std::vector staging(indices.size());
162 | py::list output(indices.size());
163 | {
164 | py::gil_scoped_release scoped_release;
165 | auto status = reader.ParallelReadRecordsWithIndices(
166 | indices,
167 | [&](uint64_t indices_index,
168 | absl::string_view record_data) -> absl::Status {
169 | staging[indices_index] = record_data;
170 | return absl::OkStatus();
171 | });
172 | if (!status.ok()) {
173 | throw std::runtime_error(std::string(status.message()));
174 | }
175 | }
176 | // TODO(fchern): Can we write the data directly to the output
177 | // list in our Parallel loop?
178 | ssize_t index = 0;
179 | for (const auto& record : staging) {
180 | auto py_record = py::bytes(record);
181 | PyList_SET_ITEM(output.ptr(), index++,
182 | py_record.release().ptr());
183 | }
184 | return output;
185 | })
186 | .def("read",
187 | [](ArrayRecordReaderBase& reader, int32_t begin, int32_t end) {
188 | int32_t range_begin = begin, range_end = end;
189 | if (range_begin < 0) {
190 | range_begin = reader.NumRecords() + range_begin;
191 | }
192 | if (range_end < 0) {
193 | range_end = reader.NumRecords() + range_end;
194 | }
195 | if (range_begin > reader.NumRecords() || range_begin < 0 ||
196 | range_end > reader.NumRecords() || range_end < 0 ||
197 | range_end <= range_begin) {
198 | throw std::out_of_range(
199 | absl::StrFormat("[%d, %d) is of range of [0, %d)", begin,
200 | end, reader.NumRecords()));
201 | }
202 | int32_t num_to_read = range_end - range_begin;
203 | std::vector staging(num_to_read);
204 | py::list output(num_to_read);
205 | {
206 | py::gil_scoped_release scoped_release;
207 | auto status = reader.ParallelReadRecordsInRange(
208 | range_begin, range_end,
209 | [&](uint64_t index,
210 | absl::string_view record_data) -> absl::Status {
211 | staging[index - range_begin] = record_data;
212 | return absl::OkStatus();
213 | });
214 | if (!status.ok()) {
215 | throw std::runtime_error(std::string(status.message()));
216 | }
217 | }
218 | // TODO(fchern): Can we write the data directly to the output
219 | // list in our Parallel loop?
220 | ssize_t index = 0;
221 | for (const auto& record : staging) {
222 | auto py_record = py::bytes(record);
223 | PyList_SET_ITEM(output.ptr(), index++,
224 | py_record.release().ptr());
225 | }
226 | return output;
227 | })
228 | .def("read_all", [](ArrayRecordReaderBase& reader) {
229 | std::vector staging(reader.NumRecords());
230 | py::list output(reader.NumRecords());
231 | {
232 | py::gil_scoped_release scoped_release;
233 | auto status = reader.ParallelReadRecords(
234 | [&](uint64_t index,
235 | absl::string_view record_data) -> absl::Status {
236 | staging[index] = record_data;
237 | return absl::OkStatus();
238 | });
239 | if (!status.ok()) {
240 | throw std::runtime_error(std::string(status.message()));
241 | }
242 | }
243 | // TODO(fchern): Can we write the data directly to the output
244 | // list in our Parallel loop?
245 | ssize_t index = 0;
246 | for (const auto& record : staging) {
247 | auto py_record = py::bytes(record);
248 | PyList_SET_ITEM(output.ptr(), index++, py_record.release().ptr());
249 | }
250 | return output;
251 | });
252 | }
253 |
--------------------------------------------------------------------------------
/python/array_record_module_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for array_record_module."""
16 |
17 | import os
18 |
19 | from absl.testing import absltest
20 |
21 | from python.array_record_module import ArrayRecordReader
22 | from python.array_record_module import ArrayRecordWriter
23 |
24 |
25 | class ArrayRecordModuleTest(absltest.TestCase):
26 |
27 | def setUp(self):
28 | super(ArrayRecordModuleTest, self).setUp()
29 | self.test_file = os.path.join(self.create_tempdir().full_path,
30 | "test.arecord")
31 |
32 | def test_open_and_close(self):
33 | writer = ArrayRecordWriter(self.test_file)
34 | self.assertTrue(writer.ok())
35 | self.assertTrue(writer.is_open())
36 | writer.close()
37 | self.assertFalse(writer.is_open())
38 |
39 | reader = ArrayRecordReader(self.test_file)
40 | self.assertTrue(reader.ok())
41 | self.assertTrue(reader.is_open())
42 | reader.close()
43 | self.assertFalse(reader.is_open())
44 |
45 | def test_bad_options(self):
46 |
47 | def create_writer():
48 | ArrayRecordWriter(self.test_file, "blah")
49 |
50 | def create_reader():
51 | ArrayRecordReader(self.test_file, "blah")
52 |
53 | self.assertRaises(ValueError, create_writer)
54 | self.assertRaises(ValueError, create_reader)
55 |
56 | def test_write_read(self):
57 | writer = ArrayRecordWriter(self.test_file)
58 | test_strs = [b"abc", b"def", b"ghi"]
59 | for s in test_strs:
60 | writer.write(s)
61 | writer.close()
62 | reader = ArrayRecordReader(
63 | self.test_file, "readahead_buffer_size:0,max_parallelism:0"
64 | )
65 | num_strs = len(test_strs)
66 | self.assertEqual(reader.num_records(), num_strs)
67 | self.assertEqual(reader.record_index(), 0)
68 | for gt in test_strs:
69 | result = reader.read()
70 | self.assertEqual(result, gt)
71 | self.assertRaises(IndexError, reader.read)
72 | reader.seek(0)
73 | self.assertEqual(reader.record_index(), 0)
74 | self.assertEqual(reader.read(), test_strs[0])
75 | self.assertEqual(reader.record_index(), 1)
76 |
77 | def test_write_read_non_unicode(self):
78 | writer = ArrayRecordWriter(self.test_file)
79 | b = b"F\xc3\xb8\xc3\xb6\x97\xc3\xa5r"
80 | writer.write(b)
81 | writer.close()
82 | reader = ArrayRecordReader(self.test_file)
83 | self.assertEqual(reader.read(), b)
84 |
85 | def test_write_read_with_file_reader_buffer_size(self):
86 | writer = ArrayRecordWriter(self.test_file)
87 | b = b"F\xc3\xb8\xc3\xb6\x97\xc3\xa5r"
88 | writer.write(b)
89 | writer.close()
90 | reader = ArrayRecordReader(self.test_file, file_reader_buffer_size=2**10)
91 | self.assertEqual(reader.read(), b)
92 |
93 | def test_batch_read(self):
94 | writer = ArrayRecordWriter(self.test_file)
95 | test_strs = [b"abc", b"def", b"ghi", b"kkk", b"..."]
96 | for s in test_strs:
97 | writer.write(s)
98 | writer.close()
99 | reader = ArrayRecordReader(self.test_file)
100 | results = reader.read_all()
101 | self.assertEqual(test_strs, results)
102 | indices = [1, 3, 0]
103 | expected = [test_strs[i] for i in indices]
104 | batch_fetch = reader.read(indices)
105 | self.assertEqual(expected, batch_fetch)
106 |
107 | def test_read_range(self):
108 | writer = ArrayRecordWriter(self.test_file)
109 | test_strs = [b"abc", b"def", b"ghi", b"kkk", b"..."]
110 | for s in test_strs:
111 | writer.write(s)
112 | writer.close()
113 | reader = ArrayRecordReader(self.test_file)
114 |
115 | def invalid_range1():
116 | reader.read(0, 0)
117 |
118 | self.assertRaises(IndexError, invalid_range1)
119 |
120 | def invalid_range2():
121 | reader.read(0, 100)
122 |
123 | self.assertRaises(IndexError, invalid_range2)
124 |
125 | def invalid_range3():
126 | reader.read(3, 2)
127 |
128 | self.assertRaises(IndexError, invalid_range3)
129 |
130 | self.assertEqual(reader.read(0, -1), test_strs[0:-1])
131 | self.assertEqual(reader.read(-3, -1), test_strs[-3:-1])
132 | self.assertEqual(reader.read(1, 3), test_strs[1:3])
133 |
134 | def test_writer_options(self):
135 | writer = ArrayRecordWriter(self.test_file, "group_size:42")
136 | writer.write(b"test123")
137 | writer.close()
138 | reader = ArrayRecordReader(self.test_file)
139 | # Includes default options.
140 | self.assertEqual(
141 | reader.writer_options_string(),
142 | "group_size:42,transpose:false,pad_to_block_boundary:false,zstd:3,"
143 | "window_log:20,max_parallelism:1")
144 |
145 |
146 | if __name__ == "__main__":
147 | absltest.main()
148 |
--------------------------------------------------------------------------------
/python/testdata/BUILD:
--------------------------------------------------------------------------------
1 | package(default_visibility = ["//visibility:public"])
2 |
3 | licenses(["notice"])
4 |
5 | exports_files(glob(["*.array_record-*"]))
6 |
--------------------------------------------------------------------------------
/python/testdata/digits.array_record-00000-of-00002:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/array_record/6cbfde75d747832769baffa40a8858046bfc1532/python/testdata/digits.array_record-00000-of-00002
--------------------------------------------------------------------------------
/python/testdata/digits.array_record-00001-of-00002:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/array_record/6cbfde75d747832769baffa40a8858046bfc1532/python/testdata/digits.array_record-00001-of-00002
--------------------------------------------------------------------------------
/requirements.in:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The ArrayRecord Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # This is the list of our Python third_party dependencies that Bazel should
16 | # pull from PyPi.
17 | # Note that requirements.txt must be re-generated using
18 | # bazel run //:requirements.update in the OSS version.
19 |
20 | etils[epath]
21 |
--------------------------------------------------------------------------------
/requirements_lock.txt:
--------------------------------------------------------------------------------
1 | #
2 | # This file is autogenerated by pip-compile with Python 3.10
3 | # by the following command:
4 | #
5 | # bazel run //:requirements.update
6 | #
7 | etils[epath,epy]==1.11.0 \
8 | --hash=sha256:a394cf3476bcec51c221426a70c39cd1006e889456ba41e4d7f12fd6814be7a5 \
9 | --hash=sha256:aff3278a3be7fddf302dfd80335e9f924244666c71239cd91e836f3d055f1c4a
10 | # via -r requirements.in
11 | fsspec==2024.12.0 \
12 | --hash=sha256:670700c977ed2fb51e0d9f9253177ed20cbde4a3e5c0283cc5385b5870c8533f \
13 | --hash=sha256:b520aed47ad9804237ff878b504267a3b0b441e97508bd6d2d8774e3db85cee2
14 | # via etils
15 | importlib-resources==6.4.5 \
16 | --hash=sha256:980862a1d16c9e147a59603677fa2aa5fd82b87f223b6cb870695bcfce830065 \
17 | --hash=sha256:ac29d5f956f01d5e4bb63102a5a19957f1b9175e45649977264a1416783bb717
18 | # via etils
19 | typing-extensions==4.12.2 \
20 | --hash=sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d \
21 | --hash=sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8
22 | # via etils
23 | zipp==3.21.0 \
24 | --hash=sha256:2c9958f6430a2040341a52eb608ed6dd93ef4392e02ffe219417c1b28b5dd1f4 \
25 | --hash=sha256:ac1bbe05fd2991f160ebce24ffbac5f6d11d83dc90891255885223d42b3cd931
26 | # via etils
27 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | """Setup.py file for array_record."""
2 |
3 | from setuptools import find_packages
4 | from setuptools import setup
5 | from setuptools.dist import Distribution
6 |
7 | REQUIRED_PACKAGES = [
8 | 'absl-py',
9 | 'etils[epath]',
10 | ]
11 |
12 | BEAM_EXTRAS = [
13 | 'apache-beam[gcp]==2.53.0',
14 | 'google-cloud-storage>=2.11.0',
15 | 'tensorflow>=2.14.0'
16 | ]
17 |
18 |
19 | class BinaryDistribution(Distribution):
20 | """This class makes 'bdist_wheel' include an ABI tag on the wheel."""
21 |
22 | def has_ext_modules(self):
23 | return True
24 |
25 |
26 | setup(
27 | name='array_record',
28 | version='0.6.0',
29 | description='A file format that achieves a new frontier of IO efficiency',
30 | author='ArrayRecord team',
31 | author_email='no-reply@google.com',
32 | packages=find_packages(),
33 | include_package_data=True,
34 | package_data={'': ['*.so']},
35 | python_requires='>=3.10',
36 | install_requires=REQUIRED_PACKAGES,
37 | extras_require={'beam': BEAM_EXTRAS},
38 | url='https://github.com/google/array_record',
39 | license='Apache-2.0',
40 | classifiers=[
41 | 'Programming Language :: Python :: 3.10',
42 | 'Programming Language :: Python :: 3.11',
43 | 'Programming Language :: Python :: 3.12',
44 | ],
45 | zip_safe=False,
46 | distclass=BinaryDistribution,
47 | )
48 |
--------------------------------------------------------------------------------