├── .bazelrc ├── .github └── workflows │ └── python-tests.yml ├── BUILD ├── CONTRIBUTING.md ├── LICENSE ├── MODULE.bazel ├── README.md ├── __init__.py ├── beam ├── README.md ├── __init__.py ├── arrayrecordio.py ├── demo.py ├── dofns.py ├── example.py ├── examples │ ├── example_full_demo_cli.sh │ ├── example_gcs_conversion.py │ ├── example_sink_conversion.py │ └── requirements.txt ├── options.py ├── pipelines.py └── testdata.py ├── cpp ├── BUILD ├── array_record_reader.cc ├── array_record_reader.h ├── array_record_reader_test.cc ├── array_record_writer.cc ├── array_record_writer.h ├── array_record_writer_test.cc ├── common.h ├── layout.proto ├── masked_reader.cc ├── masked_reader.h ├── masked_reader_test.cc ├── parallel_for.h ├── parallel_for_test.cc ├── sequenced_chunk_writer.cc ├── sequenced_chunk_writer.h ├── sequenced_chunk_writer_test.cc ├── test_utils.cc ├── test_utils.h ├── test_utils_test.cc ├── thread_pool.cc ├── thread_pool.h ├── tri_state_ptr.h └── tri_state_ptr_test.cc ├── oss ├── README.md ├── build.Dockerfile ├── build_whl.sh └── runner_common.sh ├── python ├── BUILD ├── __init__.py ├── array_record_data_source.py ├── array_record_data_source_test.py ├── array_record_module.cc ├── array_record_module_test.py └── testdata │ ├── BUILD │ ├── digits.array_record-00000-of-00002 │ └── digits.array_record-00001-of-00002 ├── requirements.in ├── requirements_lock.txt └── setup.py /.bazelrc: -------------------------------------------------------------------------------- 1 | build -c opt 2 | build --cxxopt=-std=c++17 3 | build --host_cxxopt=-std=c++17 4 | build --experimental_repo_remote_exec 5 | 6 | # TODO(fchern): Use non-hardcode path. 7 | build --action_env=PYTHON_BIN_PATH="/usr/bin/python3" 8 | build --action_env=PYTHON_LIB_PATH="/usr/lib/python3" 9 | build --repo_env=PYTHON_BIN_PATH="/usr/bin/python3" 10 | build --python_path="/usr/bin/python3" 11 | -------------------------------------------------------------------------------- /.github/workflows/python-tests.yml: -------------------------------------------------------------------------------- 1 | name: Build and test 2 | 3 | on: 4 | pull_request: 5 | branches: [main] 6 | 7 | jobs: 8 | test: 9 | runs-on: ubuntu-latest 10 | strategy: 11 | matrix: 12 | python-version: ['3.9', '3.10', '3.11'] 13 | env: 14 | DOCKER_BUILDKIT: 1 15 | TMP_FOLDER: /tmp/array_record 16 | steps: 17 | - uses: actions/checkout@v2 18 | - name: Build Docker image 19 | run: | 20 | docker build --progress=plain --no-cache \ 21 | --build-arg PYTHON_VERSION=${{ matrix.python-version }} \ 22 | -t array_record:latest - < oss/build.Dockerfile 23 | - name: Build wheels and test 24 | run: | 25 | docker run --rm -a stdin -a stdout -a stderr \ 26 | --env PYTHON_VERSION=${{ matrix.python-version }} \ 27 | --volume ${GITHUB_WORKSPACE}:${TMP_FOLDER} --name array_record array_record:latest \ 28 | bash oss/build_whl.sh 29 | - name: Install in a blank Docker and test the import in Python 30 | run: | 31 | docker run --rm -a stdin -a stdout -a stderr \ 32 | --env PYTHON_VERSION=${{ matrix.python-version }} \ 33 | --volume ${GITHUB_WORKSPACE}:/root \ 34 | python:${{ matrix.python-version }} bash -c " 35 | ARRAY_RECORD_VERSION=\$(python /root/setup.py --version 2>&1 /dev/null) 36 | SHORT_PYTHON_VERSION=\${PYTHON_VERSION//./} 37 | ARRAY_RECORD_WHEEL=\"/root/all_dist/array_record-\${ARRAY_RECORD_VERSION}-py\${SHORT_PYTHON_VERSION}-none-any.whl\" 38 | python -m pip install \${ARRAY_RECORD_WHEEL} && 39 | python -c 'import array_record' && 40 | python -c 'from array_record.python import array_record_data_source' 41 | " 42 | -------------------------------------------------------------------------------- /BUILD: -------------------------------------------------------------------------------- 1 | # ArrayRecord is a new file format for IO intensive applications. 2 | # It supports efficient random access and various compression algorithms. 3 | 4 | load("@rules_python//python:pip.bzl", "compile_pip_requirements") 5 | 6 | 7 | package(default_visibility = ["//visibility:public"]) 8 | 9 | licenses(["notice"]) 10 | 11 | exports_files(["LICENSE"]) 12 | 13 | py_library( 14 | name = "setup", 15 | srcs = ["setup.py"], 16 | ) 17 | 18 | compile_pip_requirements( 19 | name = "requirements", 20 | requirements_in = "requirements.in", 21 | requirements_txt = "requirements_lock.txt", 22 | ) 23 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement (CLA). You (or your employer) retain the copyright to your 10 | contribution; this simply gives us permission to use and redistribute your 11 | contributions as part of the project. Head over to 12 | to see your current agreements on file or 13 | to sign a new one. 14 | 15 | You generally only need to submit a CLA once, so if you've already submitted one 16 | (even if it was for a different project), you probably don't need to do it 17 | again. 18 | 19 | ## Code Reviews 20 | 21 | All submissions, including submissions by project members, require review. We 22 | use GitHub pull requests for this purpose. Consult 23 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 24 | information on using pull requests. 25 | 26 | ## Community Guidelines 27 | 28 | This project follows 29 | [Google's Open Source Community Guidelines](https://opensource.google/conduct/). 30 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /MODULE.bazel: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The ArrayRecord Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # TODO(fchern): automate version string alignment with setup.py 16 | VERSION = "0.6.0" 17 | 18 | module( 19 | name = "array_record", 20 | version = VERSION, 21 | repo_name = "com_google_array_record", 22 | ) 23 | 24 | bazel_dep(name = "rules_proto", version = "7.0.2") 25 | bazel_dep(name = "rules_python", version = "0.40.0") 26 | bazel_dep(name = "platforms", version = "0.0.10") 27 | bazel_dep(name = "protobuf", version = "29.0") 28 | bazel_dep(name = "googletest", version = "1.15.2") 29 | bazel_dep(name = "abseil-cpp", version = "20240722.0") 30 | bazel_dep(name = "abseil-py", version = "2.1.0") 31 | bazel_dep(name = "eigen", version = "3.4.0.bcr.2") 32 | bazel_dep(name = "riegeli", version = "0.0.0-20241218-3385e3c") 33 | bazel_dep(name = "pybind11_bazel", version = "2.12.0") 34 | 35 | PYTHON_VERSION = "3.10" 36 | 37 | python = use_extension("@rules_python//python/extensions:python.bzl", "python") 38 | python.toolchain( 39 | ignore_root_user_error = True, # Required for our containerized CI environments. 40 | python_version = PYTHON_VERSION, 41 | ) 42 | 43 | pip = use_extension("@rules_python//python/extensions:pip.bzl", "pip") 44 | 45 | # requirements_lock.txt is generated by 46 | # bazel run //:requirements.update 47 | pip.parse( 48 | hub_name = "pypi", 49 | python_version = PYTHON_VERSION, 50 | requirements_lock = "//:requirements_lock.txt", 51 | ) 52 | use_repo(pip, "pypi") 53 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ArrayRecord 2 | 3 | ArrayRecord is a new file format derived from 4 | [Riegeli](https://github.com/google/riegeli), achieving a new 5 | frontier of IO efficiency. We designed ArrayRecord to support parallel read, 6 | write, and random access by record index. 7 | ArrayRecord builds on top of Riegeli and supports the same compression 8 | algorithms. 9 | 10 | 11 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/array_record/6cbfde75d747832769baffa40a8858046bfc1532/__init__.py -------------------------------------------------------------------------------- /beam/README.md: -------------------------------------------------------------------------------- 1 | ## Apache Beam Integration for ArrayRecord 2 | 3 | ### Quickstart 4 | 5 | #### Convert TFRecord in a GCS bucket to ArrayRecord 6 | ``` 7 | pip install apache-beam[gcp]==2.53.0 8 | pip install array-record[beam] 9 | # check that apache-beam is still at 2.53.0 10 | pip show apache-beam 11 | git clone https://github.com/google/array_record.git 12 | cd array_record/beam/examples 13 | # Fill in the required fields in example_gcs_conversion.py 14 | # If use DataFlow, set pipeline_options as instructed in example_gcs_conversion.py 15 | python example_gcs_conversion.py 16 | ``` 17 | If DataFlow is used, you can monitor the run from the DataFlow job monitoring UI (https://cloud.google.com/dataflow/docs/guides/monitoring-overview) 18 | 19 | ### Summary 20 | 21 | This submodule provides some Apache Beam components and lightweight pipelines for converting different file formats (TFRecord at present) into ArrayRecords. The intention is to provide a variety of fairly seamless tools for migrating existing TFRecord datasets, allowing a few different choices regarding sharding and write location. 22 | 23 | There are two core components of this module: 24 | 25 | 1. A Beam `PTransform` with a `FileBasedSink` for writing ArrayRecords. It's modeled after similar components like `TFRecordIO` and `FileIO`. Worth noting in this implementation is that `array_record`'s `ArrayRecordWriter` object requires a file-path-like string to initialize, and the `.close()` method is required to make the file usable. This characteristic forces the overriding of Beam's default `.open()` functionality, which is where its schema and file handling functionality is housed. In short, it means this sink is **only usable for ArrayRecord writes to disk or disk-like paths, e.g. FUSE, NFS mounts, etc.** All writes using schema prefixes (e.g. `gs://`) will fail. 26 | 27 | 2. A Beam `DoFn` that accepts a single tuple consisting of a filename key and an entire set of serialized records. The function writes the serialized content to an ArrayRecord file in an on-disk path, uploads it to a specified GCS bucket, and removes the temporary file. This function has no inherent file awareness, making its primary goal the writing of a single file per PCollection. As such, it requires the file content division logic to be provided to the function elsewhere in the Beam pipeline. 28 | 29 | In addition to these components, there are a number of simple pipelines included in this module that provide basic likely implementations of the above components. A few of those pipelines are as follows: 30 | 31 | 1. **Conversion from a set number of TFRecord files in either GCS or on-disk to a flexible number of ArrayRecords on disk:** Leverages the PTransform/Sink, and due to Beam's file handling capabilities allows for a `num_shards` argument that supports redistribution of the bounded dataset across an arbitrary number of files. However, due to overriding the `open()` method, writes to GCS don't work. 32 | 33 | 2. **Conversion from a set number of TFRecord files in either GCS or on-disk to a matching number of ArrayRecords on GCS:** Levarages the `ReadAllFromTFRecord` and `GroupByKey` Beam functions to organize a set of filename:content pairs, which are then passed to the ArrayRecord `DoFn`. The end result is that TFRecords are converted to ArrayRecords one-to-one. 34 | 35 | 3. **Conversion from a set number of TFRecord files in either GCS or on-disk to a matching number of ArrayRecords on disk:** Identical to pipeline 1, it just reads the number of shards first and sets the number of ArrayRecord shards to match. 36 | 37 | In addition to all of that, there are a handful of dummy data generation functions used for testing and validation. 38 | 39 | ### Usage 40 | 41 | **Basics and 'Getting Started'** 42 | 43 | Please note that in an attempt to keep the array_record library lightweight, Apache Beam (and some of the underlying data generation dependencies like Tensorflow) are not installed by default when you run `pip install array-record`. To get the extra packages automatically, run `pip install array-record[beam]`. 44 | 45 | Once installed, all of the Beam components are available to import from `array_record.beam`. 46 | 47 | **Importing the PTransform or the DoFn** 48 | 49 | If you're familiar with Apache Beam and want to build a custom pipeline around its core constructs, you can import the native Beam objects and implement them as you see fit. 50 | 51 | To import the PTransform with the disk-based sink, use `from array_record.beam.arrayrecordio import WriteToArrayRecord`. You may then use it as a standard step in Beam Pipeline. It accepts a variety of different inputs including `file_path_prefix`, `file_path_suffix`, `coder`, and `num_shards`. For more detail, as well as options for extensibility, please refer to [Apache Beam's Documentation for FileBasedSink](https://beam.apache.org/releases/pydoc/current/apache_beam.io.filebasedsink.html) 52 | 53 | 54 | To import the custom DoFn, use `from array_record.beam.dofns import ConvertToArrayRecordGCS`. You may then use it as a parameter for a Beam `ParDo`. It takes a handful of side inputs as described below: 55 | 56 | - **path:** REQUIRED (and positional). The intended path prefix for the GCS bucket in "gs://..." format 57 | - **overwrite_extension:** FALSE by default. Boolean making the DoFn attempt to overwrite any file extension after "." 58 | - **file_path_suffix:** ".arrayrecord" by default. Intended suffix for overwrite or append 59 | 60 | Note that by default, the DoFn will APPEND an existing filename/extension with ".arrayrecord". Setting `file_path_suffix` to `""` will leave the file names as-is and thus expect you to be passing in a different `path` than the source. 61 | 62 | You can see usage details for each of these implementations in `pipelines.py`. 63 | 64 | **Using the Helper Functions** 65 | 66 | Several helper functions have been packaged to make the functionality more accessible to those with less comfort building Apache Beam pipelines. All of these pipelines take `input` and `output` arguments, which are intended as the respective source and destination paths of the TFRecord files and the ArrayRecord files. Wildcards are accepted in these paths. By default, these parameters can either be passed as CLI arguments when executing a pipeline as `python -m --input --output `, or as an override to the `args` argument if executing programmatically. Additionally, extra arguments can be passed via CLI or programmatically in the `pipeline_options` argument if you want to control the behavior of Beam. The likely reason for this would be altering the Runner to Google Cloud Dataflow, which these examples support (with caveats; see the section below on Dataflow). 67 | 68 | There are slight variations in execution when running these either from an interpreter or the CLI, so familiarize yourself with the files in the `examples/` directory along with `demo.py`, which show the different invocation methods. The available functions can all be imported `from array_record.beam.pipelines import *` and are as follows: 69 | 70 | - **convert_tf_to_arrayrecord_disk:** Converts TFRecords at `input` path to ArrayRecords at `output` path for disk-based writes only. Accepts an extra `num_shards` argument for resharding ArrayRecords across an arbitrary number of files. 71 | - **convert_tf_to_arrayrecord_disk_match_shards:** Same as above, except it reads the number of source files and matches them to the destination. There is no `num_shards` argument. 72 | - **convert_tf_to_arrayrecord_gcs:** Converts TFRecords at `input` path to ArrayRecords at `output` path, where the `output` path **must** be a GCS bucket in "gs://" format. This function accepts the same `overwrite_extension` and `file_path_suffix` arguments as the DoFn itself, allowing for customization of file naming. 73 | 74 | ### Examples and Demos 75 | 76 | See the examples in the `examples/` directory for different invocation techniques. One of the examples invokes `array_record.beam.demo` as a module, which is a simple pipeline that generates some TFRecords and then converts them to ArrayRecord in GCS. You can see the implementation in `demo.py`, which should serve as a guide for implementing your own CLI-triggered pipelines. 77 | 78 | You'll also note commented sections in each example, which are the configuration parameters for running the pipelines on Google Cloud Dataflow. There is also a `requirements.txt` in there, which at present is a requirement for running these on Dataflow as is. See below for more detail. 79 | 80 | ### Dataflow Usage 81 | 82 | These pipelines have all been tested and are compatible with Google Cloud Dataflow. Uncomment the sections in the example files and set your own bucket/project information to see it in action. 83 | 84 | Note, however, the `requirements.txt` file. This is necessary because the `array-record` PyPl installation does not install the Apache Beam or Tensorflow components by default to keep the library lightweight. A `requirements.txt` passed as an argument to the Dataflow job is required to ensure everything is installed correctly on the runner. 85 | 86 | 87 | Allow to simmer uncovered for 5 minutes. Plate, serve, and enjoy. 88 | -------------------------------------------------------------------------------- /beam/__init__.py: -------------------------------------------------------------------------------- 1 | """Apache Beam module for array_record. 2 | 3 | This module provides both core components and 4 | helper functions to enable users to convert different file formats to AR. 5 | 6 | To keep dependencies light, we'll import Beam on module usage so any errors 7 | occur early. 8 | """ 9 | 10 | import apache_beam as beam 11 | 12 | # I'd really like a PEP8 compatible conditional import here with a more 13 | # explicit error message. Example below: 14 | 15 | # try: 16 | # import apache_beam as beam 17 | # except Exception as e: 18 | # raise ImportError( 19 | # ('Beam functionality requires extra dependencies. ' 20 | # 'Install apache-beam or run "pip install array_record[beam]".')) from e 21 | -------------------------------------------------------------------------------- /beam/arrayrecordio.py: -------------------------------------------------------------------------------- 1 | """An IO module for ArrayRecord. 2 | 3 | CURRENTLY ONLY SINK IS IMPLEMENTED, AND IT DOESN'T WORK WITH NON-DISK WRITES 4 | """ 5 | 6 | from apache_beam import io 7 | from apache_beam import transforms 8 | from apache_beam.coders import coders 9 | from apache_beam.io import filebasedsink 10 | from apache_beam.io import filesystem 11 | from array_record.python import array_record_module 12 | 13 | 14 | class _ArrayRecordSink(filebasedsink.FileBasedSink): 15 | """Sink Class for use in Arrayrecord PTransform.""" 16 | 17 | def __init__( 18 | self, 19 | file_path_prefix, 20 | file_name_suffix=None, 21 | num_shards=0, 22 | shard_name_template=None, 23 | coder=coders.ToBytesCoder(), 24 | compression_type=filesystem.CompressionTypes.AUTO): 25 | 26 | super().__init__( 27 | file_path_prefix, 28 | file_name_suffix=file_name_suffix, 29 | num_shards=num_shards, 30 | shard_name_template=shard_name_template, 31 | coder=coder, 32 | mime_type='application/octet-stream', 33 | compression_type=compression_type) 34 | 35 | def open(self, temp_path): 36 | array_writer = array_record_module.ArrayRecordWriter( 37 | temp_path, 'group_size:1' 38 | ) 39 | return array_writer 40 | 41 | def close(self, file_handle): 42 | file_handle.close() 43 | 44 | def write_encoded_record(self, file_handle, value): 45 | file_handle.write(value) 46 | 47 | 48 | class WriteToArrayRecord(transforms.PTransform): 49 | """PTransform for a disk-based write to ArrayRecord.""" 50 | 51 | def __init__( 52 | self, 53 | file_path_prefix, 54 | file_name_suffix='', 55 | num_shards=0, 56 | shard_name_template=None, 57 | coder=coders.ToBytesCoder(), 58 | compression_type=filesystem.CompressionTypes.AUTO): 59 | 60 | self._sink = _ArrayRecordSink( 61 | file_path_prefix, 62 | file_name_suffix, 63 | num_shards, 64 | shard_name_template, 65 | coder, 66 | compression_type) 67 | 68 | def expand(self, pcoll): 69 | return pcoll | io.iobase.Write(self._sink) 70 | -------------------------------------------------------------------------------- /beam/demo.py: -------------------------------------------------------------------------------- 1 | """Demo Pipeline. 2 | 3 | This file creates a TFrecord dataset and converts it to ArrayRecord on GCS 4 | """ 5 | 6 | import apache_beam as beam 7 | from apache_beam.coders import coders 8 | from . import dofns 9 | from . import example 10 | from . import options 11 | 12 | 13 | ## Grab CLI arguments. 14 | ## Override by passing args/pipeline_options to the function manually. 15 | args, pipeline_options = options.get_arguments() 16 | 17 | 18 | def main(): 19 | p1 = beam.Pipeline(options=pipeline_options) 20 | initial = (p1 21 | | 'Create a set of TFExamples' >> beam.Create( 22 | example.generate_movie_examples() 23 | ) 24 | | 'Write TFRecords' >> beam.io.WriteToTFRecord( 25 | args['input'], 26 | coder=coders.ToBytesCoder(), 27 | num_shards=4, 28 | file_name_suffix='.tfrecord' 29 | ) 30 | | 'Read shards from GCS' >> beam.io.ReadAllFromTFRecord( 31 | with_filename=True) 32 | | 'Group with Filename' >> beam.GroupByKey() 33 | | 'Write to ArrayRecord in GCS' >> beam.ParDo( 34 | dofns.ConvertToArrayRecordGCS(), 35 | args['output'], 36 | overwrite_extension=True)) 37 | 38 | return p1, initial 39 | 40 | 41 | if __name__ == '__main__': 42 | demo_pipeline = main() 43 | demo_pipeline.run() 44 | -------------------------------------------------------------------------------- /beam/dofns.py: -------------------------------------------------------------------------------- 1 | """DoFn's for parallel processing.""" 2 | 3 | import os 4 | import urllib 5 | import apache_beam as beam 6 | from array_record.python.array_record_module import ArrayRecordWriter 7 | from google.cloud import storage 8 | 9 | 10 | class ConvertToArrayRecordGCS(beam.DoFn): 11 | """Write a tuple consisting of a filename and records to GCS ArrayRecords.""" 12 | 13 | _WRITE_DIR = '/tmp/' 14 | 15 | def process( 16 | self, 17 | element, 18 | path, 19 | write_dir=_WRITE_DIR, 20 | file_path_suffix='.arrayrecord', 21 | overwrite_extension=False, 22 | ): 23 | 24 | ## Upload to GCS 25 | def upload_to_gcs(bucket_name, filename, prefix='', source_dir=self._WRITE_DIR): 26 | source_filename = os.path.join(source_dir, filename) 27 | blob_name = os.path.join(prefix, filename) 28 | storage_client = storage.Client() 29 | bucket = storage_client.get_bucket(bucket_name) 30 | blob = bucket.blob(blob_name) 31 | blob.upload_from_filename(source_filename) 32 | 33 | ## Simple logic for stripping a file extension and replacing it 34 | def fix_filename(filename): 35 | base_name = os.path.splitext(filename)[0] 36 | new_filename = base_name + file_path_suffix 37 | return new_filename 38 | 39 | parsed_gcs_path = urllib.parse.urlparse(path) 40 | bucket_name = parsed_gcs_path.hostname 41 | gcs_prefix = parsed_gcs_path.path.lstrip('/') 42 | 43 | if overwrite_extension: 44 | filename = fix_filename(os.path.basename(element[0])) 45 | else: 46 | filename = '{}{}'.format(os.path.basename(element[0]), file_path_suffix) 47 | 48 | write_path = os.path.join(write_dir, filename) 49 | writer = ArrayRecordWriter(write_path, 'group_size:1') 50 | 51 | for item in element[1]: 52 | writer.write(item) 53 | 54 | writer.close() 55 | 56 | upload_to_gcs(bucket_name, filename, prefix=gcs_prefix) 57 | os.remove(os.path.join(write_dir, filename)) 58 | -------------------------------------------------------------------------------- /beam/example.py: -------------------------------------------------------------------------------- 1 | """Helper file for generating TF/ArrayRecords and writing them to disk.""" 2 | 3 | import os 4 | from array_record.python.array_record_module import ArrayRecordWriter 5 | import tensorflow as tf 6 | from . import testdata 7 | 8 | 9 | def generate_movie_examples(): 10 | """Create a list of TF examples from the dummy data above and return it. 11 | 12 | Returns: 13 | TFExample object 14 | """ 15 | 16 | examples = [] 17 | for example in testdata.data: 18 | examples.append( 19 | tf.train.Example( 20 | features=tf.train.Features( 21 | feature={ 22 | 'Age': tf.train.Feature( 23 | int64_list=tf.train.Int64List(value=[example['Age']])), 24 | 'Movie': tf.train.Feature( 25 | bytes_list=tf.train.BytesList( 26 | value=[ 27 | m.encode('utf-8') for m in example['Movie']])), 28 | 'Movie Ratings': tf.train.Feature( 29 | float_list=tf.train.FloatList( 30 | value=example['Movie Ratings'])), 31 | 'Suggestion': tf.train.Feature( 32 | bytes_list=tf.train.BytesList( 33 | value=[example['Suggestion'].encode('utf-8')])), 34 | 'Suggestion Purchased': tf.train.Feature( 35 | float_list=tf.train.FloatList( 36 | value=[example['Suggestion Purchased']])), 37 | 'Purchase Price': tf.train.Feature( 38 | float_list=tf.train.FloatList( 39 | value=[example['Purchase Price']])) 40 | } 41 | ) 42 | ) 43 | ) 44 | 45 | return(examples) 46 | 47 | 48 | def generate_serialized_movie_examples(): 49 | """Return a serialized version of the above data for byte insertion.""" 50 | 51 | return [example.SerializeToString() for example in generate_movie_examples()] 52 | 53 | 54 | def write_example_to_tfrecord(example, file_path): 55 | """Write example(s) to a single TFrecord file.""" 56 | 57 | with tf.io.TFRecordWriter(file_path) as writer: 58 | writer.write(example.SerializeToString()) 59 | 60 | 61 | # Write example(s) to a single ArrayRecord file 62 | def write_example_to_arrayrecord(example, file_path): 63 | writer = ArrayRecordWriter(file_path, 'group_size:1') 64 | writer.write(example.SerializeToString()) 65 | writer.close() 66 | 67 | 68 | def kitty_tfrecord(prefix=''): 69 | """Create a TFRecord from a cat pic on the Internet. 70 | 71 | This is mainly for testing; probably don't use it. 72 | 73 | Args: 74 | prefix: A file directory in string format. 75 | """ 76 | 77 | cat_in_snow = tf.keras.utils.get_file( 78 | '320px-Felis_catus-cat_on_snow.jpg', 79 | 'https://storage.googleapis.com/download.tensorflow.org/example_images/320px-Felis_catus-cat_on_snow.jpg') 80 | 81 | image_labels = { 82 | cat_in_snow: 0 83 | } 84 | 85 | image_string = open(cat_in_snow, 'rb').read() 86 | label = image_labels[cat_in_snow] 87 | image_shape = tf.io.decode_jpeg(image_string).shape 88 | 89 | feature = { 90 | 'height': tf.train.Feature(int64_list=tf.train.Int64List( 91 | value=[image_shape[0]])), 92 | 'width': tf.train.Feature(int64_list=tf.train.Int64List( 93 | value=[image_shape[1]])), 94 | 'depth': tf.train.Feature(int64_list=tf.train.Int64List( 95 | value=[image_shape[2]])), 96 | 'label': tf.train.Feature(int64_list=tf.train.Int64List( 97 | value=[label])), 98 | 'image_raw': tf.train.Feature(bytes_list=tf.train.BytesList( 99 | value=[image_string])) 100 | } 101 | 102 | example = tf.train.Example(features=tf.train.Features(feature=feature)) 103 | 104 | record_file = os.path.join(prefix, 'kittymeow.tfrecord') 105 | with tf.io.TFRecordWriter(record_file) as writer: 106 | writer.write(example.SerializeToString()) 107 | -------------------------------------------------------------------------------- /beam/examples/example_full_demo_cli.sh: -------------------------------------------------------------------------------- 1 | # Execute this via BASH to run a full demo that creates TFRecords and converts them 2 | 3 | #!/bin/bash 4 | 5 | 6 | # Set bucket info below. Uncomment lower lines and set values to use Dataflow. 7 | python -m array_record.beam.demo \ 8 | --input="gs:///records/movies" \ 9 | --output="gs:///records/" \ 10 | # --region="" \ 11 | # --runner="DataflowRunner" \ 12 | # --project="" \ 13 | # --requirements_file="requirements.txt" 14 | -------------------------------------------------------------------------------- /beam/examples/example_gcs_conversion.py: -------------------------------------------------------------------------------- 1 | """Execute this to convert an existing set of TFRecords to ArrayRecords.""" 2 | 3 | 4 | from apache_beam.options import pipeline_options 5 | from array_record.beam.pipelines import convert_tf_to_arrayrecord_gcs 6 | 7 | ## Set input and output patterns as specified 8 | input_pattern = 'gs:///records/*.tfrecord' 9 | output_path = 'gs:///records/' 10 | 11 | args = {'input': input_pattern, 'output': output_path} 12 | 13 | ## If run in Dataflow, set pipeline options and uncomment in main() 14 | ## If run pipeline_options is not set, you will use a local runner 15 | pipeline_options = pipeline_options.PipelineOptions( 16 | runner='DataflowRunner', 17 | project='', 18 | region='', 19 | requirements_file='requirements.txt' 20 | ) 21 | 22 | 23 | def main(): 24 | convert_tf_to_arrayrecord_gcs( 25 | args=args, 26 | # pipeline_options=pipeline_options, 27 | ).run() 28 | 29 | if __name__ == '__main__': 30 | main() 31 | -------------------------------------------------------------------------------- /beam/examples/example_sink_conversion.py: -------------------------------------------------------------------------------- 1 | """Execute this to convert TFRecords to ArrayRecords using the disk Sink.""" 2 | 3 | 4 | from apache_beam.options import pipeline_options 5 | from array_record.beam.pipelines import convert_tf_to_arrayrecord_disk_match_shards 6 | 7 | ## Set input and output patterns as specified 8 | input_pattern = 'gs:///records/*.tfrecord' 9 | output_path = 'records/movies' 10 | 11 | args = {'input': input_pattern, 'output': output_path} 12 | 13 | ## If run in Dataflow, set pipeline options and uncomment in main() 14 | ## If run pipeline_options is not set, you will use a local runner 15 | pipeline_options = pipeline_options.PipelineOptions( 16 | runner='DataflowRunner', 17 | project='', 18 | region='', 19 | requirements_file='requirements.txt' 20 | ) 21 | 22 | 23 | def main(): 24 | convert_tf_to_arrayrecord_disk_match_shards( 25 | args=args, 26 | # pipeline_options=pipeline_options, 27 | ).run() 28 | 29 | if __name__ == '__main__': 30 | main() 31 | -------------------------------------------------------------------------------- /beam/examples/requirements.txt: -------------------------------------------------------------------------------- 1 | array-record[beam] 2 | google-cloud-storage==2.11.0 3 | tensorflow==2.14.0 -------------------------------------------------------------------------------- /beam/options.py: -------------------------------------------------------------------------------- 1 | """Handler for Pipeline and Beam options that allows for cleaner importing.""" 2 | 3 | 4 | import argparse 5 | from apache_beam.options import pipeline_options 6 | 7 | 8 | def get_arguments(): 9 | """Simple external wrapper for argparse that allows for manual construction. 10 | 11 | Returns: 12 | 1. A dictionary of known args for use in pipelines 13 | 2. The remainder of the arguments in PipelineOptions format 14 | 15 | """ 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument( 19 | '--input', 20 | help='The file pattern for the input TFRecords.',) 21 | parser.add_argument( 22 | '--output', 23 | help='The path prefix for output ArrayRecords.') 24 | 25 | args, beam_args = parser.parse_known_args() 26 | return(args.__dict__, pipeline_options.PipelineOptions(beam_args)) 27 | -------------------------------------------------------------------------------- /beam/pipelines.py: -------------------------------------------------------------------------------- 1 | """Various opinionated Beam pipelines for testing different functionality.""" 2 | 3 | import apache_beam as beam 4 | from apache_beam.coders import coders 5 | from . import arrayrecordio 6 | from . import dofns 7 | from . import example 8 | from . import options 9 | 10 | 11 | ## Grab CLI arguments. 12 | ## Override by passing args/pipeline_options to the function manually. 13 | def_args, def_pipeline_options = options.get_arguments() 14 | 15 | 16 | def example_to_tfrecord( 17 | num_shards=1, 18 | args=def_args, 19 | pipeline_options=def_pipeline_options): 20 | """Beam pipeline for creating example TFRecord data. 21 | 22 | Args: 23 | num_shards: Number of files 24 | args: Custom arguments 25 | pipeline_options: Beam arguments in dict format 26 | 27 | Returns: 28 | Beam Pipeline object 29 | """ 30 | 31 | p1 = beam.Pipeline(options=pipeline_options) 32 | _ = ( 33 | p1 34 | | 'Create' >> beam.Create(example.generate_movie_examples()) 35 | | 'Write' 36 | >> beam.io.WriteToTFRecord( 37 | args['output'], 38 | coder=coders.ToBytesCoder(), 39 | num_shards=num_shards, 40 | file_name_suffix='.tfrecord', 41 | ) 42 | ) 43 | return p1 44 | 45 | 46 | def example_to_arrayrecord( 47 | num_shards=1, args=def_args, pipeline_options=def_pipeline_options 48 | ): 49 | """Beam pipeline for creating example ArrayRecord data. 50 | 51 | Args: 52 | num_shards: Number of files 53 | args: Custom arguments 54 | pipeline_options: Beam arguments in dict format 55 | 56 | Returns: 57 | Beam Pipeline object 58 | """ 59 | 60 | p1 = beam.Pipeline(options=pipeline_options) 61 | _ = ( 62 | p1 63 | | 'Create' >> beam.Create(example.generate_movie_examples()) 64 | | 'Write' 65 | >> arrayrecordio.WriteToArrayRecord( 66 | args['output'], 67 | coder=coders.ToBytesCoder(), 68 | num_shards=num_shards, 69 | file_name_suffix='.arrayrecord', 70 | ) 71 | ) 72 | return p1 73 | 74 | 75 | def convert_tf_to_arrayrecord_disk( 76 | num_shards=1, args=def_args, pipeline_options=def_pipeline_options 77 | ): 78 | """Convert TFRecords to ArrayRecords using sink/sharding functionality. 79 | 80 | THIS ONLY WORKS FOR DISK ARRAYRECORD WRITES 81 | 82 | Args: 83 | num_shards: Number of files 84 | args: Custom arguments 85 | pipeline_options: Beam arguments in dict format 86 | 87 | Returns: 88 | Beam Pipeline object 89 | """ 90 | 91 | p1 = beam.Pipeline(options=pipeline_options) 92 | _ = ( 93 | p1 94 | | 'Read TFRecord' >> beam.io.ReadFromTFRecord(args['input']) 95 | | 'Write ArrayRecord' 96 | >> arrayrecordio.WriteToArrayRecord( 97 | args['output'], 98 | coder=coders.ToBytesCoder(), 99 | num_shards=num_shards, 100 | file_name_suffix='.arrayrecord', 101 | ) 102 | ) 103 | return p1 104 | 105 | 106 | def convert_tf_to_arrayrecord_disk_match_shards( 107 | args=def_args, pipeline_options=def_pipeline_options 108 | ): 109 | """Convert TFRecords to matching number of ArrayRecords. 110 | 111 | THIS ONLY WORKS FOR DISK ARRAYRECORD WRITES 112 | 113 | Args: 114 | args: Custom arguments 115 | pipeline_options: Beam arguments in dict format 116 | 117 | Returns: 118 | Beam Pipeline object 119 | """ 120 | 121 | p1 = beam.Pipeline(options=pipeline_options) 122 | initial = ( 123 | p1 124 | | 'Start' >> beam.Create([args['input']]) 125 | | 'Read' >> beam.io.ReadAllFromTFRecord(with_filename=True) 126 | ) 127 | 128 | file_count = ( 129 | initial 130 | | 'Group' >> beam.GroupByKey() 131 | | 'Count Shards' >> beam.combiners.Count.Globally() 132 | ) 133 | 134 | _ = ( 135 | initial 136 | | 'Drop Filename' >> beam.Map(lambda x: x[1]) 137 | | 'Write ArrayRecord' 138 | >> arrayrecordio.WriteToArrayRecord( 139 | args['output'], 140 | coder=coders.ToBytesCoder(), 141 | num_shards=beam.pvalue.AsSingleton(file_count), 142 | file_name_suffix='.arrayrecord', 143 | ) 144 | ) 145 | return p1 146 | 147 | 148 | def convert_tf_to_arrayrecord_gcs( 149 | overwrite_extension=False, 150 | file_path_suffix='.arrayrecord', 151 | args=def_args, 152 | pipeline_options=def_pipeline_options): 153 | """Convert TFRecords to ArrayRecords in GCS 1:1. 154 | 155 | Args: 156 | overwrite_extension: Boolean making DoFn attempt to overwrite extension 157 | file_path_suffix: Intended suffix for overwrite or append 158 | args: Custom arguments 159 | pipeline_options: Beam arguments in dict format 160 | 161 | Returns: 162 | Beam Pipeline object 163 | """ 164 | 165 | p1 = beam.Pipeline(options=pipeline_options) 166 | _ = ( 167 | p1 168 | | 'Start' >> beam.Create([args['input']]) 169 | | 'Read' >> beam.io.ReadAllFromTFRecord(with_filename=True) 170 | | 'Group' >> beam.GroupByKey() 171 | | 'Write to ArrayRecord in GCS' 172 | >> beam.ParDo( 173 | dofns.ConvertToArrayRecordGCS(), 174 | args['output'], 175 | file_path_suffix=file_path_suffix, 176 | overwrite_extension=overwrite_extension, 177 | ) 178 | ) 179 | return p1 180 | -------------------------------------------------------------------------------- /beam/testdata.py: -------------------------------------------------------------------------------- 1 | """Simple test data wrapper. 2 | 3 | Separated to keep Tensorflow and Beam dependencies away from test data 4 | """ 5 | 6 | # Hardcoded multirecord dataset in dict format for testing and demo. 7 | data = [ 8 | { 9 | 'Age': 29, 10 | 'Movie': ['The Shawshank Redemption', 'Fight Club'], 11 | 'Movie Ratings': [9.0, 9.7], 12 | 'Suggestion': 'Inception', 13 | 'Suggestion Purchased': 1.0, 14 | 'Purchase Price': 9.99 15 | }, 16 | { 17 | 'Age': 39, 18 | 'Movie': ['The Prestige', 'The Big Lebowski', 'The Fall'], 19 | 'Movie Ratings': [9.5, 8.5, 8.5], 20 | 'Suggestion': 'Interstellar', 21 | 'Suggestion Purchased': 1.0, 22 | 'Purchase Price': 14.99 23 | }, 24 | { 25 | 'Age': 19, 26 | 'Movie': ['Barbie', 'The Batman', 'Boss Baby', 'Oppenheimer'], 27 | 'Movie Ratings': [9.6, 8.2, 10.0, 4.2], 28 | 'Suggestion': 'Secret Life of Pets', 29 | 'Suggestion Purchased': 0.0, 30 | 'Purchase Price': 25.99 31 | }, 32 | { 33 | 'Age': 35, 34 | 'Movie': ['The Mothman Prophecies', 'Sinister'], 35 | 'Movie Ratings': [8.3, 9.0], 36 | 'Suggestion': 'Hereditary', 37 | 'Suggestion Purchased': 1.0, 38 | 'Purchase Price': 12.99 39 | } 40 | ] 41 | -------------------------------------------------------------------------------- /cpp/BUILD: -------------------------------------------------------------------------------- 1 | # ArrayRecord is a new file format for IO intensive applications. 2 | # It supports efficient random access and various compression algorithms. 3 | 4 | load("@rules_proto//proto:defs.bzl", "proto_library") 5 | 6 | package(default_visibility = ["//visibility:public"]) 7 | 8 | licenses(["notice"]) 9 | 10 | proto_library( 11 | name = "layout_proto", 12 | srcs = ["layout.proto"], 13 | ) 14 | 15 | cc_proto_library( 16 | name = "layout_cc_proto", 17 | deps = [":layout_proto"], 18 | ) 19 | 20 | cc_library( 21 | name = "common", 22 | hdrs = ["common.h"], 23 | deps = [ 24 | "@abseil-cpp//absl/base:core_headers", 25 | "@abseil-cpp//absl/status", 26 | "@abseil-cpp//absl/strings:str_format", 27 | ], 28 | ) 29 | 30 | cc_library( 31 | name = "sequenced_chunk_writer", 32 | srcs = ["sequenced_chunk_writer.cc"], 33 | hdrs = ["sequenced_chunk_writer.h"], 34 | deps = [ 35 | ":common", 36 | "@abseil-cpp//absl/base:core_headers", 37 | "@abseil-cpp//absl/status", 38 | "@abseil-cpp//absl/status:statusor", 39 | "@abseil-cpp//absl/strings:str_format", 40 | "@abseil-cpp//absl/synchronization", 41 | "@riegeli//riegeli/base:initializer", 42 | "@riegeli//riegeli/base:object", 43 | "@riegeli//riegeli/base:status", 44 | "@riegeli//riegeli/base:types", 45 | "@riegeli//riegeli/bytes:writer", 46 | "@riegeli//riegeli/chunk_encoding:chunk", 47 | "@riegeli//riegeli/chunk_encoding:constants", 48 | "@riegeli//riegeli/records:chunk_writer", 49 | ], 50 | ) 51 | 52 | cc_library( 53 | name = "thread_pool", 54 | srcs = ["thread_pool.cc"], 55 | hdrs = ["thread_pool.h"], 56 | deps = [ 57 | "@abseil-cpp//absl/flags:flag", 58 | "@eigen//:eigen", 59 | ], 60 | ) 61 | 62 | cc_library( 63 | name = "parallel_for", 64 | hdrs = ["parallel_for.h"], 65 | deps = [ 66 | ":common", 67 | ":thread_pool", 68 | "@abseil-cpp//absl/base:core_headers", 69 | "@abseil-cpp//absl/status", 70 | "@abseil-cpp//absl/synchronization", 71 | ], 72 | ) 73 | 74 | cc_library( 75 | name = "tri_state_ptr", 76 | hdrs = ["tri_state_ptr.h"], 77 | deps = [ 78 | ":common", 79 | "@abseil-cpp//absl/base:core_headers", 80 | "@abseil-cpp//absl/synchronization", 81 | ], 82 | ) 83 | 84 | cc_library( 85 | name = "test_utils", 86 | testonly = True, 87 | srcs = ["test_utils.cc"], 88 | hdrs = ["test_utils.h"], 89 | deps = [":common"], 90 | ) 91 | 92 | cc_test( 93 | name = "test_utils_test", 94 | srcs = ["test_utils_test.cc"], 95 | deps = [ 96 | ":common", 97 | ":test_utils", 98 | "@abseil-cpp//absl/strings", 99 | "@googletest//:gtest_main", 100 | ], 101 | ) 102 | 103 | cc_library( 104 | name = "array_record_writer", 105 | srcs = ["array_record_writer.cc"], 106 | hdrs = ["array_record_writer.h"], 107 | deps = [ 108 | ":common", 109 | ":layout_cc_proto", 110 | ":sequenced_chunk_writer", 111 | ":thread_pool", 112 | ":tri_state_ptr", 113 | "@abseil-cpp//absl/base:core_headers", 114 | "@abseil-cpp//absl/log:check", 115 | "@abseil-cpp//absl/status", 116 | "@abseil-cpp//absl/status:statusor", 117 | "@abseil-cpp//absl/strings", 118 | "@abseil-cpp//absl/strings:cord", 119 | "@abseil-cpp//absl/synchronization", 120 | "@abseil-cpp//absl/types:span", 121 | "@protobuf//:protobuf_lite", 122 | "@riegeli//riegeli/base:initializer", 123 | "@riegeli//riegeli/base:object", 124 | "@riegeli//riegeli/base:options_parser", 125 | "@riegeli//riegeli/base:status", 126 | "@riegeli//riegeli/bytes:chain_writer", 127 | "@riegeli//riegeli/bytes:writer", 128 | "@riegeli//riegeli/chunk_encoding:chunk", 129 | "@riegeli//riegeli/chunk_encoding:chunk_encoder", 130 | "@riegeli//riegeli/chunk_encoding:compressor_options", 131 | "@riegeli//riegeli/chunk_encoding:constants", 132 | "@riegeli//riegeli/chunk_encoding:deferred_encoder", 133 | "@riegeli//riegeli/chunk_encoding:simple_encoder", 134 | "@riegeli//riegeli/chunk_encoding:transpose_encoder", 135 | "@riegeli//riegeli/records:records_metadata_cc_proto", 136 | ], 137 | ) 138 | 139 | cc_library( 140 | name = "masked_reader", 141 | srcs = ["masked_reader.cc"], 142 | hdrs = ["masked_reader.h"], 143 | deps = [ 144 | ":common", 145 | "@abseil-cpp//absl/memory", 146 | "@abseil-cpp//absl/status", 147 | "@abseil-cpp//absl/time", 148 | "@abseil-cpp//absl/types:optional", 149 | "@riegeli//riegeli/base:object", 150 | "@riegeli//riegeli/base:status", 151 | "@riegeli//riegeli/base:types", 152 | "@riegeli//riegeli/bytes:reader", 153 | ], 154 | ) 155 | 156 | cc_library( 157 | name = "array_record_reader", 158 | srcs = ["array_record_reader.cc"], 159 | hdrs = ["array_record_reader.h"], 160 | deps = [ 161 | ":common", 162 | ":layout_cc_proto", 163 | ":masked_reader", 164 | ":parallel_for", 165 | ":thread_pool", 166 | ":tri_state_ptr", 167 | "@abseil-cpp//absl/base:core_headers", 168 | "@abseil-cpp//absl/functional:any_invocable", 169 | "@abseil-cpp//absl/functional:function_ref", 170 | "@abseil-cpp//absl/status", 171 | "@abseil-cpp//absl/status:statusor", 172 | "@abseil-cpp//absl/strings", 173 | "@abseil-cpp//absl/strings:str_format", 174 | "@abseil-cpp//absl/types:span", 175 | "@protobuf//:protobuf_lite", 176 | "@riegeli//riegeli/base:initializer", 177 | "@riegeli//riegeli/base:object", 178 | "@riegeli//riegeli/base:options_parser", 179 | "@riegeli//riegeli/base:status", 180 | "@riegeli//riegeli/bytes:reader", 181 | "@riegeli//riegeli/chunk_encoding:chunk", 182 | "@riegeli//riegeli/chunk_encoding:chunk_decoder", 183 | "@riegeli//riegeli/records:chunk_reader", 184 | ], 185 | ) 186 | 187 | cc_test( 188 | name = "sequenced_chunk_writer_test", 189 | srcs = ["sequenced_chunk_writer_test.cc"], 190 | deps = [ 191 | ":common", 192 | ":sequenced_chunk_writer", 193 | ":thread_pool", 194 | "@abseil-cpp//absl/status", 195 | "@abseil-cpp//absl/status:statusor", 196 | "@abseil-cpp//absl/strings:cord", 197 | "@abseil-cpp//absl/strings:string_view", 198 | "@abseil-cpp//absl/types:span", 199 | "@googletest//:gtest_main", 200 | "@riegeli//riegeli/base:initializer", 201 | "@riegeli//riegeli/base:shared_ptr", 202 | "@riegeli//riegeli/bytes:chain_writer", 203 | "@riegeli//riegeli/bytes:cord_writer", 204 | "@riegeli//riegeli/bytes:string_reader", 205 | "@riegeli//riegeli/bytes:string_writer", 206 | "@riegeli//riegeli/chunk_encoding:chunk", 207 | "@riegeli//riegeli/chunk_encoding:compressor_options", 208 | "@riegeli//riegeli/chunk_encoding:constants", 209 | "@riegeli//riegeli/chunk_encoding:simple_encoder", 210 | "@riegeli//riegeli/records:record_reader", 211 | ], 212 | ) 213 | 214 | cc_test( 215 | name = "tri_state_ptr_test", 216 | srcs = ["tri_state_ptr_test.cc"], 217 | deps = [ 218 | ":common", 219 | ":thread_pool", 220 | ":tri_state_ptr", 221 | "@abseil-cpp//absl/synchronization", 222 | "@googletest//:gtest_main", 223 | "@riegeli//riegeli/base:initializer", 224 | ], 225 | ) 226 | 227 | cc_test( 228 | name = "array_record_writer_test", 229 | srcs = ["array_record_writer_test.cc"], 230 | shard_count = 4, 231 | tags = ["notsan"], 232 | deps = [ 233 | ":array_record_writer", 234 | ":common", 235 | ":layout_cc_proto", 236 | ":test_utils", 237 | ":thread_pool", 238 | "@abseil-cpp//absl/strings", 239 | "@abseil-cpp//absl/strings:cord", 240 | "@abseil-cpp//absl/strings:cord_test_helpers", 241 | "@googletest//:gtest_main", 242 | "@riegeli//riegeli/base:initializer", 243 | "@riegeli//riegeli/bytes:string_reader", 244 | "@riegeli//riegeli/bytes:string_writer", 245 | "@riegeli//riegeli/chunk_encoding:constants", 246 | "@riegeli//riegeli/records:record_reader", 247 | "@riegeli//riegeli/records:records_metadata_cc_proto", 248 | ], 249 | ) 250 | 251 | cc_test( 252 | name = "masked_reader_test", 253 | srcs = ["masked_reader_test.cc"], 254 | deps = [ 255 | ":masked_reader", 256 | "@googletest//:gtest_main", 257 | "@riegeli//riegeli/bytes:string_reader", 258 | ], 259 | ) 260 | 261 | cc_test( 262 | name = "parallel_for_test", 263 | size = "small", 264 | srcs = ["parallel_for_test.cc"], 265 | deps = [ 266 | ":common", 267 | ":parallel_for", 268 | ":thread_pool", 269 | "@abseil-cpp//absl/functional:function_ref", 270 | "@abseil-cpp//absl/status", 271 | "@googletest//:gtest_main", 272 | ], 273 | ) 274 | 275 | cc_test( 276 | name = "array_record_reader_test", 277 | srcs = ["array_record_reader_test.cc"], 278 | shard_count = 4, 279 | deps = [ 280 | ":array_record_reader", 281 | ":array_record_writer", 282 | ":common", 283 | ":layout_cc_proto", 284 | ":test_utils", 285 | ":thread_pool", 286 | "@abseil-cpp//absl/functional:function_ref", 287 | "@abseil-cpp//absl/status", 288 | "@abseil-cpp//absl/strings", 289 | "@googletest//:gtest_main", 290 | "@riegeli//riegeli/base:initializer", 291 | "@riegeli//riegeli/bytes:string_reader", 292 | "@riegeli//riegeli/bytes:string_writer", 293 | ], 294 | ) 295 | -------------------------------------------------------------------------------- /cpp/array_record_reader_test.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2022 Google LLC. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #include "cpp/array_record_reader.h" 17 | 18 | #include 19 | #include 20 | #include 21 | #include 22 | #include 23 | #include 24 | #include 25 | #include 26 | #include 27 | 28 | #include "gtest/gtest.h" 29 | #include "absl/functional/function_ref.h" 30 | #include "absl/status/status.h" 31 | #include "absl/strings/string_view.h" 32 | #include "cpp/array_record_writer.h" 33 | #include "cpp/common.h" 34 | #include "cpp/layout.pb.h" 35 | #include "cpp/test_utils.h" 36 | #include "cpp/thread_pool.h" 37 | #include "riegeli/base/maker.h" 38 | #include "riegeli/bytes/string_reader.h" 39 | #include "riegeli/bytes/string_writer.h" 40 | 41 | constexpr uint32_t kDatasetSize = 10050; 42 | 43 | namespace array_record { 44 | namespace { 45 | 46 | enum class CompressionType { kUncompressed, kBrotli, kZstd, kSnappy }; 47 | 48 | // Tuple params 49 | // CompressionType 50 | // transpose 51 | // use_thread_pool 52 | // optimize_for_random_access 53 | class ArrayRecordReaderTest 54 | : public testing::TestWithParam< 55 | std::tuple> { 56 | public: 57 | ARThreadPool* get_pool() { return ArrayRecordGlobalPool(); } 58 | ArrayRecordWriterBase::Options GetWriterOptions() { 59 | auto options = ArrayRecordWriterBase::Options(); 60 | switch (std::get<0>(GetParam())) { 61 | case CompressionType::kUncompressed: 62 | options.set_uncompressed(); 63 | break; 64 | case CompressionType::kBrotli: 65 | options.set_brotli(); 66 | break; 67 | case CompressionType::kZstd: 68 | options.set_zstd(); 69 | break; 70 | case CompressionType::kSnappy: 71 | options.set_snappy(); 72 | break; 73 | } 74 | options.set_transpose(transpose()); 75 | return options; 76 | } 77 | 78 | bool transpose() { return std::get<1>(GetParam()); } 79 | bool use_thread_pool() { return std::get<2>(GetParam()); } 80 | bool optimize_for_random_access() { return std::get<3>(GetParam()); } 81 | }; 82 | 83 | TEST_P(ArrayRecordReaderTest, MoveTest) { 84 | std::string encoded; 85 | auto writer_options = GetWriterOptions().set_group_size(2); 86 | auto writer = ArrayRecordWriter( 87 | riegeli::Maker(&encoded), writer_options, nullptr); 88 | 89 | // Empty string should not crash the writer or the reader. 90 | std::vector test_str{"aaa", "", "ccc", "dd", "e"}; 91 | for (const auto& s : test_str) { 92 | EXPECT_TRUE(writer.WriteRecord(s)); 93 | } 94 | ASSERT_TRUE(writer.Close()); 95 | 96 | auto reader_opt = ArrayRecordReaderBase::Options(); 97 | if (optimize_for_random_access()) { 98 | reader_opt.set_max_parallelism(0); 99 | reader_opt.set_readahead_buffer_size(0); 100 | } 101 | 102 | auto reader_before_move = 103 | ArrayRecordReader(riegeli::Maker(encoded), 104 | reader_opt, use_thread_pool() ? get_pool() : nullptr); 105 | ASSERT_TRUE(reader_before_move.status().ok()); 106 | 107 | ASSERT_TRUE( 108 | reader_before_move 109 | .ParallelReadRecords([&](uint64_t record_index, 110 | absl::string_view record) -> absl::Status { 111 | EXPECT_EQ(record, test_str[record_index]); 112 | return absl::OkStatus(); 113 | }) 114 | .ok()); 115 | 116 | EXPECT_EQ(reader_before_move.RecordGroupSize(), 2); 117 | 118 | ArrayRecordReader reader = std::move(reader_before_move); 119 | // Once a reader is moved, it is closed. 120 | ASSERT_FALSE(reader_before_move.is_open()); // NOLINT 121 | 122 | auto recorded_writer_options = ArrayRecordWriterBase::Options::FromString( 123 | reader.WriterOptionsString().value()) 124 | .value(); 125 | EXPECT_EQ(writer_options.compression_type(), 126 | recorded_writer_options.compression_type()); 127 | EXPECT_EQ(writer_options.compression_level(), 128 | recorded_writer_options.compression_level()); 129 | EXPECT_EQ(writer_options.transpose(), recorded_writer_options.transpose()); 130 | 131 | std::vector indices = {1, 2, 4}; 132 | ASSERT_TRUE(reader 133 | .ParallelReadRecordsWithIndices( 134 | indices, 135 | [&](uint64_t indices_idx, 136 | absl::string_view record) -> absl::Status { 137 | EXPECT_EQ(record, test_str[indices[indices_idx]]); 138 | return absl::OkStatus(); 139 | }) 140 | .ok()); 141 | 142 | absl::string_view record_view; 143 | for (auto i : IndicesOf(test_str)) { 144 | EXPECT_TRUE(reader.ReadRecord(&record_view)); 145 | EXPECT_EQ(record_view, test_str[i]); 146 | } 147 | // Cannot read once we are at the end of the file. 148 | EXPECT_FALSE(reader.ReadRecord(&record_view)); 149 | // But the reader should still be healthy. 150 | EXPECT_TRUE(reader.ok()); 151 | 152 | // Seek to a particular record works. 153 | EXPECT_TRUE(reader.SeekRecord(2)); 154 | EXPECT_TRUE(reader.ReadRecord(&record_view)); 155 | EXPECT_EQ(record_view, test_str[2]); 156 | 157 | // Seek out of bound would not fail. 158 | EXPECT_TRUE(reader.SeekRecord(10)); 159 | EXPECT_FALSE(reader.ReadRecord(&record_view)); 160 | EXPECT_TRUE(reader.ok()); 161 | 162 | EXPECT_EQ(reader.RecordGroupSize(), 2); 163 | 164 | ASSERT_TRUE(reader.Close()); 165 | } 166 | 167 | TEST_P(ArrayRecordReaderTest, RandomDatasetTest) { 168 | std::mt19937 bitgen; 169 | std::vector records(kDatasetSize); 170 | std::uniform_int_distribution<> dist(0, 123); 171 | for (auto i : Seq(kDatasetSize)) { 172 | size_t len = dist(bitgen); 173 | records[i] = MTRandomBytes(bitgen, len); 174 | } 175 | 176 | std::string encoded; 177 | auto writer = 178 | ArrayRecordWriter(riegeli::Maker(&encoded), 179 | GetWriterOptions(), get_pool()); 180 | for (auto i : Seq(kDatasetSize)) { 181 | EXPECT_TRUE(writer.WriteRecord(records[i])); 182 | } 183 | ASSERT_TRUE(writer.Close()); 184 | 185 | auto reader_opt = ArrayRecordReaderBase::Options(); 186 | if (optimize_for_random_access()) { 187 | reader_opt.set_max_parallelism(0); 188 | reader_opt.set_readahead_buffer_size(0); 189 | } 190 | 191 | auto reader = 192 | ArrayRecordReader(riegeli::Maker(encoded), 193 | reader_opt, use_thread_pool() ? get_pool() : nullptr); 194 | ASSERT_TRUE(reader.status().ok()); 195 | EXPECT_EQ(reader.NumRecords(), kDatasetSize); 196 | uint64_t group_size = 197 | std::min(ArrayRecordWriterBase::Options::kDefaultGroupSize, kDatasetSize); 198 | EXPECT_EQ(reader.RecordGroupSize(), group_size); 199 | 200 | std::vector read_all_records(kDatasetSize, false); 201 | ASSERT_TRUE(reader 202 | .ParallelReadRecords( 203 | [&](uint64_t record_index, 204 | absl::string_view result_view) -> absl::Status { 205 | EXPECT_EQ(result_view, records[record_index]); 206 | EXPECT_FALSE(read_all_records[record_index]); 207 | read_all_records[record_index] = true; 208 | return absl::OkStatus(); 209 | }) 210 | .ok()); 211 | for (bool record_was_read : read_all_records) { 212 | EXPECT_TRUE(record_was_read); 213 | } 214 | 215 | std::vector indices = {0, 3, 5, 7, 101, 2000}; 216 | std::vector read_indexed_records(indices.size(), false); 217 | ASSERT_TRUE(reader 218 | .ParallelReadRecordsWithIndices( 219 | indices, 220 | [&](uint64_t indices_idx, 221 | absl::string_view result_view) -> absl::Status { 222 | EXPECT_EQ(result_view, records[indices[indices_idx]]); 223 | EXPECT_FALSE(read_indexed_records[indices_idx]); 224 | read_indexed_records[indices_idx] = true; 225 | return absl::OkStatus(); 226 | }) 227 | .ok()); 228 | for (bool record_was_read : read_indexed_records) { 229 | EXPECT_TRUE(record_was_read); 230 | } 231 | 232 | uint64_t begin = 10, end = 101; 233 | std::vector read_range_records(end - begin, false); 234 | ASSERT_TRUE(reader 235 | .ParallelReadRecordsInRange( 236 | begin, end, 237 | [&](uint64_t record_index, 238 | absl::string_view result_view) -> absl::Status { 239 | EXPECT_EQ(result_view, records[record_index]); 240 | EXPECT_FALSE(read_range_records[record_index - begin]); 241 | read_range_records[record_index - begin] = true; 242 | return absl::OkStatus(); 243 | }) 244 | .ok()); 245 | for (bool record_was_read : read_range_records) { 246 | EXPECT_TRUE(record_was_read); 247 | } 248 | 249 | // Test sequential read 250 | absl::string_view result_view; 251 | for (auto record_index : Seq(kDatasetSize)) { 252 | ASSERT_TRUE(reader.ReadRecord(&result_view)); 253 | EXPECT_EQ(result_view, records[record_index]); 254 | } 255 | // Reached to the end. 256 | EXPECT_FALSE(reader.ReadRecord(&result_view)); 257 | EXPECT_TRUE(reader.ok()); 258 | // We can still seek back. 259 | EXPECT_TRUE(reader.SeekRecord(5)); 260 | EXPECT_TRUE(reader.ReadRecord(&result_view)); 261 | EXPECT_EQ(result_view, records[5]); 262 | 263 | ASSERT_TRUE(reader.Close()); 264 | } 265 | 266 | INSTANTIATE_TEST_SUITE_P( 267 | ParamTest, ArrayRecordReaderTest, 268 | testing::Combine(testing::Values(CompressionType::kUncompressed, 269 | CompressionType::kBrotli, 270 | CompressionType::kZstd, 271 | CompressionType::kSnappy), 272 | testing::Bool(), testing::Bool(), testing::Bool())); 273 | 274 | TEST(ArrayRecordReaderOptionTest, ParserTest) { 275 | { 276 | auto option = ArrayRecordReaderBase::Options::FromString("").value(); 277 | EXPECT_EQ(option.max_parallelism(), std::nullopt); 278 | EXPECT_EQ(option.readahead_buffer_size(), 279 | ArrayRecordReaderBase::Options::kDefaultReadaheadBufferSize); 280 | } 281 | { 282 | auto option = 283 | ArrayRecordReaderBase::Options::FromString("max_parallelism:16") 284 | .value(); 285 | EXPECT_EQ(option.max_parallelism(), 16); 286 | EXPECT_EQ(option.readahead_buffer_size(), 287 | ArrayRecordReaderBase::Options::kDefaultReadaheadBufferSize); 288 | } 289 | { 290 | auto option = ArrayRecordReaderBase::Options::FromString( 291 | "max_parallelism:16,readahead_buffer_size:16384") 292 | .value(); 293 | EXPECT_EQ(option.max_parallelism(), 16); 294 | EXPECT_EQ(option.readahead_buffer_size(), 16384); 295 | } 296 | { 297 | auto option = ArrayRecordReaderBase::Options::FromString( 298 | "max_parallelism:0,readahead_buffer_size:0") 299 | .value(); 300 | EXPECT_EQ(option.max_parallelism(), 0); 301 | EXPECT_EQ(option.readahead_buffer_size(), 0); 302 | } 303 | } 304 | 305 | } // namespace 306 | } // namespace array_record 307 | -------------------------------------------------------------------------------- /cpp/array_record_writer_test.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2022 Google LLC. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #include "cpp/array_record_writer.h" 17 | 18 | #include 19 | #include 20 | #include 21 | #include 22 | #include 23 | #include 24 | #include 25 | #include 26 | 27 | #include "gtest/gtest.h" 28 | #include "absl/strings/cord.h" 29 | #include "absl/strings/cord_test_helpers.h" 30 | #include "absl/strings/string_view.h" 31 | #include "cpp/common.h" 32 | #include "cpp/layout.pb.h" 33 | #include "cpp/test_utils.h" 34 | #include "cpp/thread_pool.h" 35 | #include "riegeli/base/maker.h" 36 | #include "riegeli/bytes/string_reader.h" 37 | #include "riegeli/bytes/string_writer.h" 38 | #include "riegeli/chunk_encoding/constants.h" 39 | #include "riegeli/records/record_reader.h" 40 | #include "riegeli/records/records_metadata.pb.h" 41 | 42 | namespace array_record { 43 | 44 | namespace { 45 | 46 | enum class CompressionType { kUncompressed, kBrotli, kZstd, kSnappy }; 47 | 48 | // Tuple params 49 | // CompressionType 50 | // padding 51 | // transpose 52 | // use ThreadPool 53 | class ArrayRecordWriterTest 54 | : public testing::TestWithParam< 55 | std::tuple> { 56 | public: 57 | ArrayRecordWriterBase::Options GetOptions() { 58 | auto options = ArrayRecordWriterBase::Options(); 59 | switch (std::get<0>(GetParam())) { 60 | case CompressionType::kUncompressed: 61 | options.set_uncompressed(); 62 | break; 63 | case CompressionType::kBrotli: 64 | options.set_brotli(); 65 | break; 66 | case CompressionType::kZstd: 67 | options.set_zstd(); 68 | break; 69 | case CompressionType::kSnappy: 70 | options.set_snappy(); 71 | break; 72 | } 73 | options.set_pad_to_block_boundary(std::get<1>(GetParam())); 74 | options.set_transpose(std::get<2>(GetParam())); 75 | return options; 76 | } 77 | }; 78 | 79 | template 80 | void SilenceMoveAfterUseForTest(T&) {} 81 | 82 | TEST_P(ArrayRecordWriterTest, MoveTest) { 83 | std::string encoded; 84 | ARThreadPool* pool = nullptr; 85 | if (std::get<3>(GetParam())) { 86 | pool = ArrayRecordGlobalPool(); 87 | } 88 | auto options = GetOptions(); 89 | options.set_group_size(2); 90 | auto writer = ArrayRecordWriter( 91 | riegeli::Maker(&encoded), options, pool); 92 | 93 | // Empty string should not crash the writer/reader. 94 | std::vector test_str{"aaa", "", "ccc", "dd", "e"}; 95 | for (auto i : Seq(3)) { 96 | EXPECT_TRUE(writer.WriteRecord(test_str[i])); 97 | } 98 | 99 | auto moved_writer = std::move(writer); 100 | SilenceMoveAfterUseForTest(writer); 101 | // Once moved, writer is closed. 102 | ASSERT_FALSE(writer.is_open()); 103 | ASSERT_TRUE(moved_writer.is_open()); 104 | // Once moved we can no longer write records. 105 | EXPECT_FALSE(writer.WriteRecord(test_str[3])); 106 | 107 | ASSERT_TRUE(moved_writer.status().ok()); 108 | EXPECT_TRUE(moved_writer.WriteRecord(test_str[3])); 109 | EXPECT_TRUE(moved_writer.WriteRecord(test_str[4])); 110 | ASSERT_TRUE(moved_writer.Close()); 111 | 112 | auto reader = 113 | riegeli::RecordReader(riegeli::Maker(encoded)); 114 | for (const auto& expected : test_str) { 115 | std::string result; 116 | reader.ReadRecord(result); 117 | EXPECT_EQ(result, expected); 118 | } 119 | } 120 | 121 | TEST_P(ArrayRecordWriterTest, CordTest) { 122 | std::string encoded; 123 | ARThreadPool* pool = nullptr; 124 | if (std::get<3>(GetParam())) { 125 | pool = ArrayRecordGlobalPool(); 126 | } 127 | auto options = GetOptions(); 128 | options.set_group_size(2); 129 | auto writer = ArrayRecordWriter( 130 | riegeli::Maker(&encoded), options, pool); 131 | 132 | absl::Cord flat_cord("test"); 133 | // Empty string should not crash the writer. 134 | absl::Cord empty_cord(""); 135 | absl::Cord fragmented_cord = absl::MakeFragmentedCord({"aaa ", "", "c"}); 136 | 137 | EXPECT_TRUE(writer.WriteRecord(flat_cord)); 138 | EXPECT_TRUE(writer.WriteRecord(empty_cord)); 139 | EXPECT_TRUE(writer.WriteRecord(fragmented_cord)); 140 | ASSERT_TRUE(writer.Close()); 141 | 142 | // Empty string should not crash the reader. 143 | std::vector expected_strings{"test", "", "aaa c"}; 144 | 145 | auto reader = 146 | riegeli::RecordReader(riegeli::Maker(encoded)); 147 | for (const auto& expected : expected_strings) { 148 | std::string result; 149 | reader.ReadRecord(result); 150 | EXPECT_EQ(result, expected); 151 | } 152 | } 153 | 154 | TEST_P(ArrayRecordWriterTest, RandomDatasetTest) { 155 | std::mt19937 bitgen; 156 | constexpr uint32_t kGroupSize = 100; 157 | constexpr uint32_t num_records = 1357; 158 | std::vector records(num_records); 159 | std::uniform_int_distribution<> dist(0, 123); 160 | for (auto i : Seq(num_records)) { 161 | size_t len = dist(bitgen); 162 | records[i] = MTRandomBytes(bitgen, len); 163 | } 164 | // results are stored in encoded 165 | std::string encoded; 166 | 167 | ARThreadPool* pool = nullptr; 168 | if (std::get<3>(GetParam())) { 169 | pool = ArrayRecordGlobalPool(); 170 | } 171 | auto options = GetOptions(); 172 | options.set_group_size(kGroupSize); 173 | 174 | auto writer = ArrayRecordWriter( 175 | riegeli::Maker(&encoded), options, pool); 176 | 177 | for (auto i : Seq(num_records)) { 178 | EXPECT_TRUE(writer.WriteRecord(records[i])); 179 | } 180 | ASSERT_TRUE(writer.Close()); 181 | 182 | auto reader = 183 | riegeli::RecordReader(riegeli::Maker(encoded)); 184 | 185 | // Verify metadata 186 | ASSERT_TRUE(reader.CheckFileFormat()); 187 | 188 | // Verify each record 189 | for (auto i : Seq(num_records)) { 190 | absl::string_view result_view; 191 | ASSERT_TRUE(reader.ReadRecord(result_view)); 192 | EXPECT_EQ(result_view, records[i]); 193 | } 194 | 195 | // Verify postcript 196 | ASSERT_TRUE(reader.Seek(reader.Size().value() - (1 << 16))); 197 | RiegeliPostscript postscript; 198 | ASSERT_TRUE(reader.ReadRecord(postscript)); 199 | ASSERT_EQ(postscript.magic(), 0x71930e704fdae05eULL); 200 | 201 | // Verify Footer 202 | ASSERT_TRUE(reader.Seek(postscript.footer_offset())); 203 | RiegeliFooterMetadata footer_metadata; 204 | ASSERT_TRUE(reader.ReadRecord(footer_metadata)); 205 | ASSERT_EQ(footer_metadata.array_record_metadata().version(), 1); 206 | auto num_chunks = footer_metadata.array_record_metadata().num_chunks(); 207 | std::vector footers(num_chunks); 208 | for (auto i : Seq(num_chunks)) { 209 | ASSERT_TRUE(reader.ReadRecord(footers[i])); 210 | } 211 | 212 | // Verify we can access the file randomly by chunk_offset recorded in the 213 | // footer 214 | for (auto i = 0UL; i < num_chunks; ++i) { 215 | ASSERT_TRUE(reader.Seek(footers[i].chunk_offset())); 216 | absl::string_view result_view; 217 | ASSERT_TRUE(reader.ReadRecord(result_view)) << reader.status(); 218 | EXPECT_EQ(result_view, records[i * kGroupSize]); 219 | } 220 | ASSERT_TRUE(reader.Close()); 221 | } 222 | 223 | INSTANTIATE_TEST_SUITE_P( 224 | ParamTest, ArrayRecordWriterTest, 225 | testing::Combine(testing::Values(CompressionType::kUncompressed, 226 | CompressionType::kBrotli, 227 | CompressionType::kZstd, 228 | CompressionType::kSnappy), 229 | testing::Bool(), testing::Bool(), testing::Bool())); 230 | 231 | TEST(ArrayRecordWriterOptionsTest, ParsingTest) { 232 | { 233 | auto option = ArrayRecordWriterBase::Options(); 234 | EXPECT_EQ(option.group_size(), 235 | ArrayRecordWriterBase::Options::kDefaultGroupSize); 236 | EXPECT_FALSE(option.transpose()); 237 | EXPECT_EQ(option.max_parallelism(), std::nullopt); 238 | EXPECT_EQ(option.compressor_options().compression_type(), 239 | riegeli::CompressionType::kZstd); 240 | EXPECT_EQ(option.compressor_options().compression_level(), 3); 241 | EXPECT_FALSE(option.pad_to_block_boundary()); 242 | } 243 | { 244 | auto option = ArrayRecordWriterBase::Options::FromString("").value(); 245 | EXPECT_EQ(option.group_size(), 246 | ArrayRecordWriterBase::Options::kDefaultGroupSize); 247 | EXPECT_FALSE(option.transpose()); 248 | EXPECT_EQ(option.max_parallelism(), std::nullopt); 249 | EXPECT_EQ(option.compressor_options().compression_type(), 250 | riegeli::CompressionType::kZstd); 251 | EXPECT_EQ(option.compressor_options().compression_level(), 3); 252 | EXPECT_EQ(option.compressor_options().window_log().value(), 20); 253 | EXPECT_FALSE(option.pad_to_block_boundary()); 254 | 255 | EXPECT_EQ(option.ToString(), 256 | "group_size:65536," 257 | "transpose:false," 258 | "pad_to_block_boundary:false," 259 | "zstd:3," 260 | "window_log:20"); 261 | EXPECT_TRUE( 262 | ArrayRecordWriterBase::Options::FromString(option.ToString()).ok()); 263 | } 264 | { 265 | auto option = ArrayRecordWriterBase::Options::FromString("default").value(); 266 | EXPECT_EQ(option.group_size(), 267 | ArrayRecordWriterBase::Options::kDefaultGroupSize); 268 | EXPECT_FALSE(option.transpose()); 269 | EXPECT_EQ(option.max_parallelism(), std::nullopt); 270 | EXPECT_EQ(option.compressor_options().compression_type(), 271 | riegeli::CompressionType::kZstd); 272 | EXPECT_EQ(option.compressor_options().compression_level(), 3); 273 | EXPECT_EQ(option.compressor_options().window_log().value(), 20); 274 | EXPECT_FALSE(option.pad_to_block_boundary()); 275 | 276 | EXPECT_EQ(option.ToString(), 277 | "group_size:65536," 278 | "transpose:false," 279 | "pad_to_block_boundary:false," 280 | "zstd:3," 281 | "window_log:20"); 282 | EXPECT_TRUE( 283 | ArrayRecordWriterBase::Options::FromString(option.ToString()).ok()); 284 | } 285 | { 286 | auto option = ArrayRecordWriterBase::Options::FromString( 287 | "group_size:32,transpose,window_log:20") 288 | .value(); 289 | EXPECT_EQ(option.group_size(), 32); 290 | EXPECT_TRUE(option.transpose()); 291 | EXPECT_EQ(option.max_parallelism(), std::nullopt); 292 | EXPECT_EQ(option.compressor_options().compression_type(), 293 | riegeli::CompressionType::kZstd); 294 | EXPECT_EQ(option.compressor_options().window_log(), 20); 295 | EXPECT_FALSE(option.pad_to_block_boundary()); 296 | 297 | EXPECT_EQ(option.ToString(), 298 | "group_size:32," 299 | "transpose:true," 300 | "pad_to_block_boundary:false," 301 | "transpose_bucket_size:256," 302 | "zstd:3," 303 | "window_log:20"); 304 | EXPECT_TRUE( 305 | ArrayRecordWriterBase::Options::FromString(option.ToString()).ok()); 306 | } 307 | { 308 | auto option = ArrayRecordWriterBase::Options::FromString( 309 | "brotli:6,group_size:32,transpose,window_log:25") 310 | .value(); 311 | EXPECT_EQ(option.group_size(), 32); 312 | EXPECT_TRUE(option.transpose()); 313 | EXPECT_EQ(option.max_parallelism(), std::nullopt); 314 | EXPECT_EQ(option.compressor_options().compression_type(), 315 | riegeli::CompressionType::kBrotli); 316 | EXPECT_EQ(option.compressor_options().window_log(), 25); 317 | EXPECT_FALSE(option.pad_to_block_boundary()); 318 | 319 | EXPECT_EQ(option.ToString(), 320 | "group_size:32," 321 | "transpose:true," 322 | "pad_to_block_boundary:false," 323 | "transpose_bucket_size:256," 324 | "brotli:6," 325 | "window_log:25"); 326 | EXPECT_TRUE( 327 | ArrayRecordWriterBase::Options::FromString(option.ToString()).ok()); 328 | } 329 | { 330 | auto option = ArrayRecordWriterBase::Options::FromString( 331 | "group_size:32,transpose,zstd:5") 332 | .value(); 333 | EXPECT_EQ(option.group_size(), 32); 334 | EXPECT_TRUE(option.transpose()); 335 | EXPECT_EQ(option.max_parallelism(), std::nullopt); 336 | EXPECT_EQ(option.compressor_options().compression_type(), 337 | riegeli::CompressionType::kZstd); 338 | EXPECT_EQ(option.compressor_options().window_log(), 20); 339 | EXPECT_EQ(option.compressor_options().compression_level(), 5); 340 | EXPECT_FALSE(option.pad_to_block_boundary()); 341 | 342 | EXPECT_EQ(option.ToString(), 343 | "group_size:32," 344 | "transpose:true," 345 | "pad_to_block_boundary:false," 346 | "transpose_bucket_size:256," 347 | "zstd:5," 348 | "window_log:20"); 349 | EXPECT_TRUE( 350 | ArrayRecordWriterBase::Options::FromString(option.ToString()).ok()); 351 | } 352 | { 353 | auto option = ArrayRecordWriterBase::Options::FromString( 354 | "uncompressed,pad_to_block_boundary:true") 355 | .value(); 356 | EXPECT_EQ(option.group_size(), 357 | ArrayRecordWriterBase::Options::kDefaultGroupSize); 358 | EXPECT_FALSE(option.transpose()); 359 | EXPECT_EQ(option.max_parallelism(), std::nullopt); 360 | EXPECT_EQ(option.compressor_options().compression_type(), 361 | riegeli::CompressionType::kNone); 362 | EXPECT_TRUE(option.pad_to_block_boundary()); 363 | 364 | EXPECT_EQ(option.ToString(), 365 | "group_size:65536," 366 | "transpose:false," 367 | "pad_to_block_boundary:true," 368 | "uncompressed"); 369 | EXPECT_TRUE( 370 | ArrayRecordWriterBase::Options::FromString(option.ToString()).ok()); 371 | } 372 | { 373 | auto option = ArrayRecordWriterBase::Options::FromString( 374 | "snappy,pad_to_block_boundary:true") 375 | .value(); 376 | EXPECT_EQ(option.group_size(), 377 | ArrayRecordWriterBase::Options::kDefaultGroupSize); 378 | EXPECT_FALSE(option.transpose()); 379 | EXPECT_EQ(option.max_parallelism(), std::nullopt); 380 | EXPECT_EQ(option.compressor_options().compression_type(), 381 | riegeli::CompressionType::kSnappy); 382 | EXPECT_TRUE(option.pad_to_block_boundary()); 383 | 384 | EXPECT_EQ(option.ToString(), 385 | "group_size:65536," 386 | "transpose:false," 387 | "pad_to_block_boundary:true," 388 | "snappy"); 389 | EXPECT_TRUE( 390 | ArrayRecordWriterBase::Options::FromString(option.ToString()).ok()); 391 | } 392 | } 393 | 394 | } // namespace 395 | } // namespace array_record 396 | -------------------------------------------------------------------------------- /cpp/common.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2022 Google LLC. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #ifndef ARRAY_RECORD_CPP_COMMON_H_ 17 | #define ARRAY_RECORD_CPP_COMMON_H_ 18 | 19 | #include "absl/base/attributes.h" 20 | #include "absl/status/status.h" 21 | #include "absl/strings/str_format.h" 22 | 23 | namespace array_record { 24 | 25 | //////////////////////////////////////////////////////////////////////////////// 26 | // Canonical Errors (with formatting!) 27 | //////////////////////////////////////////////////////////////////////////////// 28 | 29 | template 30 | ABSL_MUST_USE_RESULT absl::Status FailedPreconditionError( 31 | const absl::FormatSpec& fmt, const Args&... args) { 32 | return absl::FailedPreconditionError(absl::StrFormat(fmt, args...)); 33 | } 34 | 35 | template 36 | ABSL_MUST_USE_RESULT absl::Status InternalError( 37 | const absl::FormatSpec& fmt, const Args&... args) { 38 | return absl::InternalError(absl::StrFormat(fmt, args...)); 39 | } 40 | 41 | template 42 | ABSL_MUST_USE_RESULT absl::Status InvalidArgumentError( 43 | const absl::FormatSpec& fmt, const Args&... args) { 44 | return absl::InvalidArgumentError(absl::StrFormat(fmt, args...)); 45 | } 46 | 47 | template 48 | ABSL_MUST_USE_RESULT absl::Status NotFoundError( 49 | const absl::FormatSpec& fmt, const Args&... args) { 50 | return absl::NotFoundError(absl::StrFormat(fmt, args...)); 51 | } 52 | 53 | template 54 | ABSL_MUST_USE_RESULT absl::Status OutOfRangeError( 55 | const absl::FormatSpec& fmt, const Args&... args) { 56 | return absl::OutOfRangeError(absl::StrFormat(fmt, args...)); 57 | } 58 | 59 | template 60 | ABSL_MUST_USE_RESULT absl::Status UnavailableError( 61 | const absl::FormatSpec& fmt, const Args&... args) { 62 | return absl::UnavailableError(absl::StrFormat(fmt, args...)); 63 | } 64 | 65 | template 66 | ABSL_MUST_USE_RESULT absl::Status UnimplementedError( 67 | const absl::FormatSpec& fmt, const Args&... args) { 68 | return absl::UnimplementedError(absl::StrFormat(fmt, args...)); 69 | } 70 | 71 | template 72 | ABSL_MUST_USE_RESULT absl::Status UnknownError( 73 | const absl::FormatSpec& fmt, const Args&... args) { 74 | return absl::UnknownError(absl::StrFormat(fmt, args...)); 75 | } 76 | 77 | // TODO(fchern): Align with what XLA do. 78 | template 79 | constexpr Int DivRoundUp(Int num, DenomInt denom) { 80 | // Note: we want DivRoundUp(my_uint64, 17) to just work, so we cast the denom 81 | // to the numerator's type. The result of division always fits in the 82 | // numerator's type, so this is very safe. 83 | return (num + static_cast(denom) - static_cast(1)) / 84 | static_cast(denom); 85 | } 86 | 87 | //////////////////////////////////////////////////////////////////////////////// 88 | // Class Decorators 89 | //////////////////////////////////////////////////////////////////////////////// 90 | 91 | #define DECLARE_COPYABLE_CLASS(ClassName) \ 92 | ClassName(ClassName&&) = default; \ 93 | ClassName& operator=(ClassName&&) = default; \ 94 | ClassName(const ClassName&) = default; \ 95 | ClassName& operator=(const ClassName&) = default 96 | 97 | #define DECLARE_MOVE_ONLY_CLASS(ClassName) \ 98 | ClassName(ClassName&&) = default; \ 99 | ClassName& operator=(ClassName&&) = default; \ 100 | ClassName(const ClassName&) = delete; \ 101 | ClassName& operator=(const ClassName&) = delete 102 | 103 | #define DECLARE_IMMOBILE_CLASS(ClassName) \ 104 | ClassName(ClassName&&) = delete; \ 105 | ClassName& operator=(ClassName&&) = delete; \ 106 | ClassName(const ClassName&) = delete; \ 107 | ClassName& operator=(const ClassName&) = delete 108 | 109 | //////////////////////////////////////////////////////////////////////////////// 110 | // Seq / SeqWithStride / IndicesOf 111 | //////////////////////////////////////////////////////////////////////////////// 112 | // 113 | // Seq facilitates iterating over [begin, end) index ranges. 114 | // 115 | // * Avoids 3X stutter of the 'idx' variable, facilitating use of more 116 | // descriptive variable names like 'datapoint_idx', 'centroid_idx', etc. 117 | // 118 | // * Unifies the syntax between ParallelFor and vanilla for-loops. 119 | // 120 | // * Reverse iteration is much easier to read and less error prone. 121 | // 122 | // * Strided iteration becomes harder to miss when skimming code. 123 | // 124 | // * Reduction in boilerplate '=', '<', '+=' symbols makes it easier to 125 | // skim-read code with lots of small for-loops interleaed with operator heavy 126 | // logic (ie, most of ScaM). 127 | // 128 | // * Zero runtime overhead. 129 | // 130 | // 131 | // Equivalent for-loops (basic iteration): 132 | // 133 | // for (size_t idx : Seq(collection.size()) { ... } 134 | // for (size_t idx : Seq(0, collection.size()) { ... } 135 | // for (size_t idx = 0; idx < collection.size(); idx++) { ... } 136 | // 137 | // 138 | // In particular, reverse iteration becomes much simpler and more readable: 139 | // 140 | // for (size_t idx : ReverseSeq(collection.size())) { ... } 141 | // for (ssize_t idx = collection.size() - 1; idx >= 0; idx--) { ... } 142 | // 143 | // 144 | // Strided iteration works too: 145 | // 146 | // for (size_t idx : SeqWithStride<8>(filenames.size())) { ... } 147 | // for (size_t idx = 0; idx < filenames.size(); idx += 8) { ... } 148 | // 149 | // 150 | // Iteration count without using a variable: 151 | // 152 | // for (auto _ : Seq(16)) { ... } 153 | // 154 | // 155 | // Clarifies the ParallelFor syntax: 156 | // 157 | // ParallelFor<1>(Seq(dataset.size()), &pool, [&](size_t datapoint_idx) { 158 | // ... 159 | // }); 160 | // 161 | template 162 | class SeqWithStride { 163 | public: 164 | static constexpr size_t Stride() { return kStride; } 165 | 166 | // Constructor for iterating [0, end). 167 | inline explicit SeqWithStride(size_t end) : begin_(0), end_(end) {} 168 | 169 | // Constructor for iterating [begin, end). 170 | inline SeqWithStride(size_t begin, size_t end) : begin_(begin), end_(end) { 171 | static_assert(kStride != 0); 172 | } 173 | 174 | // SizeT is an internal detail that helps suppress 'unused variable' compiler 175 | // errors. It's implicitly convertible to size_t, but by virtue of having a 176 | // destructor, the compiler doesn't complain about unused SizeT variables. 177 | // 178 | // These are equivalent: 179 | // 180 | // for (auto _ : Seq(10)) // Suppresses 'unused variable' error. 181 | // for (SizeT _ : Seq(10)) // Suppresses 'unused variable' error. 182 | // 183 | // Prefer the 'auto' variant. Don't use SizeT directly. 184 | // 185 | class SizeT { 186 | public: 187 | // Implicit SizeT <=> SizeT conversions. 188 | inline SizeT(size_t val) : val_(val) {} // NOLINT 189 | inline operator size_t() const { return val_; } // NOLINT 190 | 191 | // Defining a destructor suppresses 'unused variable' errors for the 192 | // following pattern: for (auto _ : Seq(kNumIters)) { ... } 193 | inline ~SizeT() {} 194 | 195 | private: 196 | size_t val_; 197 | }; 198 | 199 | // Iterator implements the "!=", "++" and "*" operators required to support 200 | // the C++ for-each syntax. Not intended for direct use. 201 | class Iterator { 202 | public: 203 | // Constructor. 204 | inline explicit Iterator(size_t idx) : idx_(idx) {} 205 | // The '*' operator. 206 | inline SizeT operator*() const { return idx_; } 207 | // The '++' operator. 208 | inline Iterator& operator++() { 209 | idx_ += kStride; 210 | return *this; 211 | } 212 | // The '!=' operator. 213 | inline bool operator!=(Iterator end) const { 214 | // Note: The comparison below is "<", not "!=", in order to generate the 215 | // correct behavior when (end - begin) is not a multiple of kStride; note 216 | // that the Iterator class only exists to support the C++ for-each syntax, 217 | // and is *not* intended for direct use. 218 | // 219 | // Consider the case where (end - begin) is not a multple of kStride: 220 | // 221 | // for (size_t j : SeqWithStride<5>(9)) { ... } 222 | // for (size_t j = 0; j < 9; j += 5) { ... } // '==' wouldn't work. 223 | // 224 | if constexpr (kStride > 0) { 225 | return idx_ < end.idx_; 226 | } 227 | 228 | // The reverse-iteration case: 229 | // 230 | // for (size_t j : ReverseSeq(sz)) { ... } 231 | // for (ssize_t j = sz-1; j >= 0; j -= 5) { ... } 232 | // 233 | return static_cast(idx_) >= static_cast(end.idx_); 234 | } 235 | 236 | private: 237 | size_t idx_; 238 | }; 239 | using iterator = Iterator; 240 | 241 | inline Iterator begin() const { return Iterator(begin_); } 242 | inline Iterator end() const { return Iterator(end_); } 243 | 244 | private: 245 | size_t begin_; 246 | size_t end_; 247 | }; 248 | 249 | // Seq iterates [0, end) 250 | inline auto Seq(size_t end) { return SeqWithStride<1>(0, end); } 251 | 252 | // Seq iterates [begin, end). 253 | inline auto Seq(size_t begin, size_t end) { 254 | return SeqWithStride<1>(begin, end); 255 | } 256 | 257 | // IndicesOf provides the following equivalence class: 258 | // 259 | // for (size_t j : IndicesOf(container)) { ... } 260 | // for (size_t j : Seq(container.size()) { ... } 261 | // 262 | template 263 | SeqWithStride<1> IndicesOf(const Container& container) { 264 | return Seq(container.size()); 265 | } 266 | 267 | //////////////////////////////////////////////////////////////////////////////// 268 | // Enumerate 269 | //////////////////////////////////////////////////////////////////////////////// 270 | 271 | template ())), 273 | typename = decltype(std::end(std::declval()))> 274 | constexpr auto Enumerate(T&& iterable) { 275 | class IteratorWithIndex { 276 | public: 277 | IteratorWithIndex(IdxType idx, TIter it) : idx_(idx), it_(it) {} 278 | bool operator!=(const IteratorWithIndex& other) const { 279 | return it_ != other.it_; 280 | } 281 | void operator++() { idx_++, it_++; } 282 | auto operator*() const { return std::tie(idx_, *it_); } 283 | 284 | private: 285 | IdxType idx_; 286 | TIter it_; 287 | }; 288 | struct iterator_wrapper { 289 | T iterable; 290 | auto begin() { return IteratorWithIndex{0, std::begin(iterable)}; } 291 | auto end() { return IteratorWithIndex{0, std::end(iterable)}; } 292 | }; 293 | return iterator_wrapper{std::forward(iterable)}; 294 | } 295 | 296 | //////////////////////////////////////////////////////////////////////////////// 297 | // Profiling Helpers 298 | //////////////////////////////////////////////////////////////////////////////// 299 | 300 | #define AR_ENDO_MARKER(...) 301 | #define AR_ENDO_MARKER_TIMEOUT(...) 302 | #define AR_ENDO_TASK(...) 303 | #define AR_ENDO_JOB(...) 304 | #define AR_ENDO_TASK_TIMEOUT(...) 305 | #define AR_ENDO_JOB_TIMEOUT(...) 306 | #define AR_ENDO_SCOPE(...) 307 | #define AR_ENDO_SCOPE_TIMEOUT(...) 308 | #define AR_ENDO_EVENT(...) 309 | #define AR_ENDO_ERROR(...) 310 | #define AR_ENDO_UNITS(...) 311 | #define AR_ENDO_THREAD_NAME(...) 312 | #define AR_ENDO_GROUP(...) 313 | 314 | } // namespace array_record 315 | 316 | #endif // ARRAY_RECORD_CPP_COMMON_H_ 317 | -------------------------------------------------------------------------------- /cpp/layout.proto: -------------------------------------------------------------------------------- 1 | // Copyright 2022 Google LLC 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | // Specialized Riegeli file format designed for IO heavy tasks. It works for 16 | // storing large block of raw data without a size limit as well as structured 17 | // data like protobufs. In both cases it supports random access and parallel 18 | // reads in a single file. 19 | syntax = "proto2"; 20 | 21 | package array_record; 22 | 23 | // Riegeli files are composed in data chunks. Each data chunk contains multiple 24 | // records, and a record can be a serialized proto (with a size limit of 2GB) or 25 | // arbitrary bytes without size limits. 26 | // 27 | // Each Riegeli data chunk is encoded/compressed separately. The chunks are the 28 | // entry points for decoding, which allows us to read the chunks in parallel if 29 | // we know where these entry points are. 30 | // 31 | // We would not know the offsets to the chunks until we serialized them. Hence, 32 | // a natural way to record these offsets is to store them as a Riegeli chunk in 33 | // the footer. Finally, to tell where the footer start, we need a postscript 34 | // storing the offset of the footer. The postscript itself is also a Riegeli 35 | // chunk size 64KB and is 64KB aligned. Therefore we can locate it by 36 | // subtracting the file size by 64KB. See the illustrated file layout below. 37 | // 38 | // +-----------------+ 39 | // | User Data | 40 | // | Riegeli Chunk | 41 | // +-----------------+ 42 | // | User Data | 43 | // | Riegeli Chunk | 44 | // +-----------------+ 45 | // /\/\/\/\/\/\/\/\/\/\/ 46 | // /\/\/\/\/\/\/\/\/\/\/ _+-----------------------+ 47 | // +-----------------+ _/ | RiegeliFooterMetadata | 48 | // | Last User Data | __/ +-----------------------+ 49 | // | Chunk | _/ | Footer Proto | 50 | // +-----------------_/ +-----------------------+ 51 | // | | | Footer Proto | 52 | // | Footer Chunk | +-----------------------+ 53 | // | | | Footer Proto | 54 | // +-----------------+---------+-----------------------+ 55 | // |RiegeliPostscript| <--- Must Align 64KB and fit in 64KB. 56 | // +-----------------+ 57 | // 58 | // The footer is composed of a header (RiegeliFooterMetadata) and an array of 59 | // proto records. We choose an array of proto records instead of repeated fields 60 | // in a single proto to avoid the 2GB proto size limit. We can use proto `enum` 61 | // or `oneof` in the footer metadata to create polymorphic indices. In other 62 | // words, we can develop indices beyond array-like access patterns and extend to 63 | // any data structures worth serializing to disk. 64 | 65 | // Footer proto for locating user data chunk. 66 | message ArrayRecordFooter { 67 | optional uint64 chunk_offset = 1; 68 | optional uint64 decoded_data_size = 2; 69 | optional uint64 num_records = 3; 70 | } 71 | 72 | // Metadata/Header in the footer. 73 | message RiegeliFooterMetadata { 74 | // Metadata for ArrayRecordFooter. 75 | message ArrayRecordMetadata { 76 | // Version number of the ArrayRecord format itself. 77 | // This number should rarely change unless there's a new great layout design 78 | // that wasn't backward compatible and justifies its performance and 79 | // reliability worth us to implement. 80 | optional uint32 version = 1; 81 | optional uint64 num_chunks = 2; 82 | optional uint64 num_records = 3; 83 | 84 | // Writer options for debugging purposes. 85 | optional string writer_options = 4; 86 | } 87 | 88 | oneof metadata { 89 | // Specifying the metadata be array_record_metadata implies the footer 90 | // protos are ArrayRecordFooter. 91 | ArrayRecordMetadata array_record_metadata = 1; 92 | } 93 | 94 | // Useful field for us to support append in the future. 95 | reserved "previous_metadata"; 96 | } 97 | 98 | // A small proto message serialized at the end of the file. This proto would be 99 | // stored three times in the postscript chunk for redundancy. Therefore its 100 | // serialized size times three replication count should be smaller than the 64KB 101 | // Riegeli block boundary. 102 | message RiegeliPostscript { 103 | optional uint64 footer_offset = 1; 104 | optional uint64 magic = 2; 105 | 106 | reserved "footer_encoding"; 107 | } 108 | -------------------------------------------------------------------------------- /cpp/masked_reader.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2022 Google LLC. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #include "cpp/masked_reader.h" 17 | 18 | #include 19 | #include 20 | #include 21 | 22 | #include "absl/memory/memory.h" 23 | #include "absl/status/status.h" 24 | #include "absl/time/clock.h" 25 | #include "absl/time/time.h" 26 | #include "cpp/common.h" 27 | #include "riegeli/base/object.h" 28 | #include "riegeli/base/status.h" 29 | #include "riegeli/base/types.h" 30 | 31 | namespace array_record { 32 | 33 | using riegeli::Annotate; 34 | using riegeli::Position; 35 | using riegeli::Reader; 36 | 37 | MaskedReader::MaskedReader(std::unique_ptr src_reader, 38 | size_t length) 39 | : buffer_(std::make_shared()) { 40 | auto pos = src_reader->pos(); 41 | buffer_->resize(length); 42 | if (!src_reader->Read(length, buffer_->data())) { 43 | Fail(Annotate(src_reader->status(), 44 | "Could not read from the underlying reader")); 45 | return; 46 | } 47 | /* 48 | * limit_pos 49 | * |---------------------------| 50 | * buffer_start buffer_limit 51 | * |................|----------| 52 | */ 53 | set_buffer(buffer_->data(), buffer_->size()); 54 | set_limit_pos(pos + buffer_->size()); 55 | } 56 | 57 | MaskedReader::MaskedReader(std::shared_ptr buffer, 58 | Position limit_pos) 59 | : buffer_(buffer) { 60 | /* 61 | * limit_pos 62 | * |---------------------------| 63 | * buffer_start buffer_limit 64 | * |................|----------| 65 | */ 66 | set_buffer(buffer_->data(), buffer_->size()); 67 | set_limit_pos(limit_pos); 68 | } 69 | 70 | MaskedReader::MaskedReader(MaskedReader &&other) noexcept 71 | : Reader(std::move(other)) { 72 | buffer_ = other.buffer_; // NOLINT(bugprone-use-after-move) 73 | other.Reset(riegeli::kClosed); // NOLINT(bugprone-use-after-move) 74 | } 75 | 76 | MaskedReader &MaskedReader::operator=(MaskedReader &&other) noexcept { 77 | // Move other 78 | Reader::operator=(static_cast(other)); 79 | // Copy the shared buffer. 80 | buffer_ = other.buffer_; 81 | // Close `other` 82 | other.Reset(riegeli::kClosed); 83 | return *this; 84 | } 85 | 86 | bool MaskedReader::PullSlow(size_t min_length, size_t recommended_length) { 87 | Fail(FailedPreconditionError("Should not pull beyond buffer")); 88 | return false; 89 | } 90 | 91 | bool MaskedReader::SeekSlow(riegeli::Position new_pos) { 92 | Fail(FailedPreconditionError("Should not seek beyond buffer")); 93 | return false; 94 | } 95 | 96 | absl::optional MaskedReader::SizeImpl() { 97 | return limit_pos(); 98 | } 99 | 100 | std::unique_ptr MaskedReader::NewReaderImpl(Position initial_pos) { 101 | if (!ok()) { 102 | return nullptr; 103 | } 104 | std::unique_ptr reader = 105 | absl::WrapUnique(new MaskedReader(buffer_, limit_pos())); 106 | if (!reader->Seek(initial_pos)) { 107 | return nullptr; 108 | } 109 | return reader; 110 | } 111 | 112 | } // namespace array_record 113 | -------------------------------------------------------------------------------- /cpp/masked_reader.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2022 Google LLC. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #ifndef ARRAY_RECORD_CPP_MASKED_READER_H_ 17 | #define ARRAY_RECORD_CPP_MASKED_READER_H_ 18 | 19 | #include 20 | #include 21 | 22 | #include "absl/types/optional.h" 23 | #include "riegeli/base/object.h" 24 | #include "riegeli/base/types.h" 25 | #include "riegeli/bytes/reader.h" 26 | 27 | namespace array_record { 28 | 29 | // `MaskedReader` is a riegeli reader constructed from the other reader with a 30 | // cached buffer over a specified region. User can further derive new readers 31 | // that share the same buffer. 32 | // 33 | // original file |----------------------------------------------| 34 | // 35 | // masked buffer |..............|---------------| 36 | // |--------------------^ Position is addressed from the 37 | // beginning of the underlying buffer. 38 | // 39 | // `MaskedBuffer` and the original file uses the same position base address. 40 | // User cannot seek to the region beyond the buffer region. 41 | // 42 | // This class is useful for reducing the number of PReads. User may create a 43 | // MaskedReader containing multiple chunks, then derive multiple chunk readers 44 | // from this reader sharing the same buffer. Hence, there's only one PRead 45 | // issued for multiple chunks. 46 | class MaskedReader : public riegeli::Reader { 47 | public: 48 | explicit MaskedReader(riegeli::Closed) : riegeli::Reader(riegeli::kClosed) {} 49 | 50 | MaskedReader(std::unique_ptr src_reader, size_t length); 51 | 52 | MaskedReader(MaskedReader &&other) noexcept; 53 | MaskedReader &operator=(MaskedReader &&other) noexcept; 54 | 55 | bool SupportsRandomAccess() override { return true; } 56 | bool SupportsNewReader() override { return true; } 57 | 58 | protected: 59 | bool PullSlow(size_t min_length, size_t recommended_length) override; 60 | bool SeekSlow(riegeli::Position new_pos) override; 61 | 62 | absl::optional SizeImpl() override; 63 | std::unique_ptr NewReaderImpl( 64 | riegeli::Position initial_pos) override; 65 | 66 | private: 67 | // Private constructor that copies itself. 68 | MaskedReader(std::shared_ptr buffer, 69 | riegeli::Position limit_pos); 70 | 71 | std::shared_ptr buffer_; 72 | }; 73 | 74 | } // namespace array_record 75 | 76 | #endif // ARRAY_RECORD_CPP_MASKED_READER_H_ 77 | -------------------------------------------------------------------------------- /cpp/masked_reader_test.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2022 Google LLC. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #include "cpp/masked_reader.h" 17 | 18 | #include 19 | 20 | #include "gtest/gtest.h" 21 | #include "riegeli/bytes/string_reader.h" 22 | 23 | namespace array_record { 24 | namespace { 25 | 26 | using riegeli::StringReader; 27 | 28 | TEST(MaskedReaderTest, SanityTest) { 29 | auto data = std::string("0123456789abcdef"); 30 | auto base_reader = StringReader(data); 31 | // 56789abc 32 | auto masked_reader1 = MaskedReader(base_reader.NewReader(5), 8); 33 | // Matches where we offset the reader. 34 | EXPECT_EQ(masked_reader1.pos(), 5); 35 | // Matches offset + mask length 36 | EXPECT_EQ(masked_reader1.Size(), 8 + 5); 37 | { 38 | std::string result; 39 | masked_reader1.Read(4, result); 40 | EXPECT_EQ(result, "5678"); 41 | EXPECT_EQ(masked_reader1.pos(), 5 + 4); 42 | masked_reader1.Read(4, result); 43 | EXPECT_EQ(result, "9abc"); 44 | EXPECT_EQ(masked_reader1.pos(), 5 + 8); 45 | } 46 | 47 | auto masked_reader2 = masked_reader1.NewReader(7); 48 | // Size does not change 49 | EXPECT_EQ(masked_reader2->Size(), 8 + 5); 50 | // pos is the new position we set from NewReader 51 | EXPECT_EQ(masked_reader2->pos(), 7); 52 | { 53 | std::string result; 54 | masked_reader2->Read(4, result); 55 | EXPECT_EQ(result, "789a"); 56 | } 57 | 58 | // Reaching position that is out of bound does not fail the base reader. 59 | // It simply returns a nullptr. 60 | EXPECT_EQ(masked_reader1.NewReader(0), nullptr); 61 | EXPECT_EQ(masked_reader1.NewReader(20), nullptr); 62 | EXPECT_TRUE(masked_reader1.ok()); 63 | 64 | // Support seek 65 | masked_reader1.Seek(6); 66 | { 67 | std::string result; 68 | masked_reader1.Read(4, result); 69 | EXPECT_EQ(result, "6789"); 70 | } 71 | 72 | // Seek beyond buffer is a failure. 73 | EXPECT_FALSE(masked_reader1.Seek(20)); 74 | } 75 | 76 | } // namespace 77 | 78 | } // namespace array_record 79 | -------------------------------------------------------------------------------- /cpp/parallel_for.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2022 Google LLC. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #ifndef ARRAY_RECORD_CPP_PARALLEL_FOR_H_ 17 | #define ARRAY_RECORD_CPP_PARALLEL_FOR_H_ 18 | 19 | #include 20 | #include 21 | #include 22 | #include 23 | #include 24 | 25 | #include "absl/base/optimization.h" 26 | #include "absl/status/status.h" 27 | #include "absl/synchronization/mutex.h" 28 | #include "cpp/common.h" 29 | #include "cpp/thread_pool.h" 30 | 31 | namespace array_record { 32 | 33 | // kDynamicBatchSize - when a batch size isn't specified, ParallelFor defaults 34 | // to dividing the work into (4 * num_threads) batches, enabling decently good 35 | // parallelism, while minimizing coordination overhead. 36 | enum : size_t { 37 | kDynamicBatchSize = std::numeric_limits::max(), 38 | }; 39 | 40 | // Options for ParallelFor. The defaults are sufficient for most users. 41 | struct ParallelForOptions { 42 | // It may be desirable to limit parallelism in some cases if e.g.: 43 | // 1. A portion of the loop body requires synchronization and Amdahl's Law 44 | // prevents scaling past a small number of threads. 45 | // 2. You're running on a NUMA system and don't want this loop to execute 46 | // across NUMA nodes. 47 | size_t max_parallelism = std::numeric_limits::max(); 48 | }; 49 | 50 | // ParallelFor - execute a for-loop in parallel, using both the calling thread, 51 | // plus the threads available in a ARThreadPool argument. 52 | // 53 | // Arguments: 54 | // 55 | // * - the sequence to be processed, eg "Seq(vec.size())". 56 | // 57 | // * - the threadpool to use (in addition to the main thread). If this 58 | // parameter is nullptr, then ParallelFor will compile down to a vanilla 59 | // single-threaded for-loop. It is permissible for the calling thread to be a 60 | // member of . 61 | // 62 | // * - the method to call for each value in . 63 | // 64 | // * [template param] - the number of calls to that 65 | // each thread will perform before synchronizing access to the loop counter. If 66 | // not provided, the array will be divided into (n_threads * 4) work batches, 67 | // enabling good parallelism in most cases, while minimizing synchronization 68 | // overhead. 69 | // 70 | // 71 | // Example: Compute 1M sqrts in parallel. 72 | // 73 | // vector sqrts(1000000); 74 | // ParallelFor(Seq(sqrts.size()), &pool, [&](size_t i) { 75 | // sqrts[i] = sqrt(i); 76 | // }); 77 | // 78 | // 79 | // Example: Only the evens, using SeqWithStride. 80 | // 81 | // ParallelFor(SeqWithStride<2>(0, sqrts.size()), &pool, [&sqrts](size_t i) { 82 | // sqrts[i] = i; 83 | // }); 84 | // 85 | // 86 | // Example: Execute N expensive tasks in parallel, in batches of 1-at-a-time. 87 | // 88 | // ParallelFor<1>(Seq(tasks.size()), &pool, [&](size_t j) { 89 | // DoSomethingExpensive(tasks[j]); 90 | // }); 91 | // 92 | template 94 | inline void ParallelFor(SeqT seq, ARThreadPool* pool, Function func, 95 | ParallelForOptions opts = ParallelForOptions()); 96 | 97 | // ParallelForWithStatus - Similar to ParallelFor, except it can short circuit 98 | // if an error occurred. 99 | // 100 | // Arguments: 101 | // * - the sequence to be processed, eg "Seq(vec.size())". 102 | // 103 | // * - the threadpool to use (in addition to the main thread). If this 104 | // parameter is nullptr, then ParallelFor will compile down to a vanilla 105 | // single-threaded for-loop. It is permissible for the calling thread to be a 106 | // member of . 107 | // 108 | // * - the method to call for each value in with type 109 | // interface of std::function. 110 | // 111 | // * [template param] - the number of calls to that 112 | // each thread will perform before synchronizing access to the loop counter. If 113 | // not provided, the array will be divided into (n_threads * 4) work batches, 114 | // enabling good parallelism in most cases, while minimizing synchronization 115 | // overhead. 116 | // 117 | // Example: 118 | // 119 | // auto status = ParallelForWithStatus<1>( 120 | // Seq(tasks.size()), &pool, [&](size_t idx) -> absl::Status { 121 | // RETURN_IF_ERROR(RunTask(tasks[idx])); 122 | // return absl::OkStatus(); 123 | // }); 124 | // 125 | template 127 | inline absl::Status ParallelForWithStatus( 128 | SeqT seq, ARThreadPool* pool, Function Func, 129 | ParallelForOptions opts = ParallelForOptions()) { 130 | absl::Status finite_check_status = absl::OkStatus(); 131 | 132 | std::atomic_bool is_ok_status{true}; 133 | absl::Mutex mutex; 134 | ParallelFor( 135 | seq, pool, 136 | [&](size_t idx) { 137 | if (!is_ok_status.load(std::memory_order_relaxed)) { 138 | return; 139 | } 140 | absl::Status status = Func(idx); 141 | if (!status.ok()) { 142 | absl::MutexLock lock(&mutex); 143 | finite_check_status = status; 144 | is_ok_status.store(false, std::memory_order_relaxed); 145 | } 146 | }, 147 | opts); 148 | return finite_check_status; 149 | } 150 | 151 | //////////////////////////////////////////////////////////////////////////////// 152 | // IMPLEMENTATION DETAILS 153 | //////////////////////////////////////////////////////////////////////////////// 154 | 155 | namespace parallel_for_internal { 156 | 157 | // ParallelForClosure - a single heap-allocated object that holds the loop's 158 | // state. The object will delete itself when the final task completes. 159 | template 160 | class ParallelForClosure { 161 | public: 162 | static constexpr bool kIsDynamicBatch = (kItersPerBatch == kDynamicBatchSize); 163 | ParallelForClosure(SeqT seq, Function func) 164 | : func_(func), 165 | index_(*seq.begin()), 166 | range_end_(*seq.end()), 167 | reference_count_(1) {} 168 | 169 | inline void RunParallel(ARThreadPool* pool, size_t desired_threads) { 170 | // Don't push more tasks to the pool than we have work for. Also, if 171 | // parallelism is limited by desired_threads not thread pool size, subtract 172 | // 1 from the number of threads to push to account for the main thread. 173 | size_t n_threads = 174 | std::min(desired_threads - 1, pool->NumThreads()); 175 | 176 | // Handle dynamic batch size. 177 | if (kIsDynamicBatch) { 178 | batch_size_ = 179 | SeqT::Stride() * std::max(1ul, desired_threads / 4 / n_threads); 180 | } 181 | 182 | reference_count_ += n_threads; 183 | while (n_threads--) { 184 | pool->Schedule([this]() { Run(); }); 185 | } 186 | 187 | // Do work on the main thread. Once this returns, we are guaranteed that all 188 | // batches have been assigned to some thread. 189 | DoWork(); 190 | 191 | // Then wait for all worker threads to exit the core loop. Thus, once the 192 | // main thread is able to take a WriterLock, we are guaranteed that all 193 | // batches have finished, allowing the main thread to move on. 194 | // 195 | // The main thread does *NOT* wait for ARThreadPool tasks that haven't yet 196 | // entered the core loop. This is important for handling scenarios where 197 | // the ARThreadPool falls significantly behind and hasn't started some of 198 | // the tasks assigned to it. Once assigned, those tasks will quickly realize 199 | // that there is no work left, and the final task to schedule will delete 200 | // this heap-allocated object. 201 | // 202 | termination_mutex_.WriterLock(); 203 | termination_mutex_.WriterUnlock(); 204 | 205 | // Drop main thread's reference. 206 | if (--reference_count_ == 0) delete this; 207 | } 208 | 209 | void Run() { 210 | // Do work on a child thread. Before starting any work, each child thread 211 | // takes a reader lock, preventing the main thread from finishing while 212 | // any child threads are still executing in the core loop. 213 | termination_mutex_.ReaderLock(); 214 | DoWork(); 215 | termination_mutex_.ReaderUnlock(); 216 | 217 | // Drop child thread's reference. 218 | if (--reference_count_ == 0) delete this; 219 | } 220 | 221 | // DoWork - the "core loop", executed in parallel on N threads. 222 | inline void DoWork() { 223 | // Performance Note: Copying constant values to the stack allows the 224 | // compiler to know that they are actually constant and can be assigned to 225 | // registers (the 'const' keyword is insufficient). 226 | const size_t range_end = range_end_; 227 | 228 | // Performance Note: when batch size is not dynamic, the compiler will treat 229 | // it as a constant that can be directly inlined into the code w/o consuming 230 | // a register. 231 | constexpr size_t kStaticBatchSize = SeqT::Stride() * kItersPerBatch; 232 | const size_t batch_size = kIsDynamicBatch ? batch_size_ : kStaticBatchSize; 233 | 234 | // The core loop: 235 | for (;;) { 236 | // The std::atomic index_ coordinates sequential batch assignment. 237 | const size_t batch_begin = index_.fetch_add(batch_size); 238 | // Once assigned, batches execute w/o further coordination. 239 | const size_t batch_end = std::min(batch_begin + batch_size, range_end); 240 | if (ABSL_PREDICT_FALSE(batch_begin >= range_end)) break; 241 | for (size_t idx : SeqWithStride(batch_begin, batch_end)) { 242 | func_(idx); 243 | } 244 | } 245 | } 246 | 247 | private: 248 | Function func_; 249 | 250 | // The index_ is used by worker threads to coordinate sequential batch 251 | // assignment. This is the only coordination mechanism inside the core loop. 252 | std::atomic index_; 253 | 254 | // The iteration stops at range_end_. 255 | const size_t range_end_; 256 | 257 | // The termination_mutex_ coordinates termination of the for-loop, allowing 258 | // the main thread to batch until all child threads have exited the core-loop. 259 | absl::Mutex termination_mutex_; 260 | 261 | // For smaller arrays, it's possible for the work to finish before some of 262 | // the ARThreadPool tasks have started running, and in some extreme cases, it 263 | // might take entire milliseconds before these tasks begin running. The main 264 | // thread will continue doing other work, and the last task to schedule and 265 | // terminate will delete this heap allocated object. 266 | std::atomic reference_count_; 267 | 268 | // The batch_size_ member variable is only used when the batch size is 269 | // dynamic (io, kItersPerBatch == kDynamicBatchSize). 270 | size_t batch_size_ = kItersPerBatch; 271 | }; 272 | 273 | } // namespace parallel_for_internal 274 | 275 | template 276 | inline void ParallelFor(SeqT seq, ARThreadPool* pool, Function func, 277 | ParallelForOptions opts) { 278 | // Figure out how many batches of work we have. 279 | constexpr size_t kMinItersPerBatch = 280 | kItersPerBatch == kDynamicBatchSize ? 1 : kItersPerBatch; 281 | const size_t desired_threads = std::min( 282 | opts.max_parallelism, DivRoundUp(*seq.end() - *seq.begin(), 283 | SeqT::Stride() * kMinItersPerBatch)); 284 | 285 | // Unfortunately TF ThreadPool has no interface to monitor queue fullness 286 | // Serialized vanilla for-loop for handling any of: 287 | // 288 | // * No ThreadPool provided. 289 | // 290 | // * ThreadPool has fallen very far behind. 291 | // 292 | // * Small arrays w/ only 1 batch of work to do. 293 | // 294 | // Note that the compiler will inline this logic, making ParallelFor 295 | // equivalent to a traditional C++ for-loop for the cases above. 296 | if (pool == nullptr || 297 | desired_threads <= 1 298 | ) { 299 | for (size_t idx : seq) { 300 | func(idx); 301 | } 302 | return; 303 | } 304 | 305 | // Otherwise, fire up the threadpool. Note that the shared closure object is 306 | // heap-allocated, allowing this method to finish even if some tasks haven't 307 | // started running yet. The object will be deleted by the last finishing task, 308 | // or possibly by this thread, whichever is last to terminate. 309 | using parallel_for_internal::ParallelForClosure; 310 | auto closure = 311 | new ParallelForClosure(seq, func); 312 | closure->RunParallel(pool, desired_threads); 313 | } 314 | 315 | } // namespace array_record 316 | 317 | #endif // ARRAY_RECORD_CPP_PARALLEL_FOR_H_ 318 | -------------------------------------------------------------------------------- /cpp/parallel_for_test.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2022 Google LLC. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | // Tests for parallel_for.h. 17 | #include "cpp/parallel_for.h" 18 | 19 | #include 20 | #include 21 | #include 22 | #include 23 | #include 24 | #include 25 | #include 26 | #include 27 | 28 | #include "gmock/gmock.h" 29 | #include "gtest/gtest.h" 30 | #include "absl/functional/function_ref.h" 31 | #include "absl/status/status.h" 32 | #include "cpp/common.h" 33 | #include "cpp/thread_pool.h" 34 | 35 | namespace array_record { 36 | 37 | class ParallelForTest : public testing::Test { 38 | protected: 39 | void SetUp() final { pool_ = ArrayRecordGlobalPool(); } 40 | 41 | public: 42 | static constexpr int32_t kNumElements = 1000000; 43 | ARThreadPool* pool_; 44 | }; 45 | 46 | TEST_F(ParallelForTest, AtomicCounterTest) { 47 | std::vector items(kNumElements, 0); 48 | ParallelFor(Seq(kNumElements), pool_, [&](size_t j) { items[j] = 1; }); 49 | int32_t num_accessed = 0; 50 | for (auto item_was_accessed : items) { 51 | if (item_was_accessed) { 52 | num_accessed++; 53 | } 54 | } 55 | EXPECT_EQ(num_accessed, kNumElements); 56 | } 57 | 58 | TEST_F(ParallelForTest, ImplicitStepImplicitBlock) { 59 | std::vector result(kNumElements); 60 | auto l = [&result](size_t j) { result[j] += sqrt(j); }; 61 | 62 | ParallelFor(Seq(kNumElements), pool_, l); 63 | for (size_t j : Seq(kNumElements)) { 64 | EXPECT_EQ(sqrt(j), result[j]); 65 | } 66 | 67 | result.clear(); 68 | result.resize(kNumElements); 69 | ParallelFor(Seq(kNumElements), nullptr, l); 70 | for (size_t j : Seq(kNumElements)) { 71 | EXPECT_EQ(sqrt(j), result[j]); 72 | } 73 | } 74 | 75 | TEST_F(ParallelForTest, ImplicitStepExplicitBlock) { 76 | std::vector result(kNumElements); 77 | auto l = [&result](size_t j) { result[j] += sqrt(j); }; 78 | ParallelFor<10>(Seq(kNumElements), pool_, l); 79 | 80 | for (size_t j : Seq(kNumElements)) { 81 | EXPECT_EQ(sqrt(j), result[j]); 82 | } 83 | 84 | result.clear(); 85 | result.resize(kNumElements); 86 | ParallelFor<10>(Seq(kNumElements), nullptr, l); 87 | 88 | for (size_t j : Seq(kNumElements)) { 89 | EXPECT_EQ(sqrt(j), result[j]); 90 | } 91 | } 92 | 93 | TEST_F(ParallelForTest, ExplicitStepExplicitBlock) { 94 | std::vector result(kNumElements); 95 | auto l = [&result](size_t j) { result[j] += sqrt(j); }; 96 | ParallelFor<10>(SeqWithStride<2>(kNumElements), pool_, l); 97 | 98 | for (size_t j : Seq(kNumElements)) { 99 | // We only did the even numbered elements, so the odd ones should be zero. 100 | if (j & 1) { 101 | EXPECT_EQ(result[j], 0.0); 102 | } else { 103 | EXPECT_EQ(sqrt(j), result[j]); 104 | } 105 | } 106 | 107 | result.clear(); 108 | result.resize(kNumElements); 109 | ParallelFor<10>(SeqWithStride<2>(kNumElements), nullptr, l); 110 | 111 | for (size_t j : Seq(kNumElements)) { 112 | // We only did the even numbered elements, so the odd ones should be zero. 113 | if (j & 1) { 114 | EXPECT_EQ(result[j], 0.0); 115 | } else { 116 | EXPECT_EQ(sqrt(j), result[j]); 117 | } 118 | } 119 | } 120 | 121 | TEST_F(ParallelForTest, ExplicitStepImplicitBlock) { 122 | std::vector result(kNumElements); 123 | auto l = [&result](size_t j) { result[j] += sqrt(j); }; 124 | ParallelFor(SeqWithStride<2>(kNumElements), pool_, l); 125 | 126 | for (size_t j : Seq(kNumElements)) { 127 | // We only did the even numbered elements, so the odd ones should be zero. 128 | if (j & 1) { 129 | EXPECT_EQ(result[j], 0.0); 130 | } else { 131 | EXPECT_EQ(sqrt(j), result[j]); 132 | } 133 | } 134 | 135 | result.clear(); 136 | result.resize(kNumElements); 137 | ParallelFor(SeqWithStride<2>(kNumElements), nullptr, l); 138 | 139 | for (size_t j : Seq(kNumElements)) { 140 | // We only did the even numbered elements, so the odd ones should be zero. 141 | if (j & 1) { 142 | EXPECT_EQ(result[j], 0.0); 143 | } else { 144 | EXPECT_EQ(sqrt(j), result[j]); 145 | } 146 | } 147 | } 148 | 149 | TEST_F(ParallelForTest, ExampleCompiles) { 150 | // Once c-style approves lambdas, usage of this library will become very clean 151 | // as illustrated below. Compute the square root of every number from 0 to 152 | // 1000000 in parallel. 153 | std::vector sqrts(1000000); 154 | auto pool = ArrayRecordGlobalPool(); 155 | 156 | ParallelFor(Seq(sqrts.size()), pool, 157 | [&sqrts](size_t j) { sqrts[j] = sqrt(j); }); 158 | 159 | // Only compute the square roots of even numbers by using an explicit step. 160 | ParallelFor(SeqWithStride<2>(sqrts.size()), pool, 161 | [&sqrts](size_t j) { sqrts[j] = j; }); 162 | 163 | // The block_size parameter can be adjusted to control the granularity of 164 | // parallelism. This parameter represents the number of iterations of the 165 | // loop that will be done in a single thread before communicating with any 166 | // other threads. Smaller block sizes lead to better load balancing between 167 | // threads. Larger block sizes lead to less communication overhead and less 168 | // risk of false sharing (http://en.wikipedia.org/wiki/False_sharing) 169 | // when writing to adjacent array elements from different threads based 170 | // on the loop index. The default block size creates approximately 4 171 | // blocks per thread. 172 | // 173 | // Use an explicit block size of 10. 174 | ParallelFor<10>(SeqWithStride<2>(sqrts.size()), pool, 175 | [&sqrts](size_t j) { sqrts[j] = j; }); 176 | } 177 | 178 | TEST_F(ParallelForTest, ParallelForWithStatusTest) { 179 | std::atomic_int counter = 0; 180 | auto status = 181 | ParallelForWithStatus<1>(Seq(kNumElements), pool_, [&](size_t i) { 182 | counter.fetch_add(1, std::memory_order_release); 183 | return absl::OkStatus(); 184 | }); 185 | EXPECT_TRUE(status.ok()); 186 | EXPECT_EQ(counter.load(std::memory_order_acquire), kNumElements); 187 | } 188 | 189 | TEST_F(ParallelForTest, ParallelForWithStatusTestShortCircuit) { 190 | std::atomic_int counter = 0; 191 | auto status = 192 | ParallelForWithStatus<1>(Seq(kNumElements), pool_, [&](size_t i) { 193 | counter.fetch_add(1, std::memory_order_release); 194 | return absl::UnknownError("Intended error"); 195 | }); 196 | EXPECT_EQ(status.code(), absl::StatusCode::kUnknown); 197 | EXPECT_LT(counter.load(std::memory_order_acquire), kNumElements); 198 | } 199 | 200 | } // namespace array_record 201 | -------------------------------------------------------------------------------- /cpp/sequenced_chunk_writer.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2022 Google LLC. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #include "cpp/sequenced_chunk_writer.h" 17 | 18 | #include // NOLINT(build/c++11) 19 | #include 20 | #include // NOLINT(build/c++11) 21 | #include 22 | 23 | #include "absl/status/status.h" 24 | #include "absl/status/statusor.h" 25 | #include "absl/strings/str_format.h" 26 | #include "absl/synchronization/mutex.h" 27 | #include "riegeli/base/status.h" 28 | #include "riegeli/base/types.h" 29 | #include "riegeli/chunk_encoding/chunk.h" 30 | #include "riegeli/chunk_encoding/constants.h" 31 | #include "riegeli/records/chunk_writer.h" 32 | 33 | namespace array_record { 34 | 35 | bool SequencedChunkWriterBase::CommitFutureChunk( 36 | std::future>&& future_chunk) { 37 | absl::MutexLock l(&mu_); 38 | if (!ok()) { 39 | return false; 40 | } 41 | queue_.push(std::move(future_chunk)); 42 | return true; 43 | } 44 | 45 | bool SequencedChunkWriterBase::SubmitFutureChunks(bool block) { 46 | // We need to use TryLock to prevent deadlock. 47 | // 48 | // std::future::get() blocks if the result wasn't ready. 49 | // Hence the following scenario triggers a deadlock. 50 | // T1: 51 | // SubmitFutureChunks(true) 52 | // mu_ holds 53 | // Blocks on queue_.front().get(); 54 | // T2: 55 | // In charge to fulfill the future of queue_.front() on its exit. 56 | // SubmitFutureChunks(false) 57 | // Blocks on mu_ if we used mu_.Lock() instead of mu_.TryLock() 58 | // 59 | // NOTE: Even if ok() is false, the below loop will drain queue_, either 60 | // completely if block is true, or until a non-ready future is at the front of 61 | // the queue in the non-blocking case. If ok() is false, the front element is 62 | // popped from the queue and discarded. 63 | 64 | if (block) { 65 | // When blocking, we block both on mutex acquisition and on future 66 | // completion. 67 | absl::MutexLock lock(&mu_); 68 | riegeli::ChunkWriter* writer = get_writer(); 69 | while (!queue_.empty()) { 70 | TrySubmitFirstFutureChunk(writer); 71 | } 72 | return ok(); 73 | } else if (mu_.TryLock()) { 74 | // When non-blocking, we only proceed if we can lock the mutex without 75 | // blocking, and we only process those futures that are ready. We need 76 | // to unlock the mutex manually in this case, and take care to call ok() 77 | // under the lock. 78 | riegeli::ChunkWriter* writer = get_writer(); 79 | while (!queue_.empty() && 80 | queue_.front().wait_for(std::chrono::microseconds::zero()) == 81 | std::future_status::ready) { 82 | TrySubmitFirstFutureChunk(writer); 83 | } 84 | bool result = ok(); 85 | mu_.Unlock(); 86 | return result; 87 | } else { 88 | return true; 89 | } 90 | } 91 | 92 | void SequencedChunkWriterBase::TrySubmitFirstFutureChunk( 93 | riegeli::ChunkWriter* chunk_writer) { 94 | auto status_or_chunk = queue_.front().get(); 95 | queue_.pop(); 96 | 97 | if (!ok() || !chunk_writer->ok()) { 98 | // Note (see above): the front of the queue is popped even if we discard it 99 | // now. 100 | return; 101 | } 102 | // Set self unhealthy for bad chunks. 103 | if (!status_or_chunk.ok()) { 104 | Fail(riegeli::Annotate( 105 | status_or_chunk.status(), 106 | absl::StrFormat("Could not submit chunk: %d", submitted_chunks_))); 107 | return; 108 | } 109 | riegeli::Chunk chunk = std::move(status_or_chunk.value()); 110 | uint64_t chunk_offset = chunk_writer->pos(); 111 | uint64_t decoded_data_size = chunk.header.decoded_data_size(); 112 | uint64_t num_records = chunk.header.num_records(); 113 | 114 | if (!chunk_writer->WriteChunk(std::move(chunk))) { 115 | Fail(riegeli::Annotate( 116 | chunk_writer->status(), 117 | absl::StrFormat("Could not submit chunk: %d", submitted_chunks_))); 118 | return; 119 | } 120 | if (pad_to_block_boundary_) { 121 | if (!chunk_writer->PadToBlockBoundary()) { 122 | Fail(riegeli::Annotate( 123 | chunk_writer->status(), 124 | absl::StrFormat("Could not pad boundary for chunk: %d", 125 | submitted_chunks_))); 126 | return; 127 | } 128 | } 129 | if (!chunk_writer->Flush(riegeli::FlushType::kFromObject)) { 130 | Fail(riegeli::Annotate( 131 | chunk_writer->status(), 132 | absl::StrFormat("Could not flush chunk: %d", submitted_chunks_))); 133 | return; 134 | } 135 | if (callback_) { 136 | (*callback_)(submitted_chunks_, chunk_offset, decoded_data_size, 137 | num_records); 138 | } 139 | submitted_chunks_++; 140 | } 141 | 142 | void SequencedChunkWriterBase::Initialize() { 143 | auto* chunk_writer = get_writer(); 144 | riegeli::Chunk chunk; 145 | chunk.header = riegeli::ChunkHeader(chunk.data, 146 | riegeli::ChunkType::kFileSignature, 0, 0); 147 | if (!chunk_writer->WriteChunk(chunk)) { 148 | Fail(riegeli::Annotate(chunk_writer->status(), 149 | "Failed to create the file header")); 150 | } 151 | if (!chunk_writer->Flush(riegeli::FlushType::kFromObject)) { 152 | Fail(riegeli::Annotate(chunk_writer->status(), "Could not flush")); 153 | } 154 | } 155 | 156 | void SequencedChunkWriterBase::Done() { 157 | if (!SubmitFutureChunks(true)) { 158 | Fail(absl::InternalError("Unable to submit pending chunks")); 159 | return; 160 | } 161 | auto* chunk_writer = get_writer(); 162 | if (!chunk_writer->Close()) { 163 | Fail(riegeli::Annotate(chunk_writer->status(), 164 | "Failed to close chunk_writer")); 165 | } 166 | } 167 | 168 | } // namespace array_record 169 | -------------------------------------------------------------------------------- /cpp/sequenced_chunk_writer.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2022 Google LLC. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | // Low-level API for building off-memory data structures with riegeli. 17 | // 18 | // `SequencedChunkWriter` writes chunks of records to an abstracted destination. 19 | // This class abstract out the generic chunk writing logic from concrete logic 20 | // that builds up the data structures for future access. This class is 21 | // thread-safe and allows users to encode each chunk concurrently while 22 | // maintaining the sequence order of the chunks. 23 | 24 | #ifndef ARRAY_RECORD_CPP_SEQUENCED_CHUNK_WRITER_H_ 25 | #define ARRAY_RECORD_CPP_SEQUENCED_CHUNK_WRITER_H_ 26 | 27 | #include 28 | #include // NOLINT(build/c++11) 29 | #include 30 | #include 31 | 32 | #include "absl/base/thread_annotations.h" 33 | #include "absl/status/status.h" 34 | #include "absl/status/statusor.h" 35 | #include "absl/synchronization/mutex.h" 36 | #include "cpp/common.h" 37 | #include "riegeli/base/initializer.h" 38 | #include "riegeli/base/object.h" 39 | #include "riegeli/bytes/writer.h" 40 | #include "riegeli/chunk_encoding/chunk.h" 41 | #include "riegeli/records/chunk_writer.h" 42 | 43 | namespace array_record { 44 | 45 | // Template parameter independent part of `SequencedChunkWriter`. 46 | class SequencedChunkWriterBase : public riegeli::Object { 47 | SequencedChunkWriterBase(const SequencedChunkWriterBase&) = delete; 48 | SequencedChunkWriterBase& operator=(const SequencedChunkWriterBase&) = delete; 49 | SequencedChunkWriterBase(SequencedChunkWriterBase&&) = delete; 50 | SequencedChunkWriterBase& operator=(SequencedChunkWriterBase&&) = delete; 51 | 52 | public: 53 | // The `SequencedChunkWriter` invokes `SubmitChunkCallback` for every 54 | // successful chunk writing. In other words, the invocation only happens when 55 | // none of the errors occur (SequencedChunkWriter internal state, chunk 56 | // correctness, underlying writer state, etc.) The callback consists of four 57 | // arguments: 58 | // 59 | // chunk_seq: sequence number of the chunk in the file. Indexed from 0. 60 | // chunk_offset: byte offset of the chunk in the file. A reader can seek this 61 | // offset and decode the chunk without other information. 62 | // decoded_data_size: byte size of the decoded data. Users may serialize this 63 | // field for readers to allocate memory for the decoded data. 64 | // num_records: number of records in the chunk. 65 | class SubmitChunkCallback { 66 | public: 67 | virtual ~SubmitChunkCallback() {} 68 | virtual void operator()(uint64_t chunk_seq, uint64_t chunk_offset, 69 | uint64_t decoded_data_size, 70 | uint64_t num_records) = 0; 71 | }; 72 | 73 | // Commits a future chunk to the `SequencedChunkWriter` before materializing 74 | // the chunk. Users can encode the chunk in a separated thread at the cost of 75 | // larger temporal memory usage. `SequencedChunkWriter` serializes the chunks 76 | // at the order of this function call. 77 | // 78 | // Example 1: packaged_task 79 | // 80 | // std::packaged_task()> encoding_task( 81 | // []() -> absl::StatusOr { 82 | // ... returns a riegeli::Chunk on success. 83 | // }); 84 | // std::future> task_future = 85 | // encoding_task.get(); 86 | // sequenced_chunk_writer->CommitFutureChunk(std::move(task_future)); 87 | // 88 | // // Computes the encoding task in a thread pool. 89 | // pool->Schedule(std::move(encoding_task)); 90 | // 91 | // Example 2: promise and future 92 | // 93 | // std::promise> chunk_promise; 94 | // RET_CHECK(sequenced_chunk_writer->CommitFutureChunk( 95 | // chunk_promise.get_future())) << sequenced_chunk_writer->status(); 96 | // pool->Schedule([chunk_promise = std::move(chunk_promise)] { 97 | // // computes chunk 98 | // chunk_promise.set_value(status_or_chunk); 99 | // }); 100 | // 101 | // Although `SequencedChunkWriter` is thread-safe, this method should be 102 | // invoked from a single thread because it doesn't make sense to submit future 103 | // chunks without a proper order. 104 | bool CommitFutureChunk( 105 | std::future>&& future_chunk); 106 | 107 | // Extracts the future chunks and submits them to the underlying destination. 108 | // This operation may block if the argument `block` was true. This method is 109 | // thread-safe, and we recommend users invoke it with `block=false` in each 110 | // thread to reduce the temporal memory usage. 111 | // 112 | // If ok() is false before or during this operation, queue elements continue 113 | // to be extracted, but are immediately discarded. 114 | // 115 | // Example 1: single thread usage 116 | // 117 | // std::promise> chunk_promise; 118 | // RET_CHECK(sequenced_chunk_writer->CommitFutureChunk( 119 | // chunk_promise.get_future())) << sequenced_chunk_writer->status(); 120 | // chunk_promise.set_value(ComputesChunk()); 121 | // RET_CHECK(writer->SubmitFutureChunks(true)) << writer->status(); 122 | // 123 | // Example 2: concurrent access 124 | // 125 | // riegeli::SharedPtr writer(riegeli::MakerSchedule([writer, 128 | // chunk_promise = std::move(chunk_promise)]() mutable { 129 | // chunk_promise.set_value(status_or_chunk); 130 | // // Should not block otherwise would enter deadlock! 131 | // writer->SubmitFutureChunks(false); 132 | // }); 133 | // // Blocking the main thread is fine. 134 | // RET_CHECK(writer->SubmitFutureChunks(true)) << writer->status(); 135 | // 136 | bool SubmitFutureChunks(bool block = false); 137 | 138 | // Pads to 64KB boundary for future chunk submission. (Default false). 139 | void set_pad_to_block_boundary(bool pad_to_block_boundary) { 140 | absl::MutexLock l(&mu_); 141 | pad_to_block_boundary_ = pad_to_block_boundary; 142 | } 143 | bool pad_to_block_boundary() { 144 | absl::MutexLock l(&mu_); 145 | return pad_to_block_boundary_; 146 | } 147 | 148 | // Setup a callback for each committed chunk. See CommitChunkCallback 149 | // comments for details. 150 | void set_submit_chunk_callback(SubmitChunkCallback* callback) { 151 | absl::MutexLock l(&mu_); 152 | callback_ = callback; 153 | } 154 | 155 | // Guard the status access. 156 | absl::Status status() const { 157 | absl::ReaderMutexLock l(&mu_); 158 | return riegeli::Object::status(); 159 | } 160 | 161 | protected: 162 | SequencedChunkWriterBase() {} 163 | virtual riegeli::ChunkWriter* get_writer() = 0; 164 | 165 | // Initializes and validates the underlying writer states. 166 | void Initialize(); 167 | 168 | // Callback for riegeli::Object::Close. 169 | void Done() override; 170 | 171 | private: 172 | // Attempts to submit the first chunk from the queue. Expects that the lock is 173 | // already held. Even if ok() is false on entry already, the queue element is 174 | // removed (and discarded). 175 | void TrySubmitFirstFutureChunk(riegeli::ChunkWriter* chunk_writer) 176 | ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); 177 | 178 | mutable absl::Mutex mu_; 179 | bool pad_to_block_boundary_ ABSL_GUARDED_BY(mu_) = false; 180 | SubmitChunkCallback* callback_ ABSL_GUARDED_BY(mu_) = nullptr; 181 | 182 | // Records the sequence number of submitted chunks. 183 | uint64_t submitted_chunks_ ABSL_GUARDED_BY(mu_) = 0; 184 | 185 | // Queue for storing the future chunks. 186 | std::queue>> queue_ 187 | ABSL_GUARDED_BY(mu_); 188 | }; 189 | 190 | // A `SequencedChunkWriter` writes chunks (a blob of multiple and possibly 191 | // compressed records) rather than individual records to an abstracted 192 | // destination. `SequencedChunkWriter` allows users to encode each chunk 193 | // concurrently while keeping the chunk sequence order as the input order. 194 | // 195 | // Users can also supply a `CommitChunkCallback` to collect chunk sequence 196 | // numbers, offsets in the file, decoded data size, and the number of records in 197 | // each chunk. Users may use the callback information to produce a lookup table 198 | // in the footer for an efficient reader to decode multiple chunks in parallel. 199 | // 200 | // Example usage: 201 | // 202 | // // Step 1: open the writer with file backend. 203 | // File* file = file::OpenOrDie(...); 204 | // riegeli::SharedPtr writer(riegeli::Maker( 205 | // riegeli::Maker(filename_or_file))); 206 | // 207 | // // Step 2: create a chunk encoding task. 208 | // std::packaged_task()> encoding_task( 209 | // []() -> absl::StatusOr { 210 | // ... returns a riegeli::Chunk on success. 211 | // }); 212 | // 213 | // // Step 3: book a slot for writing the encoded chunk. 214 | // RET_CHECK(writer->CommitFutureChunk( 215 | // encoding_task.get_future())) << writer->status(); 216 | // 217 | // // Step 4: Computes the encoding task in a thread pool. 218 | // pool->Schedule([=,encoding_task=std::move(encoding_task)]() mutable { 219 | // encoding_task(); // std::promise fulfilled. 220 | // // riegeli::SharedPtr pevents the writer to go out of scope, so it is 221 | // // safe to invoke the method here. 222 | // writer->SubmitFutureChunks(false); 223 | // }); 224 | // 225 | // // Repeats step 2 to 4. 226 | // 227 | // // Finally, close the writer. 228 | // RET_CHECK(writer->Close()) << writer->status(); 229 | // 230 | // 231 | // It is necessary to call `Close()` at the end of a successful writing session, 232 | // and it is recommended to call `Close()` at the end of a successful reading 233 | // session. It is not needed to call `Close()` on early returns, assuming that 234 | // contents of the destination do not matter after all, e.g. because a failure 235 | // is being reported instead; the destructor releases resources in any case. 236 | // 237 | // `SequencedChunkWriter` inherits riegeli::Object which provides useful 238 | // abstractions for state management of IO-like operations. Instead of the 239 | // common absl::Status/StatusOr for each method, the riegeli::Object's error 240 | // handling mechanism uses bool and separated `status()`, `ok()`, `is_open()` 241 | // for users to handle different types of failure states. 242 | // 243 | // 244 | // `SequencedChunkWriter` use templated backend abstraction. To serialize the 245 | // output to a string, user simply write: 246 | // 247 | // std::string dest; 248 | // SequencedChunkWriter writes_to_string( 249 | // riegeli::Maker(&dest)); 250 | // 251 | // Similarly, user can write the output to a cord or to a file. 252 | // 253 | // absl::Cord cord; 254 | // SequencedChunkWriter writes_to_cord( 255 | // riegeli::Maker(&cord)); 256 | // 257 | // SequencedChunkWriter writes_to_file( 258 | // riegeli::Maker(filename_or_file)); 259 | // 260 | // User may also use riegeli::SharedPtr or std::make_unique to construct the 261 | // instance, as shown in the previous example. 262 | template 263 | class SequencedChunkWriter : public SequencedChunkWriterBase { 264 | public: 265 | DECLARE_IMMOBILE_CLASS(SequencedChunkWriter); 266 | 267 | // Will write to the `Writer` provided by `dest`. 268 | explicit SequencedChunkWriter(riegeli::Initializer dest) 269 | : dest_(std::move(dest)) { 270 | Initialize(); 271 | } 272 | 273 | protected: 274 | riegeli::ChunkWriter* get_writer() final { return &dest_; } 275 | 276 | private: 277 | riegeli::DefaultChunkWriter dest_; 278 | }; 279 | 280 | template 281 | explicit SequencedChunkWriter(Dest&& dest) 282 | -> SequencedChunkWriter>; 283 | 284 | } // namespace array_record 285 | 286 | #endif // ARRAY_RECORD_CPP_SEQUENCED_CHUNK_WRITER_H_ 287 | -------------------------------------------------------------------------------- /cpp/sequenced_chunk_writer_test.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2022 Google LLC. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #include "cpp/sequenced_chunk_writer.h" 17 | 18 | #include 19 | #include // NOLINT(build/c++11) 20 | #include 21 | #include 22 | #include 23 | #include 24 | 25 | #include "gtest/gtest.h" 26 | #include "absl/status/status.h" 27 | #include "absl/status/statusor.h" 28 | #include "absl/strings/cord.h" 29 | #include "absl/strings/string_view.h" 30 | #include "absl/types/span.h" 31 | #include "cpp/common.h" 32 | #include "cpp/thread_pool.h" 33 | #include "riegeli/base/maker.h" 34 | #include "riegeli/bytes/chain_writer.h" 35 | #include "riegeli/bytes/cord_writer.h" 36 | #include "riegeli/bytes/string_reader.h" 37 | #include "riegeli/bytes/string_writer.h" 38 | #include "riegeli/chunk_encoding/chunk.h" 39 | #include "riegeli/chunk_encoding/compressor_options.h" 40 | #include "riegeli/chunk_encoding/constants.h" 41 | #include "riegeli/chunk_encoding/simple_encoder.h" 42 | #include "riegeli/records/record_reader.h" 43 | 44 | namespace array_record { 45 | namespace { 46 | 47 | TEST(SequencedChunkWriterTest, RvalCtorTest) { 48 | // Constructs SequencedChunkWriter by taking the ownership of the other 49 | // riegeli writer. 50 | { 51 | std::string dest; 52 | auto str_writer = riegeli::StringWriter(&dest); 53 | auto to_string = SequencedChunkWriter(std::move(str_writer)); 54 | } 55 | { 56 | absl::Cord cord; 57 | auto cord_writer = riegeli::CordWriter(&cord); 58 | auto to_cord = SequencedChunkWriter(std::move(cord_writer)); 59 | } 60 | { 61 | std::string dest; 62 | auto str_writer = riegeli::StringWriter(&dest); 63 | auto to_string = 64 | std::make_unique>>( 65 | std::move(str_writer)); 66 | } 67 | { 68 | absl::Cord cord; 69 | auto cord_writer = riegeli::CordWriter(&cord); 70 | auto to_cord = 71 | std::make_unique>>( 72 | std::move(cord_writer)); 73 | } 74 | } 75 | 76 | TEST(SequencedChunkWriterTest, DestArgsCtorTest) { 77 | // Constructs SequencedChunkWriter by forwarding the constructor arguments to 78 | // templated riegeli writer. 79 | { 80 | std::string dest; 81 | auto to_string = 82 | SequencedChunkWriter(riegeli::Maker(&dest)); 83 | } 84 | { 85 | absl::Cord cord; 86 | auto to_cord = 87 | SequencedChunkWriter(riegeli::Maker(&cord)); 88 | } 89 | 90 | { 91 | std::string dest; 92 | auto to_string = 93 | std::make_unique>>( 94 | riegeli::Maker(&dest)); 95 | } 96 | { 97 | absl::Cord cord; 98 | auto to_cord = 99 | std::make_unique>>( 100 | riegeli::Maker(&cord)); 101 | } 102 | } 103 | 104 | class TestCommitChunkCallback 105 | : public SequencedChunkWriterBase::SubmitChunkCallback { 106 | public: 107 | void operator()(uint64_t chunk_seq, uint64_t chunk_offset, 108 | uint64_t decoded_data_size, uint64_t num_records) override { 109 | chunk_offsets_.push_back(chunk_offset); 110 | } 111 | absl::Span get_chunk_offsets() const { 112 | return chunk_offsets_; 113 | } 114 | 115 | private: 116 | std::vector chunk_offsets_; 117 | }; 118 | 119 | TEST(SequencedChunkWriterTest, SanityTestCodeSnippet) { 120 | std::string encoded; 121 | auto callback = TestCommitChunkCallback(); 122 | 123 | auto writer = std::make_shared>>( 124 | riegeli::Maker(&encoded)); 125 | writer->set_submit_chunk_callback(&callback); 126 | ASSERT_TRUE(writer->ok()) << writer->status(); 127 | 128 | for (auto i : Seq(3)) { 129 | std::packaged_task()> encoding_task([i] { 130 | riegeli::Chunk chunk; 131 | riegeli::SimpleEncoder encoder( 132 | riegeli::CompressorOptions().set_uncompressed(), 133 | riegeli::SimpleEncoder::TuningOptions().set_size_hint(1)); 134 | std::string text_to_encode = std::to_string(i); 135 | EXPECT_TRUE(encoder.AddRecord(absl::string_view(text_to_encode))); 136 | riegeli::ChunkType chunk_type; 137 | uint64_t decoded_data_size; 138 | uint64_t num_records; 139 | riegeli::ChainWriter chain_writer(&chunk.data); 140 | EXPECT_TRUE(encoder.EncodeAndClose(chain_writer, chunk_type, num_records, 141 | decoded_data_size)); 142 | EXPECT_TRUE(chain_writer.Close()); 143 | chunk.header = riegeli::ChunkHeader(chunk.data, chunk_type, num_records, 144 | decoded_data_size); 145 | return chunk; 146 | }); 147 | ASSERT_TRUE(writer->CommitFutureChunk(encoding_task.get_future())); 148 | encoding_task(); 149 | writer->SubmitFutureChunks(false); 150 | } 151 | // Calling SubmitFutureChunks(true) blocks the current thread until all 152 | // encoding tasks complete. 153 | EXPECT_TRUE(writer->SubmitFutureChunks(true)); 154 | // Paddings should not cause any failure. 155 | EXPECT_TRUE(writer->Close()); 156 | 157 | // File produced by SequencedChunkWriter should be a valid riegeli file. 158 | auto reader = 159 | riegeli::RecordReader(riegeli::Maker(encoded)); 160 | ASSERT_TRUE(reader.CheckFileFormat()); 161 | // Read sequentially 162 | absl::Cord result; 163 | EXPECT_TRUE(reader.ReadRecord(result)); 164 | EXPECT_EQ(result, "0"); 165 | EXPECT_TRUE(reader.ReadRecord(result)); 166 | EXPECT_EQ(result, "1"); 167 | EXPECT_TRUE(reader.ReadRecord(result)); 168 | EXPECT_EQ(result, "2"); 169 | EXPECT_FALSE(reader.ReadRecord(result)); 170 | 171 | // We can use the chunk_offsets information to randomly access records. 172 | auto offsets = callback.get_chunk_offsets(); 173 | EXPECT_TRUE(reader.Seek(offsets[1])); 174 | EXPECT_TRUE(reader.ReadRecord(result)); 175 | EXPECT_EQ(result, "1"); 176 | EXPECT_TRUE(reader.Seek(offsets[0])); 177 | EXPECT_TRUE(reader.ReadRecord(result)); 178 | EXPECT_EQ(result, "0"); 179 | EXPECT_TRUE(reader.Seek(offsets[2])); 180 | EXPECT_TRUE(reader.ReadRecord(result)); 181 | EXPECT_EQ(result, "2"); 182 | 183 | EXPECT_TRUE(reader.Close()); 184 | } 185 | 186 | TEST(SequencedChunkWriterTest, SanityTestBadChunk) { 187 | std::string encoded; 188 | auto callback = TestCommitChunkCallback(); 189 | 190 | auto writer = std::make_shared>>( 191 | riegeli::Maker(&encoded)); 192 | writer->set_submit_chunk_callback(&callback); 193 | ASSERT_TRUE(writer->ok()) << writer->status(); 194 | std::packaged_task()> encoding_task( 195 | [] { return absl::InternalError("On purpose"); }); 196 | EXPECT_TRUE(writer->CommitFutureChunk(encoding_task.get_future())); 197 | EXPECT_TRUE(writer->SubmitFutureChunks(false)); 198 | encoding_task(); 199 | // We should see the error being populated even when we try to run it with the 200 | // non-blocking version. 201 | EXPECT_FALSE(writer->SubmitFutureChunks(false)); 202 | EXPECT_EQ(writer->status().code(), absl::StatusCode::kInternal); 203 | 204 | EXPECT_FALSE(writer->Close()); 205 | EXPECT_TRUE(callback.get_chunk_offsets().empty()); 206 | } 207 | 208 | } // namespace 209 | } // namespace array_record 210 | -------------------------------------------------------------------------------- /cpp/test_utils.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2022 Google LLC. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #include "cpp/test_utils.h" 17 | 18 | #include 19 | #include 20 | 21 | #include "cpp/common.h" 22 | 23 | namespace array_record { 24 | 25 | std::string MTRandomBytes(std::mt19937& bitgen, size_t length) { 26 | std::string result(length, '\0'); 27 | 28 | size_t gen_bytes = sizeof(uint32_t); 29 | size_t rem = length % gen_bytes; 30 | std::mt19937::result_type val = bitgen(); 31 | char* ptr = result.data(); 32 | std::memcpy(ptr, &val, rem); 33 | ptr += rem; 34 | 35 | for (auto _ : Seq(length / gen_bytes)) { 36 | uint32_t val = bitgen(); 37 | std::memcpy(ptr, &val, gen_bytes); 38 | ptr += gen_bytes; 39 | } 40 | return result; 41 | } 42 | 43 | } // namespace array_record 44 | -------------------------------------------------------------------------------- /cpp/test_utils.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2022 Google LLC. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #ifndef ARRAY_RECORD_CPP_TEST_UTILS_H_ 17 | #define ARRAY_RECORD_CPP_TEST_UTILS_H_ 18 | 19 | #include 20 | 21 | namespace array_record { 22 | 23 | // Generates a sequence of random bytes deterministically. 24 | std::string MTRandomBytes(std::mt19937& bitgen, size_t length); 25 | 26 | } // namespace array_record 27 | 28 | #endif // ARRAY_RECORD_CPP_TEST_UTILS_H_ 29 | -------------------------------------------------------------------------------- /cpp/test_utils_test.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2022 Google LLC. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #include "cpp/test_utils.h" 17 | 18 | #include 19 | 20 | #include "gmock/gmock.h" 21 | #include "gtest/gtest.h" 22 | #include "absl/strings/string_view.h" 23 | #include "cpp/common.h" 24 | 25 | namespace array_record { 26 | namespace { 27 | 28 | TEST(MTRandomBytesTest, ZeroLen) { 29 | std::mt19937 bitgen; 30 | auto result = MTRandomBytes(bitgen, 0); 31 | ASSERT_EQ(result.size(), 0); 32 | } 33 | 34 | TEST(MTRandomBytesTest, OneByte) { 35 | std::mt19937 bitgen, bitgen2; 36 | auto result = MTRandomBytes(bitgen, 1); 37 | ASSERT_EQ(result.size(), 1); 38 | ASSERT_NE(result[0], '\0'); 39 | 40 | auto val = bitgen2(); 41 | char char_val = *reinterpret_cast(&val); 42 | ASSERT_EQ(result[0], char_val); 43 | } 44 | 45 | TEST(MTRandomBytesTest, LargeVals) { 46 | constexpr size_t len = 123; 47 | std::mt19937 bitgen, bitgen2; 48 | 49 | auto result1 = MTRandomBytes(bitgen, len); 50 | auto result2 = MTRandomBytes(bitgen2, len); 51 | ASSERT_EQ(result1.size(), len); 52 | ASSERT_EQ(result2.size(), len); 53 | ASSERT_EQ(result1, result2); 54 | } 55 | 56 | } // namespace 57 | 58 | } // namespace array_record 59 | -------------------------------------------------------------------------------- /cpp/thread_pool.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2022 Google LLC. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #include "cpp/thread_pool.h" 17 | 18 | #include "absl/flags/flag.h" 19 | 20 | ABSL_FLAG(uint32_t, array_record_global_pool_size, 64, 21 | "Number of threads for ArrayRecordGlobalPool"); 22 | 23 | namespace array_record { 24 | 25 | ARThreadPool* ArrayRecordGlobalPool() { 26 | static ARThreadPool* pool_ = []() -> ARThreadPool* { 27 | ARThreadPool* pool = new 28 | Eigen::ThreadPool(absl::GetFlag(FLAGS_array_record_global_pool_size)); 29 | return pool; 30 | }(); 31 | return pool_; 32 | } 33 | 34 | } // namespace array_record 35 | -------------------------------------------------------------------------------- /cpp/thread_pool.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2022 Google LLC. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #ifndef ARRAY_RECORD_CPP_THREAD_POOL_H_ 17 | #define ARRAY_RECORD_CPP_THREAD_POOL_H_ 18 | 19 | #define EIGEN_USE_CUSTOM_THREAD_POOL 20 | #include "unsupported/Eigen/CXX11/ThreadPool" 21 | 22 | namespace array_record { 23 | 24 | using ARThreadPool = Eigen::ThreadPoolInterface; 25 | 26 | ARThreadPool* ArrayRecordGlobalPool(); 27 | 28 | } // namespace array_record 29 | 30 | #endif // ARRAY_RECORD_CPP_THREAD_POOL_H_ 31 | -------------------------------------------------------------------------------- /cpp/tri_state_ptr.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2024 Google LLC. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #ifndef ARRAY_RECORD_CPP_TRI_STATE_PTR_H_ 17 | #define ARRAY_RECORD_CPP_TRI_STATE_PTR_H_ 18 | 19 | #include 20 | 21 | #include 22 | #include 23 | #include 24 | #include 25 | 26 | #include "absl/base/thread_annotations.h" 27 | #include "absl/synchronization/mutex.h" 28 | #include "cpp/common.h" 29 | 30 | namespace array_record { 31 | 32 | /** TriStatePtr is a wrapper around a pointer that allows for a unique and 33 | * shared reference. 34 | * 35 | * There are three states: 36 | * 37 | * - NoRef: The object does not have shared or unique references. 38 | * - Sharing: The object is shared. 39 | * - Unique: The object is referenced by a unique pointer wrapper. 40 | * 41 | * The state transition from NoRef to Shared when MakeShared is called. 42 | * An internal refernce count is incremented when a SharedRef is created. 43 | * 44 | * SharedRef ref = MakeShared(); -- 45 | * NoRef ----------------------------> Sharing / | MakeShared() 46 | * All SharedRef deallocated <-- 47 | * <---------------------------- 48 | * 49 | * The state can also transition to Unique when WaitAndMakeUnique is called. 50 | * We can only hold one unique reference at a time. 51 | * 52 | * UniqueRef ref = WaitAndMakeUnique(); 53 | * NoRef ----------------------------> Unique 54 | * The UniqueRef is deallocated 55 | * <---------------------------- 56 | * 57 | * Other than the state transition above, state transitions methods would block 58 | * until the specified state is possible. On deallocation, the destructor blocks 59 | * until the state is NoRef. 60 | * 61 | * Example usage: 62 | * 63 | * TriStatePtr main(riegeli::Maker(...)); 64 | * // Create a shared reference to work on other threads. 65 | * pool->Schedule([refobj = foo_ptr.MakeShared()] { 66 | * refobj->FooMethod(); 67 | * }); 68 | * 69 | * // Blocks until refobj is out of scope. 70 | * auto unique_ref = main.WaitAndMakeUnique(); 71 | * unique_ref->CleanupFoo(); 72 | * 73 | */ 74 | template 75 | class TriStatePtr { 76 | public: 77 | DECLARE_IMMOBILE_CLASS(TriStatePtr); 78 | TriStatePtr() = default; 79 | 80 | ~TriStatePtr() { 81 | absl::MutexLock l(&mu_); 82 | mu_.Await(absl::Condition( 83 | +[](State* sharing_state) { return *sharing_state == State::kNoRef; }, 84 | &state_)); 85 | } 86 | 87 | explicit TriStatePtr(std::unique_ptr ptr) : ptr_(std::move(ptr)) {} 88 | 89 | class SharedRef { 90 | public: 91 | SharedRef(TriStatePtr* parent) : parent_(parent) {} 92 | 93 | SharedRef(const SharedRef& other) : parent_(other.parent_) { 94 | parent_->ref_count_++; 95 | } 96 | SharedRef& operator=(const SharedRef& other) { 97 | parent_ = other.parent_; 98 | parent_->ref_count_++; 99 | return *this; 100 | } 101 | 102 | ~SharedRef() { 103 | int32_t ref_count = 104 | parent_->ref_count_.fetch_sub(1, std::memory_order_acq_rel) - 1; 105 | if (ref_count == 0) { 106 | absl::MutexLock l(&parent_->mu_); 107 | parent_->state_ = State::kNoRef; 108 | } 109 | } 110 | 111 | const BaseT& operator*() const { return *parent_->ptr_.get(); } 112 | const BaseT* operator->() const { return parent_->ptr_.get(); } 113 | BaseT& operator*() { return *parent_->ptr_.get(); } 114 | BaseT* operator->() { return parent_->ptr_.get(); } 115 | 116 | private: 117 | TriStatePtr* parent_ = nullptr; 118 | }; 119 | 120 | class UniqueRef { 121 | public: 122 | DECLARE_MOVE_ONLY_CLASS(UniqueRef); 123 | UniqueRef(TriStatePtr* parent) : parent_(parent) {} 124 | 125 | ~UniqueRef() { 126 | absl::MutexLock l(&parent_->mu_); 127 | parent_->state_ = State::kNoRef; 128 | } 129 | 130 | const BaseT& operator*() const { return *parent_->ptr_.get(); } 131 | const BaseT* operator->() const { return parent_->ptr_.get(); } 132 | BaseT& operator*() { return *parent_->ptr_.get(); } 133 | BaseT* operator->() { return parent_->ptr_.get(); } 134 | 135 | private: 136 | TriStatePtr* parent_; 137 | }; 138 | 139 | SharedRef MakeShared() { 140 | absl::MutexLock l(&mu_); 141 | mu_.Await(absl::Condition( 142 | +[](State* sharing_state) { return *sharing_state != State::kUnique; }, 143 | &state_)); 144 | state_ = State::kSharing; 145 | ref_count_++; 146 | return SharedRef(this); 147 | } 148 | 149 | UniqueRef WaitAndMakeUnique() { 150 | absl::MutexLock l(&mu_); 151 | mu_.Await(absl::Condition( 152 | +[](State* sharing_state) { return *sharing_state == State::kNoRef; }, 153 | &state_)); 154 | state_ = State::kUnique; 155 | return UniqueRef(this); 156 | } 157 | 158 | enum class State { 159 | kNoRef = 0, 160 | kSharing = 1, 161 | kUnique = 2, 162 | }; 163 | 164 | State state() const { 165 | absl::MutexLock l(&mu_); 166 | return state_; 167 | } 168 | 169 | private: 170 | mutable absl::Mutex mu_; 171 | std::atomic_int32_t ref_count_ = 0; 172 | State state_ ABSL_GUARDED_BY(mu_) = State::kNoRef; 173 | std::unique_ptr ptr_; 174 | }; 175 | 176 | } // namespace array_record 177 | 178 | #endif // ARRAY_RECORD_CPP_TRI_STATE_PTR_H_ 179 | -------------------------------------------------------------------------------- /cpp/tri_state_ptr_test.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2022 Google LLC. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #include "cpp/tri_state_ptr.h" 17 | #include 18 | 19 | #include "gtest/gtest.h" 20 | #include "absl/synchronization/notification.h" 21 | #include "cpp/common.h" 22 | #include "cpp/thread_pool.h" 23 | #include "riegeli/base/maker.h" 24 | 25 | namespace array_record { 26 | namespace { 27 | 28 | class FooBase { 29 | public: 30 | virtual ~FooBase() = default; 31 | virtual int value() const = 0; 32 | virtual void add_value(int v) = 0; 33 | virtual void mul_value(int v) = 0; 34 | }; 35 | 36 | class Foo : public FooBase { 37 | public: 38 | explicit Foo(int v) : value_(v) {} 39 | DECLARE_MOVE_ONLY_CLASS(Foo); 40 | 41 | int value() const override { return value_; }; 42 | void add_value(int v) override { value_ += v; } 43 | void mul_value(int v) override { value_ *= v; } 44 | 45 | private: 46 | int value_; 47 | }; 48 | 49 | class TriStatePtrTest : public testing::Test { 50 | public: 51 | TriStatePtrTest() : pool_(ArrayRecordGlobalPool()) {} 52 | 53 | protected: 54 | ARThreadPool* pool_; 55 | }; 56 | 57 | TEST_F(TriStatePtrTest, SanityTest) { 58 | TriStatePtr foo_main(std::move(riegeli::Maker(1))); 59 | EXPECT_EQ(foo_main.state(), TriStatePtr::State::kNoRef); 60 | absl::Notification notification; 61 | { 62 | pool_->Schedule( 63 | [foo_shared = foo_main.MakeShared(), ¬ification]() mutable { 64 | notification.WaitForNotification(); 65 | EXPECT_EQ(foo_shared->value(), 1); 66 | const auto second_foo_shared = foo_shared; 67 | foo_shared->add_value(1); 68 | EXPECT_EQ(second_foo_shared->value(), 2); 69 | }); 70 | } 71 | EXPECT_EQ(foo_main.state(), TriStatePtr::State::kSharing); 72 | notification.Notify(); 73 | auto foo_unique = foo_main.WaitAndMakeUnique(); 74 | foo_unique->mul_value(3); 75 | EXPECT_EQ(foo_unique->value(), 6); 76 | EXPECT_EQ(foo_main.state(), TriStatePtr::State::kUnique); 77 | } 78 | 79 | } // namespace 80 | } // namespace array_record 81 | -------------------------------------------------------------------------------- /oss/README.md: -------------------------------------------------------------------------------- 1 | # Steps to build a new array_record pip package 2 | 3 | 1. Update the version number in setup.py 4 | 5 | 2. In the root folder, run 6 | 7 | ``` 8 | ./oss/build_whl.sh 9 | ``` 10 | to use the current `python3` version. Otherwise, optionally set 11 | ``` 12 | PYTHON_VERSION=3.9 ./oss/build_whl.sh 13 | ``` 14 | 15 | 3. Wheels are in `all_dist/`. 16 | -------------------------------------------------------------------------------- /oss/build.Dockerfile: -------------------------------------------------------------------------------- 1 | # Constructs the environment within which we will build the pip wheels. 2 | 3 | 4 | ARG AUDITWHEEL_PLATFORM 5 | 6 | FROM quay.io/pypa/${AUDITWHEEL_PLATFORM} 7 | 8 | ARG PYTHON_VERSION 9 | ARG PYTHON_BIN 10 | ARG BAZEL_VERSION 11 | 12 | ENV DEBIAN_FRONTEND=noninteractive 13 | 14 | RUN ulimit -n 1024 && yum install -y rsync 15 | ENV PATH="${PYTHON_BIN}:${PATH}" 16 | 17 | # Download the correct bazel version and make sure it's on path. 18 | RUN BAZEL_ARCH_SUFFIX="$(uname -m | sed s/aarch64/arm64/)" \ 19 | && curl -sSL --fail -o /usr/local/bin/bazel "https://github.com/bazelbuild/bazel/releases/download/$BAZEL_VERSION/bazel-$BAZEL_VERSION-linux-$BAZEL_ARCH_SUFFIX" \ 20 | && chmod a+x /usr/local/bin/bazel 21 | 22 | # Install dependencies needed for array_record. 23 | RUN --mount=type=cache,target=/root/.cache \ 24 | ${PYTHON_BIN}/python -m pip install -U \ 25 | absl-py \ 26 | auditwheel \ 27 | etils[epath] \ 28 | patchelf \ 29 | setuptools \ 30 | twine \ 31 | wheel; 32 | 33 | WORKDIR "/tmp/array_record" -------------------------------------------------------------------------------- /oss/build_whl.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Build wheel for the python version specified by $PYTHON_VERSION. 3 | # Optionally, can set the environment variable $PYTHON_BIN to refer to a 4 | # specific python interpreter. 5 | 6 | set -e -x 7 | 8 | if [ -z ${PYTHON_BIN} ]; then 9 | if [ -z ${PYTHON_VERSION} ]; then 10 | PYTHON_BIN=$(which python3) 11 | else 12 | PYTHON_BIN=$(which python${PYTHON_VERSION}) 13 | fi 14 | fi 15 | 16 | PYTHON_MAJOR_VERSION=$(${PYTHON_BIN} -c 'import sys; print(sys.version_info.major)') 17 | PYTHON_MINOR_VERSION=$(${PYTHON_BIN} -c 'import sys; print(sys.version_info.minor)') 18 | PYTHON_VERSION="${PYTHON_MAJOR_VERSION}.${PYTHON_MINOR_VERSION}" 19 | export PYTHON_VERSION="${PYTHON_VERSION}" 20 | 21 | function write_to_bazelrc() { 22 | echo "$1" >> .bazelrc 23 | } 24 | 25 | function main() { 26 | # Remove .bazelrc if it already exists 27 | [ -e .bazelrc ] && rm .bazelrc 28 | 29 | write_to_bazelrc "build -c opt" 30 | write_to_bazelrc "build --cxxopt=-std=c++17" 31 | write_to_bazelrc "build --host_cxxopt=-std=c++17" 32 | write_to_bazelrc "build --experimental_repo_remote_exec" 33 | write_to_bazelrc "build --python_path=\"${PYTHON_BIN}\"" 34 | 35 | if [ -n "${CROSSTOOL_TOP}" ]; then 36 | write_to_bazelrc "build --crosstool_top=${CROSSTOOL_TOP}" 37 | write_to_bazelrc "test --crosstool_top=${CROSSTOOL_TOP}" 38 | fi 39 | 40 | export USE_BAZEL_VERSION="${BAZEL_VERSION}" 41 | bazel clean 42 | bazel build ... 43 | bazel test --verbose_failures --test_output=errors ... 44 | 45 | DEST="/tmp/array_record/all_dist" 46 | # Create the directory, then do dirname on a non-existent file inside it to 47 | # give us an absolute paths with tilde characters resolved to the destination 48 | # directory. 49 | mkdir -p "${DEST}" 50 | echo "=== destination directory: ${DEST}" 51 | 52 | TMPDIR=$(mktemp -d -t tmp.XXXXXXXXXX) 53 | 54 | echo $(date) : "=== Using tmpdir: ${TMPDIR}" 55 | mkdir "${TMPDIR}/array_record" 56 | 57 | echo $(date) : "=== Copy array_record files" 58 | 59 | cp setup.py "${TMPDIR}" 60 | cp LICENSE "${TMPDIR}" 61 | rsync -avm -L --exclude="bazel-*/" . "${TMPDIR}/array_record" 62 | rsync -avm -L --include="*.so" --include="*_pb2.py" \ 63 | --exclude="*.runfiles" --exclude="*_obj" --include="*/" --exclude="*" \ 64 | bazel-bin/cpp "${TMPDIR}/array_record" 65 | rsync -avm -L --include="*.so" --include="*_pb2.py" \ 66 | --exclude="*.runfiles" --exclude="*_obj" --include="*/" --exclude="*" \ 67 | bazel-bin/python "${TMPDIR}/array_record" 68 | 69 | pushd ${TMPDIR} 70 | echo $(date) : "=== Building wheel" 71 | ${PYTHON_BIN} setup.py bdist_wheel --python-tag py3${PYTHON_MINOR_VERSION} 72 | 73 | if [ -n "${AUDITWHEEL_PLATFORM}" ]; then 74 | echo $(date) : "=== Auditing wheel" 75 | auditwheel repair --plat ${AUDITWHEEL_PLATFORM} -w dist dist/*.whl 76 | fi 77 | 78 | echo $(date) : "=== Listing wheel" 79 | ls -lrt dist/*.whl 80 | cp dist/*.whl "${DEST}" 81 | popd 82 | 83 | echo $(date) : "=== Output wheel file is in: ${DEST}" 84 | } 85 | 86 | main 87 | -------------------------------------------------------------------------------- /oss/runner_common.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Builds ArrayRecord from source code located in SOURCE_DIR producing wheels 4 | # under $SOURCE_DIR/all_dist. 5 | function build_and_test_array_record_linux() { 6 | SOURCE_DIR=$1 7 | 8 | # Automatically decide which platform to build for by checking on which 9 | # platform this runs. 10 | AUDITWHEEL_PLATFORM="manylinux2014_$(uname -m)" 11 | 12 | # Using a previous version of Blaze to avoid: 13 | # https://github.com/bazelbuild/bazel/issues/8622 14 | export BAZEL_VERSION="8.0.0" 15 | 16 | # Build wheels for multiple Python minor versions. 17 | PYTHON_MAJOR_VERSION=3 18 | for PYTHON_MINOR_VERSION in 10 11 12 19 | do 20 | PYTHON_VERSION=${PYTHON_MAJOR_VERSION}.${PYTHON_MINOR_VERSION} 21 | PYTHON_BIN=/opt/python/cp${PYTHON_MAJOR_VERSION}${PYTHON_MINOR_VERSION}-cp${PYTHON_MAJOR_VERSION}${PYTHON_MINOR_VERSION}/bin 22 | 23 | # Cleanup older images. 24 | docker rmi -f array_record:${PYTHON_VERSION} 25 | docker rm -f array_record 26 | 27 | DOCKER_BUILDKIT=1 docker build --progress=plain --no-cache \ 28 | --build-arg AUDITWHEEL_PLATFORM=${AUDITWHEEL_PLATFORM} \ 29 | --build-arg PYTHON_VERSION=${PYTHON_VERSION} \ 30 | --build-arg PYTHON_BIN=${PYTHON_BIN} \ 31 | --build-arg BAZEL_VERSION=${BAZEL_VERSION} \ 32 | -t array_record:${PYTHON_VERSION} - < ${SOURCE_DIR}/oss/build.Dockerfile 33 | 34 | docker run --rm -a stdin -a stdout -a stderr \ 35 | --env PYTHON_BIN="${PYTHON_BIN}/python" \ 36 | --env BAZEL_VERSION=${BAZEL_VERSION} \ 37 | --env AUDITWHEEL_PLATFORM=${AUDITWHEEL_PLATFORM} \ 38 | -v $SOURCE_DIR:/tmp/array_record \ 39 | --name array_record array_record:${PYTHON_VERSION} \ 40 | bash oss/build_whl.sh 41 | done 42 | 43 | ls ${SOURCE_DIR}/all_dist/*.whl 44 | } 45 | 46 | function install_and_init_pyenv { 47 | pyenv_root=${1:-$HOME/.pyenv} 48 | export PYENV_ROOT=$pyenv_root 49 | if [[ ! -d $PYENV_ROOT ]]; then 50 | echo "Installing pyenv.." 51 | git clone https://github.com/pyenv/pyenv.git "$PYENV_ROOT" 52 | export PATH="/home/kbuilder/.local/bin:$PYENV_ROOT/bin:$PATH" 53 | eval "$(pyenv init --path)" 54 | fi 55 | 56 | echo "Python setup..." 57 | pyenv install -s "$PYENV_PYTHON_VERSION" 58 | pyenv global "$PYENV_PYTHON_VERSION" 59 | PYTHON=$(pyenv which python) 60 | } 61 | 62 | function setup_env_vars_py310 { 63 | # This controls the python binary to use. 64 | PYTHON=python3.10 65 | PYTHON_STR=python3.10 66 | PYTHON_MAJOR_VERSION=3 67 | PYTHON_MINOR_VERSION=10 68 | # This is for pyenv install. 69 | PYENV_PYTHON_VERSION=3.10.13 70 | } 71 | 72 | function update_bazel_macos { 73 | BAZEL_VERSION=$1 74 | ARCH="$(uname -m)" 75 | curl -L https://github.com/bazelbuild/bazel/releases/download/${BAZEL_VERSION}/bazel-${BAZEL_VERSION}-installer-darwin-${ARCH}.sh -O 76 | ls 77 | chmod +x bazel-*.sh 78 | ./bazel-${BAZEL_VERSION}-installer-darwin-${ARCH}.sh --user 79 | rm -f ./bazel-${BAZEL_VERSION}-installer-darwin-${ARCH}.sh 80 | # Add new bazel installation to path 81 | PATH="/Users/kbuilder/bin:$PATH" 82 | } 83 | 84 | function build_and_test_array_record_macos() { 85 | SOURCE_DIR=$1 86 | # Set up Bazel. 87 | # Using a previous version of Bazel to avoid: 88 | # https://github.com/bazelbuild/bazel/issues/8622 89 | export BAZEL_VERSION="8.0.0" 90 | update_bazel_macos ${BAZEL_VERSION} 91 | bazel --version 92 | 93 | # Set up Pyenv. 94 | setup_env_vars_py310 95 | install_and_init_pyenv 96 | 97 | # Build and test ArrayRecord. 98 | cd ${SOURCE_DIR} 99 | bash ${SOURCE_DIR}/oss/build_whl.sh 100 | 101 | ls ${SOURCE_DIR}/all_dist/*.whl 102 | } 103 | -------------------------------------------------------------------------------- /python/BUILD: -------------------------------------------------------------------------------- 1 | # Python binding for ArrayRecord 2 | 3 | load("@pybind11_bazel//:build_defs.bzl", "pybind_extension") 4 | load("@pypi//:requirements.bzl", "requirement") 5 | 6 | package(default_visibility = ["//visibility:public"]) 7 | 8 | licenses(["notice"]) 9 | 10 | pybind_extension( 11 | name = "array_record_module", 12 | srcs = ["array_record_module.cc"], 13 | deps = [ 14 | "@abseil-cpp//absl/status", 15 | "@abseil-cpp//absl/strings", 16 | "@abseil-cpp//absl/strings:str_format", 17 | "//cpp:array_record_reader", 18 | "//cpp:array_record_writer", 19 | "//cpp:thread_pool", 20 | "@riegeli//riegeli/base:initializer", 21 | "@riegeli//riegeli/bytes:fd_reader", 22 | "@riegeli//riegeli/bytes:fd_writer", 23 | ], 24 | ) 25 | 26 | py_test( 27 | name = "array_record_module_test", 28 | srcs = ["array_record_module_test.py"], 29 | data = [":array_record_module.so"], 30 | deps = [ 31 | "@abseil-py//absl/testing:absltest", 32 | ], 33 | ) 34 | 35 | py_library( 36 | name = "array_record_data_source", 37 | srcs = ["array_record_data_source.py"], 38 | data = [":array_record_module.so"], 39 | deps = [ 40 | requirement("etils"), 41 | ], 42 | ) 43 | 44 | py_test( 45 | name = "array_record_data_source_test", 46 | srcs = ["array_record_data_source_test.py"], 47 | args = ["--test_srcdir=python/testdata"], 48 | data = [ 49 | ":array_record_module.so", 50 | "//python/testdata:digits.array_record-00000-of-00002", 51 | "//python/testdata:digits.array_record-00001-of-00002", 52 | ], 53 | deps = [ 54 | ":array_record_data_source", 55 | "@abseil-py//absl/testing:absltest", 56 | "@abseil-py//absl/testing:flagsaver", 57 | "@abseil-py//absl/testing:parameterized", 58 | ], 59 | ) 60 | -------------------------------------------------------------------------------- /python/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/array_record/6cbfde75d747832769baffa40a8858046bfc1532/python/__init__.py -------------------------------------------------------------------------------- /python/array_record_data_source_test.py: -------------------------------------------------------------------------------- 1 | # 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # https://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | 14 | """Tests for ArrayRecord data sources.""" 15 | 16 | from concurrent import futures 17 | import dataclasses 18 | import os 19 | import pathlib 20 | from unittest import mock 21 | 22 | from absl import flags 23 | from absl.testing import absltest 24 | from absl.testing import flagsaver 25 | from absl.testing import parameterized 26 | 27 | from python import array_record_data_source 28 | from python import array_record_module 29 | 30 | 31 | FLAGS = flags.FLAGS 32 | 33 | 34 | @dataclasses.dataclass 35 | class DummyFileInstruction: 36 | filename: str 37 | skip: int 38 | take: int 39 | examples_in_shard: int 40 | 41 | 42 | class ArrayRecordDataSourcesTest(absltest.TestCase): 43 | 44 | def setUp(self): 45 | super().setUp() 46 | self.testdata_dir = pathlib.Path(FLAGS.test_srcdir) 47 | 48 | def test_check_default_group_size(self): 49 | filename = os.path.join(FLAGS.test_tmpdir, "test.array_record") 50 | writer = array_record_module.ArrayRecordWriter(filename) 51 | writer.write(b"foobar") 52 | writer.close() 53 | reader = array_record_module.ArrayRecordReader(filename) 54 | with self.assertLogs(level="ERROR") as log_output: 55 | array_record_data_source._check_group_size(filename, reader) 56 | self.assertRegex( 57 | log_output.output[0], 58 | ( 59 | r"File .* was created with group size 65536. Grain requires group" 60 | r" size 1 for good performance" 61 | ), 62 | ) 63 | 64 | def test_check_valid_group_size(self): 65 | filename = os.path.join(FLAGS.test_tmpdir, "test.array_record") 66 | writer = array_record_module.ArrayRecordWriter(filename, "group_size:1") 67 | writer.write(b"foobar") 68 | writer.close() 69 | reader = array_record_module.ArrayRecordReader(filename) 70 | 71 | def test_check_invalid_group_size(self): 72 | filename = os.path.join(FLAGS.test_tmpdir, "test.array_record") 73 | writer = array_record_module.ArrayRecordWriter(filename, "group_size:11") 74 | writer.write(b"foobar") 75 | writer.close() 76 | reader = array_record_module.ArrayRecordReader(filename) 77 | with self.assertLogs(level="ERROR") as log_output: 78 | array_record_data_source._check_group_size(filename, reader) 79 | self.assertRegex( 80 | log_output.output[0], 81 | ( 82 | r"File .* was created with group size 11. Grain requires group size" 83 | r" 1 for good performance" 84 | ), 85 | ) 86 | 87 | def test_array_record_data_source_len(self): 88 | ar = array_record_data_source.ArrayRecordDataSource([ 89 | self.testdata_dir / "digits.array_record-00000-of-00002", 90 | self.testdata_dir / "digits.array_record-00001-of-00002", 91 | ]) 92 | self.assertLen(ar, 10) 93 | 94 | def test_array_record_data_source_iter(self): 95 | ar = array_record_data_source.ArrayRecordDataSource([ 96 | self.testdata_dir / "digits.array_record-00000-of-00002", 97 | self.testdata_dir / "digits.array_record-00001-of-00002", 98 | ]) 99 | digits = [b"0", b"1", b"2", b"3", b"4", b"5", b"6", b"7", b"8", b"9"] 100 | for actual, expected in zip(ar, digits): 101 | self.assertEqual(actual, expected) 102 | 103 | def test_array_record_data_source_single_path(self): 104 | indices_to_read = [0, 1, 2, 3, 4] 105 | expected_data = [b"0", b"1", b"2", b"3", b"4"] 106 | # Use a single path instead of a list of paths/file_instructions. 107 | with array_record_data_source.ArrayRecordDataSource( 108 | self.testdata_dir / "digits.array_record-00000-of-00002" 109 | ) as ar: 110 | actual_data = [ar[x] for x in indices_to_read] 111 | self.assertEqual(expected_data, actual_data) 112 | self.assertTrue(all(reader is None for reader in ar._readers)) 113 | 114 | def test_array_record_data_source_string_read_instructions(self): 115 | indices_to_read = [0, 1, 2, 3, 4] 116 | expected_data = [b"0", b"1", b"2", b"7", b"8"] 117 | # Use a single path instead of a list of paths/file_instructions. 118 | ar = array_record_data_source.ArrayRecordDataSource([ 119 | self.testdata_dir / "digits.array_record-00000-of-00002[0:3]", 120 | self.testdata_dir / "digits.array_record-00001-of-00002[2:4]", 121 | ]) 122 | self.assertLen(ar, 5) 123 | actual_data = [ar[x] for x in indices_to_read] 124 | self.assertEqual(expected_data, actual_data) 125 | 126 | def test_array_record_data_source_reverse_order(self): 127 | indices_to_read = [9, 8, 7, 6, 5, 4, 3, 2, 1, 0] 128 | expected_data = [b"9", b"8", b"7", b"6", b"5", b"4", b"3", b"2", b"1", b"0"] 129 | with array_record_data_source.ArrayRecordDataSource([ 130 | self.testdata_dir / "digits.array_record-00000-of-00002", 131 | self.testdata_dir / "digits.array_record-00001-of-00002", 132 | ]) as ar: 133 | actual_data = [ar[x] for x in indices_to_read] 134 | self.assertEqual(expected_data, actual_data) 135 | self.assertTrue(all(reader is None for reader in ar._readers)) 136 | 137 | def test_array_record_data_source_random_order(self): 138 | # some random permutation 139 | indices_to_read = [3, 0, 5, 9, 2, 1, 4, 7, 8, 6] 140 | expected_data = [b"3", b"0", b"5", b"9", b"2", b"1", b"4", b"7", b"8", b"6"] 141 | with array_record_data_source.ArrayRecordDataSource([ 142 | self.testdata_dir / "digits.array_record-00000-of-00002", 143 | self.testdata_dir / "digits.array_record-00001-of-00002", 144 | ]) as ar: 145 | actual_data = [ar[x] for x in indices_to_read] 146 | self.assertEqual(expected_data, actual_data) 147 | self.assertTrue(all(reader is None for reader in ar._readers)) 148 | 149 | def test_array_record_data_source_random_order_batched(self): 150 | # some random permutation 151 | indices_to_read = [3, 0, 5, 9, 2, 1, 4, 7, 8, 6] 152 | expected_data = [b"3", b"0", b"5", b"9", b"2", b"1", b"4", b"7", b"8", b"6"] 153 | with array_record_data_source.ArrayRecordDataSource([ 154 | self.testdata_dir / "digits.array_record-00000-of-00002", 155 | self.testdata_dir / "digits.array_record-00001-of-00002", 156 | ]) as ar: 157 | actual_data = ar.__getitems__(indices_to_read) 158 | self.assertEqual(expected_data, actual_data) 159 | self.assertTrue(all(reader is None for reader in ar._readers)) 160 | 161 | def test_array_record_data_source_file_instructions(self): 162 | file_instruction_one = DummyFileInstruction( 163 | filename=os.fspath( 164 | self.testdata_dir / "digits.array_record-00000-of-00002" 165 | ), 166 | skip=2, 167 | take=1, 168 | examples_in_shard=3, 169 | ) 170 | 171 | file_instruction_two = DummyFileInstruction( 172 | filename=os.fspath( 173 | self.testdata_dir / "digits.array_record-00001-of-00002" 174 | ), 175 | skip=2, 176 | take=2, 177 | examples_in_shard=99, 178 | ) 179 | 180 | indices_to_read = [0, 1, 2] 181 | expected_data = [b"2", b"7", b"8"] 182 | 183 | with array_record_data_source.ArrayRecordDataSource( 184 | [file_instruction_one, file_instruction_two] 185 | ) as ar: 186 | self.assertLen(ar, 3) 187 | actual_data = [ar[x] for x in indices_to_read] 188 | 189 | self.assertEqual(expected_data, actual_data) 190 | self.assertTrue(all(reader is None for reader in ar._readers)) 191 | 192 | def test_array_record_source_reader_idx_and_position(self): 193 | file_instructions = [ 194 | # 2 records 195 | DummyFileInstruction( 196 | filename="file_1", skip=0, take=2, examples_in_shard=2 197 | ), 198 | # 3 records 199 | DummyFileInstruction( 200 | filename="file_2", skip=2, take=3, examples_in_shard=99 201 | ), 202 | # 1 record 203 | DummyFileInstruction( 204 | filename="file_3", skip=10, take=1, examples_in_shard=99 205 | ), 206 | ] 207 | 208 | expected_indices_and_positions = [ 209 | (0, 0), 210 | (0, 1), 211 | (1, 2), 212 | (1, 3), 213 | (1, 4), 214 | (2, 10), 215 | ] 216 | 217 | with array_record_data_source.ArrayRecordDataSource( 218 | file_instructions 219 | ) as ar: 220 | self.assertLen(ar, 6) 221 | for record_key in range(len(ar)): 222 | self.assertEqual( 223 | expected_indices_and_positions[record_key], 224 | ar._reader_idx_and_position(record_key), 225 | ) 226 | 227 | def test_array_record_source_reader_idx_and_position_negative_idx(self): 228 | with array_record_data_source.ArrayRecordDataSource([ 229 | self.testdata_dir / "digits.array_record-00000-of-00002", 230 | self.testdata_dir / "digits.array_record-00001-of-00002", 231 | ]) as ar: 232 | with self.assertRaises(ValueError): 233 | ar._reader_idx_and_position(-1) 234 | 235 | with self.assertRaises(ValueError): 236 | ar._reader_idx_and_position(len(ar)) 237 | 238 | def test_array_record_source_empty_sequence(self): 239 | with self.assertRaises(ValueError): 240 | with array_record_data_source.ArrayRecordDataSource([]): 241 | pass 242 | 243 | def test_repr(self): 244 | ar = array_record_data_source.ArrayRecordDataSource([ 245 | self.testdata_dir / "digits.array_record-00000-of-00002", 246 | self.testdata_dir / "digits.array_record-00001-of-00002", 247 | ]) 248 | self.assertRegex(repr(ar), r"ArrayRecordDataSource\(hash_of_paths=[\w]+\)") 249 | 250 | 251 | class RunInParallelTest(parameterized.TestCase): 252 | 253 | def test_the_function_is_executed_with_kwargs(self): 254 | function = mock.Mock(return_value="return value") 255 | list_of_kwargs_to_function = [ 256 | {"foo": 1}, 257 | {"bar": 2}, 258 | ] 259 | result = array_record_data_source._run_in_parallel( 260 | function=function, 261 | list_of_kwargs_to_function=list_of_kwargs_to_function, 262 | num_workers=1, 263 | ) 264 | self.assertEqual(result, ["return value", "return value"]) 265 | self.assertEqual(function.call_count, 2) 266 | function.assert_has_calls([mock.call(foo=1), mock.call(bar=2)]) 267 | 268 | def test_exception_is_re_raised(self): 269 | function = mock.Mock() 270 | side_effect = ["return value", ValueError("Raised!")] 271 | function.side_effect = side_effect 272 | list_of_kwargs_to_function = [ 273 | {"foo": 1}, 274 | {"bar": 2}, 275 | ] 276 | self.assertEqual(len(side_effect), len(list_of_kwargs_to_function)) 277 | with self.assertRaisesRegex(ValueError, "Raised!"): 278 | array_record_data_source._run_in_parallel( 279 | function=function, 280 | list_of_kwargs_to_function=list_of_kwargs_to_function, 281 | num_workers=1, 282 | ) 283 | 284 | @parameterized.parameters([(-2,), (-1,), (0,)]) 285 | def test_num_workers_cannot_be_null_or_negative(self, num_workers): 286 | function = mock.Mock(return_value="return value") 287 | list_of_kwargs_to_function = [ 288 | {"foo": 1}, 289 | {"bar": 2}, 290 | ] 291 | with self.assertRaisesRegex( 292 | ValueError, "num_workers must be >=1 for parallelism." 293 | ): 294 | array_record_data_source._run_in_parallel( 295 | function=function, 296 | list_of_kwargs_to_function=list_of_kwargs_to_function, 297 | num_workers=num_workers, 298 | ) 299 | 300 | def test_num_workers_is_passed_to_thread_executor(self): 301 | function = mock.Mock(return_value="return value") 302 | list_of_kwargs_to_function = [ 303 | {"foo": 1}, 304 | {"bar": 2}, 305 | ] 306 | num_workers = 42 307 | with mock.patch( 308 | "concurrent.futures.ThreadPoolExecutor", 309 | wraps=futures.ThreadPoolExecutor, 310 | ) as executor: 311 | array_record_data_source._run_in_parallel( 312 | function=function, 313 | list_of_kwargs_to_function=list_of_kwargs_to_function, 314 | num_workers=num_workers, 315 | ) 316 | executor.assert_called_with(num_workers) 317 | 318 | 319 | if __name__ == "__main__": 320 | absltest.main() 321 | -------------------------------------------------------------------------------- /python/array_record_module.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2022 Google LLC. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #include 17 | #include 18 | 19 | #include 20 | #include 21 | #include 22 | #include 23 | 24 | #include "absl/status/status.h" 25 | #include "absl/strings/str_format.h" 26 | #include "absl/strings/string_view.h" 27 | #include "cpp/array_record_reader.h" 28 | #include "cpp/array_record_writer.h" 29 | #include "cpp/thread_pool.h" 30 | #include "pybind11/gil.h" 31 | #include "pybind11/pybind11.h" 32 | #include "pybind11/pytypes.h" 33 | #include "pybind11/stl.h" 34 | #include "riegeli/base/maker.h" 35 | #include "riegeli/bytes/fd_reader.h" 36 | #include "riegeli/bytes/fd_writer.h" 37 | 38 | namespace py = pybind11; 39 | 40 | PYBIND11_MODULE(array_record_module, m) { 41 | using array_record::ArrayRecordReaderBase; 42 | using array_record::ArrayRecordWriterBase; 43 | 44 | py::class_(m, "ArrayRecordWriter") 45 | .def(py::init([](const std::string& path, const std::string& options, 46 | const py::kwargs& kwargs) -> ArrayRecordWriterBase* { 47 | auto status_or_option = 48 | ArrayRecordWriterBase::Options::FromString(options); 49 | if (!status_or_option.ok()) { 50 | throw py::value_error( 51 | std::string(status_or_option.status().message())); 52 | } 53 | // Release the GIL because IO is time consuming. 54 | py::gil_scoped_release scoped_release; 55 | return new array_record::ArrayRecordWriter( 56 | riegeli::Maker(path), 57 | status_or_option.value()); 58 | }), 59 | py::arg("path"), py::arg("options") = "") 60 | .def("ok", &ArrayRecordWriterBase::ok) 61 | .def("close", 62 | [](ArrayRecordWriterBase& writer) { 63 | if (!writer.Close()) { 64 | throw std::runtime_error(std::string(writer.status().message())); 65 | } 66 | }) 67 | .def("is_open", &ArrayRecordWriterBase::is_open) 68 | // We accept only py::bytes (and not unicode strings) since we expect 69 | // most users to store binary data (e.g. serialized protocol buffers). 70 | // We cannot know if a users wants to write+read unicode and forcing users 71 | // to encode() their unicode strings avoids accidental conversions. 72 | .def("write", [](ArrayRecordWriterBase& writer, py::bytes record) { 73 | if (!writer.WriteRecord(record)) { 74 | throw std::runtime_error(std::string(writer.status().message())); 75 | } 76 | }); 77 | py::class_(m, "ArrayRecordReader") 78 | .def(py::init([](const std::string& path, const std::string& options, 79 | const py::kwargs& kwargs) -> ArrayRecordReaderBase* { 80 | auto status_or_option = 81 | ArrayRecordReaderBase::Options::FromString(options); 82 | if (!status_or_option.ok()) { 83 | throw py::value_error( 84 | std::string(status_or_option.status().message())); 85 | } 86 | riegeli::FdReaderBase::Options file_reader_options; 87 | if (kwargs.contains("file_reader_buffer_size")) { 88 | auto file_reader_buffer_size = 89 | kwargs["file_reader_buffer_size"].cast(); 90 | file_reader_options.set_buffer_size(file_reader_buffer_size); 91 | } 92 | // Release the GIL because IO is time consuming. 93 | py::gil_scoped_release scoped_release; 94 | return new array_record::ArrayRecordReader( 95 | riegeli::Maker( 96 | path, std::move(file_reader_options)), 97 | status_or_option.value(), 98 | array_record::ArrayRecordGlobalPool()); 99 | }), 100 | py::arg("path"), py::arg("options") = "", R"( 101 | ArrayRecordReader for fast sequential or random access. 102 | 103 | Args: 104 | path: File path to the input file. 105 | options: String with options for ArrayRecord. See syntax below. 106 | Kwargs: 107 | file_reader_buffer_size: Optional size of the buffer (in bytes) 108 | for the underlying file (Riegeli) reader. The default buffer 109 | size is 1 MiB. 110 | file_options: Optional file::Options to use for the underlying 111 | file (Riegeli) reader. 112 | 113 | options ::= option? ("," option?)* 114 | option ::= 115 | "readahead_buffer_size" ":" readahead_buffer_size | 116 | "max_parallelism" ":" max_parallelism 117 | readahead_buffer_size ::= non-negative integer expressed as real with 118 | optional suffix [BkKMGTPE]. (Default 16MB). Set to 0 optimizes 119 | random access performance. 120 | max_parallelism ::= `auto` or non-negative integer. Each parallel 121 | thread owns its readhaed buffer with the size 122 | `readahead_buffer_size`. (Default thread pool size) Set to 0 123 | optimizes random access performance. 124 | 125 | The default option is optimized for sequential access. To optimize 126 | the random access performance, set the options to 127 | "readahead_buffer_size:0,max_parallelism:0". 128 | )") 129 | .def("ok", &ArrayRecordReaderBase::ok) 130 | .def("close", 131 | [](ArrayRecordReaderBase& reader) { 132 | if (!reader.Close()) { 133 | throw std::runtime_error(std::string(reader.status().message())); 134 | } 135 | }) 136 | .def("is_open", &ArrayRecordReaderBase::is_open) 137 | .def("num_records", &ArrayRecordReaderBase::NumRecords) 138 | .def("record_index", &ArrayRecordReaderBase::RecordIndex) 139 | .def("writer_options_string", &ArrayRecordReaderBase::WriterOptionsString) 140 | .def("seek", 141 | [](ArrayRecordReaderBase& reader, int64_t record_index) { 142 | if (!reader.SeekRecord(record_index)) { 143 | throw std::runtime_error(std::string(reader.status().message())); 144 | } 145 | }) 146 | // See write() for why this returns py::bytes. 147 | .def("read", 148 | [](ArrayRecordReaderBase& reader) { 149 | absl::string_view string_view; 150 | if (!reader.ReadRecord(&string_view)) { 151 | if (reader.ok()) { 152 | throw std::out_of_range(absl::StrFormat( 153 | "Out of range of num_records: %d", reader.NumRecords())); 154 | } 155 | throw std::runtime_error(std::string(reader.status().message())); 156 | } 157 | return py::bytes(string_view); 158 | }) 159 | .def("read", 160 | [](ArrayRecordReaderBase& reader, std::vector indices) { 161 | std::vector staging(indices.size()); 162 | py::list output(indices.size()); 163 | { 164 | py::gil_scoped_release scoped_release; 165 | auto status = reader.ParallelReadRecordsWithIndices( 166 | indices, 167 | [&](uint64_t indices_index, 168 | absl::string_view record_data) -> absl::Status { 169 | staging[indices_index] = record_data; 170 | return absl::OkStatus(); 171 | }); 172 | if (!status.ok()) { 173 | throw std::runtime_error(std::string(status.message())); 174 | } 175 | } 176 | // TODO(fchern): Can we write the data directly to the output 177 | // list in our Parallel loop? 178 | ssize_t index = 0; 179 | for (const auto& record : staging) { 180 | auto py_record = py::bytes(record); 181 | PyList_SET_ITEM(output.ptr(), index++, 182 | py_record.release().ptr()); 183 | } 184 | return output; 185 | }) 186 | .def("read", 187 | [](ArrayRecordReaderBase& reader, int32_t begin, int32_t end) { 188 | int32_t range_begin = begin, range_end = end; 189 | if (range_begin < 0) { 190 | range_begin = reader.NumRecords() + range_begin; 191 | } 192 | if (range_end < 0) { 193 | range_end = reader.NumRecords() + range_end; 194 | } 195 | if (range_begin > reader.NumRecords() || range_begin < 0 || 196 | range_end > reader.NumRecords() || range_end < 0 || 197 | range_end <= range_begin) { 198 | throw std::out_of_range( 199 | absl::StrFormat("[%d, %d) is of range of [0, %d)", begin, 200 | end, reader.NumRecords())); 201 | } 202 | int32_t num_to_read = range_end - range_begin; 203 | std::vector staging(num_to_read); 204 | py::list output(num_to_read); 205 | { 206 | py::gil_scoped_release scoped_release; 207 | auto status = reader.ParallelReadRecordsInRange( 208 | range_begin, range_end, 209 | [&](uint64_t index, 210 | absl::string_view record_data) -> absl::Status { 211 | staging[index - range_begin] = record_data; 212 | return absl::OkStatus(); 213 | }); 214 | if (!status.ok()) { 215 | throw std::runtime_error(std::string(status.message())); 216 | } 217 | } 218 | // TODO(fchern): Can we write the data directly to the output 219 | // list in our Parallel loop? 220 | ssize_t index = 0; 221 | for (const auto& record : staging) { 222 | auto py_record = py::bytes(record); 223 | PyList_SET_ITEM(output.ptr(), index++, 224 | py_record.release().ptr()); 225 | } 226 | return output; 227 | }) 228 | .def("read_all", [](ArrayRecordReaderBase& reader) { 229 | std::vector staging(reader.NumRecords()); 230 | py::list output(reader.NumRecords()); 231 | { 232 | py::gil_scoped_release scoped_release; 233 | auto status = reader.ParallelReadRecords( 234 | [&](uint64_t index, 235 | absl::string_view record_data) -> absl::Status { 236 | staging[index] = record_data; 237 | return absl::OkStatus(); 238 | }); 239 | if (!status.ok()) { 240 | throw std::runtime_error(std::string(status.message())); 241 | } 242 | } 243 | // TODO(fchern): Can we write the data directly to the output 244 | // list in our Parallel loop? 245 | ssize_t index = 0; 246 | for (const auto& record : staging) { 247 | auto py_record = py::bytes(record); 248 | PyList_SET_ITEM(output.ptr(), index++, py_record.release().ptr()); 249 | } 250 | return output; 251 | }); 252 | } 253 | -------------------------------------------------------------------------------- /python/array_record_module_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for array_record_module.""" 16 | 17 | import os 18 | 19 | from absl.testing import absltest 20 | 21 | from python.array_record_module import ArrayRecordReader 22 | from python.array_record_module import ArrayRecordWriter 23 | 24 | 25 | class ArrayRecordModuleTest(absltest.TestCase): 26 | 27 | def setUp(self): 28 | super(ArrayRecordModuleTest, self).setUp() 29 | self.test_file = os.path.join(self.create_tempdir().full_path, 30 | "test.arecord") 31 | 32 | def test_open_and_close(self): 33 | writer = ArrayRecordWriter(self.test_file) 34 | self.assertTrue(writer.ok()) 35 | self.assertTrue(writer.is_open()) 36 | writer.close() 37 | self.assertFalse(writer.is_open()) 38 | 39 | reader = ArrayRecordReader(self.test_file) 40 | self.assertTrue(reader.ok()) 41 | self.assertTrue(reader.is_open()) 42 | reader.close() 43 | self.assertFalse(reader.is_open()) 44 | 45 | def test_bad_options(self): 46 | 47 | def create_writer(): 48 | ArrayRecordWriter(self.test_file, "blah") 49 | 50 | def create_reader(): 51 | ArrayRecordReader(self.test_file, "blah") 52 | 53 | self.assertRaises(ValueError, create_writer) 54 | self.assertRaises(ValueError, create_reader) 55 | 56 | def test_write_read(self): 57 | writer = ArrayRecordWriter(self.test_file) 58 | test_strs = [b"abc", b"def", b"ghi"] 59 | for s in test_strs: 60 | writer.write(s) 61 | writer.close() 62 | reader = ArrayRecordReader( 63 | self.test_file, "readahead_buffer_size:0,max_parallelism:0" 64 | ) 65 | num_strs = len(test_strs) 66 | self.assertEqual(reader.num_records(), num_strs) 67 | self.assertEqual(reader.record_index(), 0) 68 | for gt in test_strs: 69 | result = reader.read() 70 | self.assertEqual(result, gt) 71 | self.assertRaises(IndexError, reader.read) 72 | reader.seek(0) 73 | self.assertEqual(reader.record_index(), 0) 74 | self.assertEqual(reader.read(), test_strs[0]) 75 | self.assertEqual(reader.record_index(), 1) 76 | 77 | def test_write_read_non_unicode(self): 78 | writer = ArrayRecordWriter(self.test_file) 79 | b = b"F\xc3\xb8\xc3\xb6\x97\xc3\xa5r" 80 | writer.write(b) 81 | writer.close() 82 | reader = ArrayRecordReader(self.test_file) 83 | self.assertEqual(reader.read(), b) 84 | 85 | def test_write_read_with_file_reader_buffer_size(self): 86 | writer = ArrayRecordWriter(self.test_file) 87 | b = b"F\xc3\xb8\xc3\xb6\x97\xc3\xa5r" 88 | writer.write(b) 89 | writer.close() 90 | reader = ArrayRecordReader(self.test_file, file_reader_buffer_size=2**10) 91 | self.assertEqual(reader.read(), b) 92 | 93 | def test_batch_read(self): 94 | writer = ArrayRecordWriter(self.test_file) 95 | test_strs = [b"abc", b"def", b"ghi", b"kkk", b"..."] 96 | for s in test_strs: 97 | writer.write(s) 98 | writer.close() 99 | reader = ArrayRecordReader(self.test_file) 100 | results = reader.read_all() 101 | self.assertEqual(test_strs, results) 102 | indices = [1, 3, 0] 103 | expected = [test_strs[i] for i in indices] 104 | batch_fetch = reader.read(indices) 105 | self.assertEqual(expected, batch_fetch) 106 | 107 | def test_read_range(self): 108 | writer = ArrayRecordWriter(self.test_file) 109 | test_strs = [b"abc", b"def", b"ghi", b"kkk", b"..."] 110 | for s in test_strs: 111 | writer.write(s) 112 | writer.close() 113 | reader = ArrayRecordReader(self.test_file) 114 | 115 | def invalid_range1(): 116 | reader.read(0, 0) 117 | 118 | self.assertRaises(IndexError, invalid_range1) 119 | 120 | def invalid_range2(): 121 | reader.read(0, 100) 122 | 123 | self.assertRaises(IndexError, invalid_range2) 124 | 125 | def invalid_range3(): 126 | reader.read(3, 2) 127 | 128 | self.assertRaises(IndexError, invalid_range3) 129 | 130 | self.assertEqual(reader.read(0, -1), test_strs[0:-1]) 131 | self.assertEqual(reader.read(-3, -1), test_strs[-3:-1]) 132 | self.assertEqual(reader.read(1, 3), test_strs[1:3]) 133 | 134 | def test_writer_options(self): 135 | writer = ArrayRecordWriter(self.test_file, "group_size:42") 136 | writer.write(b"test123") 137 | writer.close() 138 | reader = ArrayRecordReader(self.test_file) 139 | # Includes default options. 140 | self.assertEqual( 141 | reader.writer_options_string(), 142 | "group_size:42,transpose:false,pad_to_block_boundary:false,zstd:3," 143 | "window_log:20,max_parallelism:1") 144 | 145 | 146 | if __name__ == "__main__": 147 | absltest.main() 148 | -------------------------------------------------------------------------------- /python/testdata/BUILD: -------------------------------------------------------------------------------- 1 | package(default_visibility = ["//visibility:public"]) 2 | 3 | licenses(["notice"]) 4 | 5 | exports_files(glob(["*.array_record-*"])) 6 | -------------------------------------------------------------------------------- /python/testdata/digits.array_record-00000-of-00002: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/array_record/6cbfde75d747832769baffa40a8858046bfc1532/python/testdata/digits.array_record-00000-of-00002 -------------------------------------------------------------------------------- /python/testdata/digits.array_record-00001-of-00002: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/array_record/6cbfde75d747832769baffa40a8858046bfc1532/python/testdata/digits.array_record-00001-of-00002 -------------------------------------------------------------------------------- /requirements.in: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The ArrayRecord Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # This is the list of our Python third_party dependencies that Bazel should 16 | # pull from PyPi. 17 | # Note that requirements.txt must be re-generated using 18 | # bazel run //:requirements.update in the OSS version. 19 | 20 | etils[epath] 21 | -------------------------------------------------------------------------------- /requirements_lock.txt: -------------------------------------------------------------------------------- 1 | # 2 | # This file is autogenerated by pip-compile with Python 3.10 3 | # by the following command: 4 | # 5 | # bazel run //:requirements.update 6 | # 7 | etils[epath,epy]==1.11.0 \ 8 | --hash=sha256:a394cf3476bcec51c221426a70c39cd1006e889456ba41e4d7f12fd6814be7a5 \ 9 | --hash=sha256:aff3278a3be7fddf302dfd80335e9f924244666c71239cd91e836f3d055f1c4a 10 | # via -r requirements.in 11 | fsspec==2024.12.0 \ 12 | --hash=sha256:670700c977ed2fb51e0d9f9253177ed20cbde4a3e5c0283cc5385b5870c8533f \ 13 | --hash=sha256:b520aed47ad9804237ff878b504267a3b0b441e97508bd6d2d8774e3db85cee2 14 | # via etils 15 | importlib-resources==6.4.5 \ 16 | --hash=sha256:980862a1d16c9e147a59603677fa2aa5fd82b87f223b6cb870695bcfce830065 \ 17 | --hash=sha256:ac29d5f956f01d5e4bb63102a5a19957f1b9175e45649977264a1416783bb717 18 | # via etils 19 | typing-extensions==4.12.2 \ 20 | --hash=sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d \ 21 | --hash=sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8 22 | # via etils 23 | zipp==3.21.0 \ 24 | --hash=sha256:2c9958f6430a2040341a52eb608ed6dd93ef4392e02ffe219417c1b28b5dd1f4 \ 25 | --hash=sha256:ac1bbe05fd2991f160ebce24ffbac5f6d11d83dc90891255885223d42b3cd931 26 | # via etils 27 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """Setup.py file for array_record.""" 2 | 3 | from setuptools import find_packages 4 | from setuptools import setup 5 | from setuptools.dist import Distribution 6 | 7 | REQUIRED_PACKAGES = [ 8 | 'absl-py', 9 | 'etils[epath]', 10 | ] 11 | 12 | BEAM_EXTRAS = [ 13 | 'apache-beam[gcp]==2.53.0', 14 | 'google-cloud-storage>=2.11.0', 15 | 'tensorflow>=2.14.0' 16 | ] 17 | 18 | 19 | class BinaryDistribution(Distribution): 20 | """This class makes 'bdist_wheel' include an ABI tag on the wheel.""" 21 | 22 | def has_ext_modules(self): 23 | return True 24 | 25 | 26 | setup( 27 | name='array_record', 28 | version='0.6.0', 29 | description='A file format that achieves a new frontier of IO efficiency', 30 | author='ArrayRecord team', 31 | author_email='no-reply@google.com', 32 | packages=find_packages(), 33 | include_package_data=True, 34 | package_data={'': ['*.so']}, 35 | python_requires='>=3.10', 36 | install_requires=REQUIRED_PACKAGES, 37 | extras_require={'beam': BEAM_EXTRAS}, 38 | url='https://github.com/google/array_record', 39 | license='Apache-2.0', 40 | classifiers=[ 41 | 'Programming Language :: Python :: 3.10', 42 | 'Programming Language :: Python :: 3.11', 43 | 'Programming Language :: Python :: 3.12', 44 | ], 45 | zip_safe=False, 46 | distclass=BinaryDistribution, 47 | ) 48 | --------------------------------------------------------------------------------