├── src └── sedpack │ ├── py.typed │ ├── io │ ├── flatbuffer │ │ ├── shardfile │ │ │ ├── __init__.py │ │ │ ├── Example.py │ │ │ ├── Shard.py │ │ │ └── Attribute.py │ │ ├── unit_tests │ │ │ ├── shard_writer_flatbuffer_test_schema │ │ │ │ └── __init__.py │ │ │ └── shard_writer_flatbuffer_test_schema.fbs │ │ ├── shard.fbs │ │ └── __init__.py │ ├── npz │ │ └── __init__.py │ ├── tfrec │ │ ├── __init__.py │ │ └── read.py │ ├── shard │ │ ├── __init__.py │ │ ├── get_shard_writer.py │ │ ├── iterate_shard_base.py │ │ ├── shard.py │ │ ├── shard_writer_tfrec.py │ │ └── shard_writer_base.py │ ├── iteration │ │ └── __init__.py │ ├── shard_info_iterator │ │ ├── __init__.py │ │ └── shard_info_iterator.py │ ├── __init__.py │ ├── itertools │ │ └── __init__.py │ ├── errors.py │ ├── types.py │ ├── compress.py │ ├── dataset.py │ └── merge_shard_infos.py │ └── __init__.py ├── .style.yapf ├── website ├── tsconfig.json ├── src │ ├── env.d.ts │ ├── content │ │ ├── docs │ │ │ ├── tutorials │ │ │ │ └── sca │ │ │ │ │ ├── tiny_aes_trace.png │ │ │ │ │ ├── tiny_aes_snr_sbi_0.png │ │ │ │ │ ├── overview.mdx │ │ │ │ │ └── dataset.mdx │ │ │ ├── index.mdx │ │ │ └── start_here │ │ │ │ ├── intro.md │ │ │ │ └── install.mdx │ │ └── config.ts │ └── mathjax.css ├── .gitignore ├── README.md ├── package.json └── astro.config.mjs ├── .github ├── cspell_matcher.json ├── workflows │ ├── mdlint.yml │ ├── yapf.yml │ ├── piptest.yml │ ├── rust_ci.yml │ ├── pylint.yml │ ├── spellcheck.yml │ ├── deploy.yml │ ├── base_benchmarks.yml │ ├── mypy.yml │ ├── fork_pr_benchmarks_track.yml │ ├── fork_pr_benchmarks_run.yml │ └── pytest.yml ├── python_matcher.json └── dependabot.yml ├── rust ├── rustfmt.toml ├── benches │ ├── setup.sh │ └── my_benchmark.rs └── Cargo.toml ├── setup.py ├── cspell.json ├── tools ├── run_pylint.sh └── check_copyright.sh ├── .markdownlint.json ├── tests └── io │ ├── shard │ ├── test_shard_writer_base.py │ └── test_shard_write_async.py │ ├── itertools │ ├── test_itertools.py │ └── test_lazy_pool.py │ ├── test_error.py │ ├── test_dataset_base.py │ ├── test_file_info.py │ ├── iteration │ └── test_rust_generator.py │ ├── npz │ └── test_npz_shards.py │ ├── test_hash_checksums.py │ ├── test_bytes.py │ ├── test_compression.py │ ├── test_rust_iter.py │ ├── test_end2end_async.py │ ├── test_write_multiprocessing.py │ ├── shard_info_iterator │ └── test_balanced_iterator.py │ ├── test_shard_custom_metadata.py │ ├── test_continue_writing.py │ ├── test_end2end_wrong_type.py │ └── test_as_tfdataset.py ├── docs ├── contributing.md ├── code-of-conduct.md └── tutorials │ └── quick_start │ ├── mnist_save.py │ └── mnist_read_keras.py ├── project-words.txt ├── mypy.ini ├── pyproject.toml ├── README.md ├── .gitignore └── base-tooling-requirements.txt /src/sedpack/py.typed: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/sedpack/io/flatbuffer/shardfile/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.style.yapf: -------------------------------------------------------------------------------- 1 | [style] 2 | based_on_style = google 3 | -------------------------------------------------------------------------------- /website/tsconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "extends": "astro/tsconfigs/strict" 3 | } -------------------------------------------------------------------------------- /src/sedpack/io/flatbuffer/unit_tests/shard_writer_flatbuffer_test_schema/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /website/src/env.d.ts: -------------------------------------------------------------------------------- 1 | /// 2 | /// 3 | -------------------------------------------------------------------------------- /website/src/content/docs/tutorials/sca/tiny_aes_trace.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/sedpack/HEAD/website/src/content/docs/tutorials/sca/tiny_aes_trace.png -------------------------------------------------------------------------------- /website/src/content/docs/tutorials/sca/tiny_aes_snr_sbi_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/sedpack/HEAD/website/src/content/docs/tutorials/sca/tiny_aes_snr_sbi_0.png -------------------------------------------------------------------------------- /website/src/mathjax.css: -------------------------------------------------------------------------------- 1 | /* Starlight styles all images in Markdown with `display: block` forcing 2 | * MathJax's inline SVG on new lines. Fix by: 3 | */ 4 | mjx-container svg { 5 | display: inline !important; 6 | } 7 | -------------------------------------------------------------------------------- /website/src/content/config.ts: -------------------------------------------------------------------------------- 1 | import { defineCollection } from 'astro:content'; 2 | import { docsSchema } from '@astrojs/starlight/schema'; 3 | 4 | export const collections = { 5 | docs: defineCollection({ schema: docsSchema() }), 6 | }; 7 | -------------------------------------------------------------------------------- /website/.gitignore: -------------------------------------------------------------------------------- 1 | # build output 2 | dist/ 3 | # generated types 4 | .astro/ 5 | 6 | # dependencies 7 | node_modules/ 8 | 9 | # logs 10 | npm-debug.log* 11 | yarn-debug.log* 12 | yarn-error.log* 13 | pnpm-debug.log* 14 | 15 | 16 | # environment variables 17 | .env 18 | .env.production 19 | 20 | # macOS-specific files 21 | .DS_Store 22 | -------------------------------------------------------------------------------- /website/src/content/docs/index.mdx: -------------------------------------------------------------------------------- 1 | --- 2 | title: Welcome to Sedpack 3 | description: Scalable and Efficient Data Packing 4 | template: splash 5 | hero: 6 | tagline: Start using Sedpack. 7 | actions: 8 | - text: Get started 9 | link: /sedpack/start_here/intro/ 10 | icon: right-arrow 11 | --- 12 | 13 | import { Card, CardGrid } from '@astrojs/starlight/components'; 14 | -------------------------------------------------------------------------------- /website/src/content/docs/start_here/intro.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: Sedpack Intro 3 | description: Sedpack - Scalable and efficient data packing 4 | --- 5 | 6 | Scalable and efficient data packing 7 | 8 | See [code samples](https://github.com/google/sedpack/tree/main/docs). 9 | 10 | The code is a major refactor of the data saving and loading code from the 11 | [SCAAML](https://github.com/google/scaaml) project. 12 | -------------------------------------------------------------------------------- /.github/cspell_matcher.json: -------------------------------------------------------------------------------- 1 | { 2 | "problemMatcher": [ 3 | { 4 | "owner": "spellcheck", 5 | "pattern": [ 6 | { 7 | "regexp": "^([^:]+):(\\d+):([^:]+)\\s*-\\s*(.*)$", 8 | "file": 1, 9 | "line": 2, 10 | "column": 3, 11 | "message": 4 12 | } 13 | ] 14 | } 15 | ] 16 | } 17 | -------------------------------------------------------------------------------- /rust/rustfmt.toml: -------------------------------------------------------------------------------- 1 | comment_width = 100 2 | fn_params_layout = "Compressed" 3 | format_code_in_doc_comments = true 4 | format_strings = true 5 | group_imports = "StdExternalCrate" 6 | imports_granularity = "Module" 7 | normalize_comments = true 8 | normalize_doc_attributes = true 9 | spaces_around_ranges = true 10 | use_small_heuristics = "Max" 11 | where_single_line = true 12 | wrap_comments = true 13 | ignore = [ 14 | # Autogenerated files 15 | "src/shard_generated.rs", 16 | ] 17 | -------------------------------------------------------------------------------- /website/README.md: -------------------------------------------------------------------------------- 1 | # Sedpack Documentation 2 | 3 | Website available at [google.github.io/sedpack](https://google.github.io/sedpack) 4 | 5 | ## Adding more documentation 6 | 7 | Local development preparation: 8 | 9 | - Install requirements: `sudo apt-get install nodejs npm` 10 | - Install project: `cd website/ ; npm install` 11 | 12 | Run local server: `npm run dev` 13 | 14 | ## Deploy to github pages 15 | 16 | Manually run the "Deploy to GitHub Pages" workflow (needs maintainer 17 | permission). 18 | -------------------------------------------------------------------------------- /website/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "documentation", 3 | "type": "module", 4 | "version": "0.0.1", 5 | "scripts": { 6 | "dev": "astro dev", 7 | "start": "astro dev", 8 | "build": "astro check && astro build", 9 | "preview": "astro preview", 10 | "astro": "astro" 11 | }, 12 | "dependencies": { 13 | "@astrojs/check": "^0.9.3", 14 | "@astrojs/starlight": "^0.37.0", 15 | "astro": "^5.16.0", 16 | "rehype-mathjax": "^7.1.0", 17 | "remark-math": "^6.0.0", 18 | "sharp": "^0.34.0", 19 | "typescript": "^5.9.2" 20 | } 21 | } 22 | -------------------------------------------------------------------------------- /src/sedpack/io/flatbuffer/shard.fbs: -------------------------------------------------------------------------------- 1 | // Shard file schema using https://flatbuffers.dev/ 2 | 3 | // Call `flatc --python sedpack/io/flatbuffer/shard.fbs` from the src/ 4 | // directory otherwise the autogenerated code contains wrong imports. Beware 5 | // that this overwrites all __init__.py files on the path. 6 | // Also remember to update `sedpack_rs/src/shard_generated.rs`. 7 | 8 | 9 | namespace sedpack.io.flatbuffer.shardfile; 10 | 11 | table Attribute { 12 | attribute_bytes:[ubyte]; 13 | } 14 | 15 | table Example { 16 | attributes:[Attribute]; 17 | } 18 | 19 | table Shard { 20 | examples:[Example]; 21 | } 22 | 23 | root_type Shard; 24 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Compatibility with legacy builds. 15 | """ 16 | 17 | from setuptools import setup 18 | 19 | setup() 20 | -------------------------------------------------------------------------------- /cspell.json: -------------------------------------------------------------------------------- 1 | { 2 | "$schema": "https://raw.githubusercontent.com/streetsidesoftware/cspell/main/cspell.schema.json", 3 | "version": "0.2", 4 | "dictionaryDefinitions": [ 5 | { 6 | "name": "project-words", 7 | "path": "./project-words.txt", 8 | "addWords": true 9 | } 10 | ], 11 | "dictionaries": [ 12 | "en-gb", 13 | "en_US", 14 | "project-words", 15 | "python" 16 | ], 17 | "ignorePaths": [ 18 | ".pylintrc", 19 | "/.github", 20 | "/project-words.txt", 21 | "/tests", 22 | "base-tooling-requirements.txt", 23 | "cspell.json", 24 | "node_modules", 25 | "requirements.in", 26 | "requirements.txt" 27 | ] 28 | } 29 | -------------------------------------------------------------------------------- /tools/run_pylint.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Copyright 2022 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Ensure we are at the project root directory 17 | cd $(readlink -f $(dirname $0))/.. 18 | 19 | pylint *.py src docs 20 | -------------------------------------------------------------------------------- /src/sedpack/io/npz/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Dataset creation and usage.""" 15 | 16 | from sedpack.io.npz.iterate_npz import IterateShardNP 17 | 18 | __all__ = [ 19 | "IterateShardNP", 20 | ] 21 | -------------------------------------------------------------------------------- /src/sedpack/io/tfrec/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Dataset creation and usage.""" 15 | 16 | from sedpack.io.tfrec.read import IterateShardTFRec 17 | 18 | __all__ = [ 19 | "IterateShardTFRec", 20 | ] 21 | -------------------------------------------------------------------------------- /src/sedpack/io/flatbuffer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Dataset creation and usage.""" 15 | 16 | from sedpack.io.flatbuffer.iterate import IterateShardFlatBuffer 17 | 18 | __all__ = [ 19 | "IterateShardFlatBuffer", 20 | ] 21 | -------------------------------------------------------------------------------- /src/sedpack/io/shard/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023-2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Dataset creation and usage.""" 15 | 16 | from sedpack.io.shard.iterate_shard_base import IterateShardBase 17 | from sedpack.io.shard.shard import Shard 18 | 19 | __all__ = [ 20 | "IterateShardBase", 21 | "Shard", 22 | ] 23 | -------------------------------------------------------------------------------- /.markdownlint.json: -------------------------------------------------------------------------------- 1 | { 2 | "default": true, 3 | "heading-style": { 4 | "style": "atx" 5 | }, 6 | "no-trailing-spaces": { 7 | "br_spaces": 0, 8 | "strict": true 9 | }, 10 | "ul-indent": { 11 | "indent": 4 12 | }, 13 | "line-length": { 14 | "line_length": 80, 15 | "heading_line_length": 120, 16 | "tables": false, 17 | "code_blocks": false 18 | }, 19 | "list-marker-space": { 20 | "ol_single": 2, 21 | "ol_multi": 2, 22 | "ul_single": 3, 23 | "ul_multi": 3 24 | }, 25 | "no-inline-html": { 26 | "allowed_elements": [ 27 | "img" 28 | ] 29 | }, 30 | "fenced-code-language": true, 31 | "code-block-style": { 32 | "style": "fenced" 33 | }, 34 | "code-fence-style": { 35 | "style": "backtick" 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /.github/workflows/mdlint.yml: -------------------------------------------------------------------------------- 1 | name: markdownlint 2 | on: 3 | pull_request: 4 | types: [opened, synchronize, reopened] 5 | paths: 6 | - '**/*.md' 7 | - '**/*.mdx' 8 | - '.markdownlint.json' 9 | merge_group: # Needed for required workflows 10 | # Run after a review has been submitted (this is a required workflow which 11 | # might not be triggered when no code changes -- trigger before going to 12 | # merge queue). 13 | pull_request_review: 14 | types: [submitted] 15 | 16 | jobs: 17 | lint: 18 | permissions: 19 | contents: read 20 | pull-requests: write 21 | runs-on: ubuntu-latest 22 | steps: 23 | - uses: actions/checkout@v6 24 | - uses: DavidAnson/markdownlint-cli2-action@07035fd053f7be764496c0f8d8f9f41f98305101 # v20 25 | with: 26 | config: .markdownlint.json 27 | globs: | 28 | *.md 29 | **/*.md 30 | -------------------------------------------------------------------------------- /src/sedpack/io/iteration/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Dataset iteration.""" 15 | 16 | from sedpack.io.iteration.rust_batched_generator import RustBatchedGenerator 17 | from sedpack.io.iteration.rust_generator import RustGenerator 18 | 19 | __all__ = [ 20 | "RustBatchedGenerator", 21 | "RustGenerator", 22 | ] 23 | -------------------------------------------------------------------------------- /rust/benches/setup.sh: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Setup datasets for Rust benchmarking. Call `bash benches/setup.sh` from the 16 | # `sedpack/rust` directory. 17 | 18 | python3 ../docs/tutorials/quick_start/mnist_save.py \ 19 | --dataset_directory mnist_fb_gzip \ 20 | --shard_file_type fb \ 21 | --compression GZIP 22 | -------------------------------------------------------------------------------- /tests/io/shard/test_shard_writer_base.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import get_args 16 | 17 | from sedpack.io.types import ShardFileTypeT 18 | from sedpack.io.shard.get_shard_writer import _SHARD_FILE_TYPE_TO_CLASS 19 | 20 | 21 | def test_all_file_types_supported(): 22 | assert set(_SHARD_FILE_TYPE_TO_CLASS.keys()) == set( 23 | get_args(ShardFileTypeT)) 24 | -------------------------------------------------------------------------------- /src/sedpack/io/shard_info_iterator/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Iterating of shard information.""" 15 | 16 | from sedpack.io.shard_info_iterator.shard_info_iterator import ShardInfoIterator 17 | from sedpack.io.shard_info_iterator.cached_shard_info_iterator import CachedShardInfoIterator 18 | 19 | __all__ = [ 20 | "CachedShardInfoIterator", 21 | "ShardInfoIterator", 22 | ] 23 | -------------------------------------------------------------------------------- /src/sedpack/io/flatbuffer/unit_tests/shard_writer_flatbuffer_test_schema.fbs: -------------------------------------------------------------------------------- 1 | // Shard file schema using https://flatbuffers.dev/ 2 | 3 | // Call `flatc --python shard_writer_flatbuffer_test_schema.fbs` from the 4 | // tests/io/shard directory otherwise the autogenerated code contains wrong 5 | // imports. Beware that this overwrites all __init__.py files on the path. 6 | 7 | 8 | namespace shard_writer_flatbuffer_test_schema; 9 | 10 | table NumPyVectorTest { 11 | // Vectors of types according to 12 | // https://flatbuffers.dev/flatbuffers_guide_writing_schema.html 13 | 14 | // 8 bit 15 | attribute_bool:[byte]; 16 | attribute_byte:[byte]; 17 | attribute_ubyte:[byte]; 18 | 19 | // 16 bit 20 | attribute_short:[byte]; 21 | attribute_ushort:[byte]; 22 | 23 | // 32 bit 24 | attribute_int:[byte]; 25 | attribute_uint:[byte]; 26 | attribute_float:[byte]; 27 | 28 | // 64 bit 29 | attribute_long:[byte]; 30 | attribute_ulong:[byte]; 31 | attribute_double:[byte]; 32 | } 33 | 34 | root_type NumPyVectorTest; 35 | -------------------------------------------------------------------------------- /src/sedpack/io/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023-2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Dataset creation and usage.""" 15 | 16 | from sedpack.io.dataset import Dataset 17 | from sedpack.io.dataset_filler import DatasetFiller, DatasetFillerContext 18 | from sedpack.io.metadata import Attribute 19 | from sedpack.io.metadata import DatasetStructure 20 | from sedpack.io.metadata import Metadata 21 | 22 | __all__ = [ 23 | "Attribute", 24 | "Dataset", 25 | "DatasetFiller", 26 | "DatasetFillerContext", 27 | "DatasetStructure", 28 | "Metadata", 29 | ] 30 | -------------------------------------------------------------------------------- /src/sedpack/io/itertools/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Dataset creation and usage.""" 15 | 16 | from sedpack.io.itertools.itertools import shuffle_buffer, shuffle_buffer_async 17 | from sedpack.io.itertools.itertools import round_robin, round_robin_async 18 | from sedpack.io.itertools.lazy_pool import LazyPool 19 | #from sedpack.io.itertools.lazy_pool_multiprocessing import LazyPool 20 | 21 | __all__ = [ 22 | "shuffle_buffer", 23 | "shuffle_buffer_async", 24 | "round_robin", 25 | "round_robin_async", 26 | "LazyPool", 27 | ] 28 | -------------------------------------------------------------------------------- /.github/workflows/yapf.yml: -------------------------------------------------------------------------------- 1 | name: yapf 2 | permissions: 3 | contents: read 4 | pull-requests: write 5 | on: 6 | pull_request: 7 | types: [opened, synchronize, reopened] 8 | paths: 9 | - '**/*.py' 10 | merge_group: # Needed for required workflows 11 | # Run after a review has been submitted (this is a required workflow which 12 | # might not be triggered when no code changes -- trigger before going to 13 | # merge queue). 14 | pull_request_review: 15 | types: [submitted] 16 | 17 | jobs: 18 | yapf: 19 | runs-on: ubuntu-22.04 20 | steps: 21 | - uses: actions/checkout@v6 22 | - name: Set up Python 3.10 23 | uses: actions/setup-python@v6 24 | with: 25 | python-version: '3.10' 26 | cache: 'pip' 27 | - name: Install dependencies 28 | run: | 29 | python -m pip install --upgrade pip setuptools wheel 30 | pip install --upgrade 'yapf>=0.30.0' 31 | - name: Register matcher 32 | run: 33 | echo ::add-matcher::./.github/python_matcher.json 34 | - name: Test code formatting with yapf 35 | run: 36 | yapf --recursive --diff . 37 | -------------------------------------------------------------------------------- /docs/contributing.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We would love to accept your patches and contributions to this project. 4 | 5 | ## Before you begin 6 | 7 | ### Sign our Contributor License Agreement 8 | 9 | Contributions to this project must be accompanied by a 10 | [Contributor License Agreement](https://cla.developers.google.com/about) (CLA). 11 | You (or your employer) retain the copyright to your contribution; this simply 12 | gives us permission to use and redistribute your contributions as part of the 13 | project. 14 | 15 | If you or your current employer have already signed the Google CLA (even if it 16 | was for a different project), you probably don't need to do it again. 17 | 18 | Visit to see your current agreements or to 19 | sign a new one. 20 | 21 | ### Review our Community Guidelines 22 | 23 | This project follows [Google's Open Source Community 24 | Guidelines](https://opensource.google/conduct/). 25 | 26 | ## Contribution process 27 | 28 | ### Code Reviews 29 | 30 | All submissions, including submissions by project members, require review. We 31 | use [GitHub pull requests](https://docs.github.com/articles/about-pull-requests) 32 | for this purpose. 33 | -------------------------------------------------------------------------------- /tools/check_copyright.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2023-2024 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | errors=0 18 | e() { 19 | echo -e "$(tput bold)$(tput setaf 1)Error:$(tput sgr0) $*" 20 | errors=$(( $error + 1 )) 21 | } 22 | 23 | # Files we want to check for copyright 24 | EXTENSIONS="py\|sh" 25 | 26 | EXCLUDE_FILES="src/sedpack/io/flatbuffer/shardfile/*\.py" 27 | 28 | for file in $(git ls-files | \ 29 | grep -e '\.\('"${EXTENSIONS}"'\)$' | \ 30 | grep -v -e '^\('"${EXCLUDE_FILES}"'\)$') 31 | do 32 | sed -n 'N;/Copyright/q;q1' $file || e "No copyright notice in $file" 33 | done 34 | 35 | if [ $errors -gt 0 ] 36 | then 37 | exit 1 38 | fi 39 | exit 0 40 | 41 | -------------------------------------------------------------------------------- /.github/workflows/piptest.yml: -------------------------------------------------------------------------------- 1 | name: piptest 2 | permissions: 3 | contents: read 4 | pull-requests: write 5 | on: 6 | pull_request: 7 | types: [opened, synchronize, reopened] 8 | paths: 9 | - 'docs/**/*.py' 10 | - 'pytest.ini' 11 | schedule: 12 | - cron: 0 5 * * 1 # Every Monday at 5:00 UTC 13 | 14 | jobs: 15 | piptesting: 16 | runs-on: ${{ matrix.platform.runner }} 17 | strategy: 18 | matrix: 19 | # ubuntu-24.04-arm is not stable enough 20 | platform: 21 | - runner: ubuntu-latest # x64 22 | - runner: windows-latest # x64 23 | - runner: macos-14 # arm64 24 | - runner: macos-15-intel # Intel 25 | - runner: macos-latest # arm64 26 | steps: 27 | - uses: actions/checkout@v6 28 | - name: Set up Python 3.10 29 | uses: actions/setup-python@v6 30 | with: 31 | python-version: '3.10' 32 | cache: 'pip' 33 | - name: Installing sedpack pip package 34 | run: | 35 | pip install sedpack 36 | - name: Run tutorial using sedpack pip package 37 | run: | 38 | python docs/tutorials/quick_start/mnist_save.py -d mnist_dataset 39 | python docs/tutorials/quick_start/mnist_read_keras.py -d mnist_dataset 40 | -------------------------------------------------------------------------------- /.github/workflows/rust_ci.yml: -------------------------------------------------------------------------------- 1 | name: Rust CI 2 | on: 3 | pull_request: 4 | types: [opened, synchronize, reopened] 5 | paths: 6 | - '**/*.rs' 7 | 8 | permissions: 9 | contents: read 10 | 11 | jobs: 12 | fmt: 13 | runs-on: ubuntu-latest 14 | defaults: 15 | run: 16 | working-directory: ./rust 17 | steps: 18 | - name: Checkout code 19 | uses: actions/checkout@v6 20 | # Ensure rustfmt is installed and setup problem matcher 21 | - uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 # v1 22 | with: 23 | components: rustfmt 24 | toolchain: nightly 25 | matcher: true 26 | - name: fmt 27 | run: cargo +nightly fmt -- --check 28 | clippy: 29 | runs-on: ubuntu-latest 30 | defaults: 31 | run: 32 | working-directory: ./rust 33 | steps: 34 | - name: Checkout code 35 | uses: actions/checkout@v6 36 | # Ensure clippy is installed and setup problem matcher 37 | - uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 # v1 38 | with: 39 | components: clippy 40 | toolchain: nightly 41 | - name: clippy 42 | run: cargo +nightly clippy --all-targets 43 | -------------------------------------------------------------------------------- /rust/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "sedpack_rs" 3 | version = "0.1.4" 4 | edition = "2024" 5 | description = "Rust bindings for sedpack a general ML dataset package" 6 | authors = [ 7 | "Elie Bursztein", 8 | "Karel Král", 9 | "Jean-Michel Picod", 10 | ] 11 | license = "Apache-2.0" 12 | # So far just a small subset of sedpack functionality. Not meant to be 13 | # published as a crate. 14 | publish = false 15 | 16 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 17 | [lib] 18 | name = "sedpack_rs" 19 | crate-type = ["cdylib", "rlib"] 20 | 21 | [dependencies] 22 | criterion = "0.8.1" 23 | flatbuffers = "25.9" 24 | flate2 = { version = "1.1.4" } 25 | glob = "0.3.3" 26 | lz4_flex = { version = "0.12.0", default-features = false , features = ["frame"] } 27 | numpy = "0.26" 28 | pyo3 = "0.26" 29 | rand = "0.9" 30 | rayon = "1.11.0" 31 | strum = "0.27.2" 32 | strum_macros = "0.27.2" 33 | tracing = "0.1.41" 34 | tracing-perfetto = "0.1.5" 35 | tracing-subscriber = "0.3.20" 36 | yoke = { version = "0.8", features = ["derive"] } 37 | zstd = "0.13.3" 38 | 39 | # Release performance optimizations. 40 | [profile.release] 41 | codegen-units = 1 42 | lto = true 43 | panic = "abort" 44 | 45 | [profile.bench] 46 | debug = true 47 | 48 | [[bench]] 49 | name = "my_benchmark" 50 | # Otherwise needs nightly. 51 | harness = false 52 | -------------------------------------------------------------------------------- /project-words.txt: -------------------------------------------------------------------------------- 1 | Adafactor 2 | AdamW 3 | Babuska 4 | Bursztein 5 | CHES 6 | CUDA 7 | Corinna 8 | Elems 9 | GPAM 10 | Hastie 11 | IACR 12 | Invernizzi 13 | Josyula 14 | Karel 15 | Král 16 | Luca 17 | MNIST 18 | Mangard 19 | Moghimi 20 | Pankaj 21 | Picod 22 | Pytree 23 | Pytrees 24 | Rohatgi 25 | Ryzen 26 | SBOX 27 | SCAAML 28 | SCARR 29 | Sakkis 30 | Suresh 31 | Tibshirani 32 | Woudenberg 33 | Yann 34 | Yokeable 35 | arange 36 | argmax 37 | argnames 38 | asyncstdlib 39 | bldr 40 | booktitle 41 | burszteindc 42 | byteswap 43 | cdylib 44 | codegen 45 | compresslevel 46 | convnet 47 | crossentropy 48 | ddof 49 | diutils 50 | dtype 51 | dtypes 52 | dunder 53 | einsum 54 | fbapi 55 | flatbuffer 56 | flate 57 | frombuffer 58 | howpublished 59 | hyperparameter 60 | hyperparameters 61 | hypertuned 62 | hypertuning 63 | inproceedings 64 | ipynb 65 | itemsize 66 | kwarguments 67 | linalg 68 | logpdf 69 | mathbb 70 | mathjax 71 | ndarray 72 | newbyteorder 73 | parseable 74 | patchelf 75 | perfcounters 76 | perfetto 77 | pickleable 78 | pyarray 79 | pyclass 80 | pyfunction 81 | pymethods 82 | pymodule 83 | riscure 84 | rlib 85 | rngs 86 | savez 87 | sedpack 88 | setaf 89 | shardfile 90 | subclassing 91 | tensorspec 92 | tfdata 93 | tfdataset 94 | tfrec 95 | tfrecord 96 | tfrecords 97 | tinyaes 98 | tobytes 99 | uoffset 100 | vmap 101 | vmaps 102 | xored 103 | zstd 104 | -------------------------------------------------------------------------------- /.github/python_matcher.json: -------------------------------------------------------------------------------- 1 | { 2 | "problemMatcher": [ 3 | { 4 | "owner": "yapf-diff", 5 | "pattern": [ 6 | { 7 | "regexp": "^---\\s*([^\\s]*)\\s*\\(original\\)$", 8 | "file": 1 9 | }, 10 | { 11 | "regexp": "^\\+\\+\\+\\s*([^\\s]*)\\s*\\((.*)\\)$", 12 | "message": 2 13 | }, 14 | { 15 | "regexp": "^@@\\s*-(\\d+),(\\d+)\\s*\\+(\\d+),(\\d+)\\s*@@$", 16 | "line": 1 17 | } 18 | ] 19 | }, 20 | { 21 | "owner": "pylint", 22 | "pattern": [ 23 | { 24 | "regexp": "^([^:]+):(\\d+):(\\d+):\\s*([CEFIRW]\\d{4}):\\s*(.*)$", 25 | "file": 1, 26 | "line": 2, 27 | "column": 3, 28 | "code": 4, 29 | "message": 5 30 | } 31 | ] 32 | }, 33 | { 34 | "owner": "mypy", 35 | "pattern": [ 36 | { 37 | "regexp": "^([^:]+):(\\d+):\\s*([^:]+):\\s*(.*)$", 38 | "file": 1, 39 | "line": 2, 40 | "code": 3, 41 | "message": 4 42 | } 43 | ] 44 | } 45 | ] 46 | } 47 | -------------------------------------------------------------------------------- /src/sedpack/io/errors.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Implement an error to indicate that a scaaml.io.Dataset already exists. 15 | 16 | Creating scaaml.io.Dataset should not overwrite existing files. When it could 17 | the constructor needs to raise an error, which should also contain the dataset 18 | directory. 19 | """ 20 | 21 | from pathlib import Path 22 | 23 | 24 | class DatasetExistsError(FileExistsError): 25 | """Error for signalling that the dataset already exists.""" 26 | 27 | def __init__(self, dataset_path: Path) -> None: 28 | """Represents that the dataset already exists. 29 | 30 | Args: 31 | dataset_path: The dataset path. 32 | """ 33 | super().__init__( 34 | f'Dataset info file exists and would be overwritten. Use instead:' 35 | f' Dataset(path="{dataset_path}")') 36 | self.dataset_path = dataset_path 37 | -------------------------------------------------------------------------------- /src/sedpack/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Dataset library. 15 | 16 | Version format: MAJOR.MINOR.PATCH (see https://pypi.org/project/semver/ for 17 | more possibilities). 18 | """ 19 | 20 | from importlib.metadata import version, PackageNotFoundError 21 | 22 | # The version of this package is defined by rust/Cargo.toml and dynamically 23 | # deduced by Maturin (see 24 | # https://www.maturin.rs/metadata.html#dynamic-metadata). To ensure 25 | # compatibility with existing code we dynamically set the __version__ attribute 26 | # here. 27 | try: 28 | # When package is installed use the version. 29 | __version__ = version("sedpack") # pylint: disable=invalid-name 30 | except PackageNotFoundError: 31 | # Package is not installed. The Rust part of this package is probably not 32 | # going to work in this case (the Rust binding would be probably missing). 33 | __version__ = "0.0.7-dev" # pylint: disable=invalid-name 34 | -------------------------------------------------------------------------------- /.github/workflows/pylint.yml: -------------------------------------------------------------------------------- 1 | name: pylint 2 | permissions: 3 | contents: read 4 | pull-requests: write 5 | on: 6 | pull_request: 7 | types: [opened, synchronize, reopened] 8 | paths: 9 | - '**/*.py' 10 | - '.pylintrc' 11 | merge_group: # Needed for required workflows 12 | # Run after a review has been submitted (this is a required workflow which 13 | # might not be triggered when no code changes -- trigger before going to 14 | # merge queue). 15 | pull_request_review: 16 | types: [submitted] 17 | 18 | jobs: 19 | copyright_header: 20 | runs-on: ubuntu-22.04 21 | steps: 22 | - uses: actions/checkout@v6 23 | - name: Check licence headers 24 | run: ./tools/check_copyright.sh 25 | 26 | pylint: 27 | runs-on: ubuntu-22.04 28 | strategy: 29 | matrix: 30 | python-version: ['3.10', '3.11', '3.12'] 31 | steps: 32 | - uses: actions/checkout@v6 33 | - name: Set up Python ${{ matrix.python-version }} 34 | uses: actions/setup-python@v6 35 | with: 36 | python-version: ${{ matrix.python-version }} 37 | cache: 'pip' 38 | - name: Install dependencies 39 | run: | 40 | python -m pip install --upgrade pip setuptools wheel 41 | pip install --require-hashes --no-deps -r requirements.txt 42 | pip install --upgrade pylint 43 | - name: Register matcher 44 | run: echo ::add-matcher::./.github/python_matcher.json 45 | - name: Test code with pylint 46 | run: ./tools/run_pylint.sh 47 | 48 | -------------------------------------------------------------------------------- /tests/io/itertools/test_itertools.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import random 16 | 17 | from sedpack.io.itertools import * 18 | 19 | 20 | def test_round_robin_simple() -> None: 21 | l = ["ABC", "D", "EFGH"] 22 | 23 | assert sorted(round_robin(l)) == list("ABCDEFGH") 24 | 25 | 26 | def test_round_robin_long_iter() -> None: 27 | l = map(lambda x: range(x, x + 10), range(100)) 28 | 29 | assert len(list(round_robin(l))) == 1_000 30 | 31 | 32 | def test_round_robin_docstring() -> None: 33 | l = ["ABC", "D", "EF"] 34 | 35 | assert sorted(round_robin(l)) == ["A", "B", "C", "D", "E", "F"] 36 | 37 | 38 | def test_random_shuffle() -> None: 39 | random.seed(42) 40 | l = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] 41 | 42 | seen = [] 43 | 44 | for x in shuffle_buffer(l, buffer_size=3): 45 | seen.append(x) 46 | 47 | # We have seen all elements and each only once. 48 | assert sorted(l) == sorted(seen) 49 | assert l != seen 50 | -------------------------------------------------------------------------------- /mypy.ini: -------------------------------------------------------------------------------- 1 | [mypy] 2 | plugins = pydantic.mypy 3 | show_error_codes = True 4 | follow_imports = silent 5 | local_partial_types = true 6 | strict_equality = true 7 | no_implicit_optional = true 8 | warn_incomplete_stub = true 9 | warn_redundant_casts = true 10 | warn_unused_configs = true 11 | warn_unused_ignores = true 12 | enable_error_code = ignore-without-code, redundant-self, truthy-iterable 13 | disable_error_code = annotation-unchecked, import-not-found, import-untyped 14 | extra_checks = false 15 | check_untyped_defs = true 16 | disallow_incomplete_defs = true 17 | disallow_subclassing_any = true 18 | disallow_untyped_calls = true 19 | disallow_untyped_decorators = true 20 | disallow_untyped_defs = true 21 | warn_return_any = true 22 | warn_unreachable = true 23 | allow_redefinition = false 24 | strict_optional = true 25 | 26 | [pydantic-mypy] 27 | init_forbid_extra = true 28 | init_typed = true 29 | warn_required_dynamic_aliases = true 30 | warn_untyped_fields = true 31 | 32 | [mypy-sedpack.*] 33 | no_implicit_reexport = true 34 | disallow_untyped_calls = true 35 | disallow_any_unimported = true 36 | disallow_untyped_decorators = true 37 | strict = true 38 | enable_error_code = ignore-without-code, redundant-self, truthy-iterable, possibly-undefined, truthy-bool, truthy-iterable, unused-ignore, mutable-override 39 | 40 | [mypy-tests.*] 41 | disallow_untyped_defs = false 42 | disallow_untyped_calls = false 43 | disallow_untyped_decorators = false 44 | 45 | [mypy-sedpack.io.flatbuffer.unit_tests.*] 46 | disallow_untyped_defs = false 47 | disallow_untyped_calls = false 48 | disallow_untyped_decorators = false 49 | -------------------------------------------------------------------------------- /.github/workflows/spellcheck.yml: -------------------------------------------------------------------------------- 1 | name: spellcheck 2 | permissions: 3 | contents: read 4 | pull-requests: write 5 | on: 6 | pull_request: 7 | types: [opened, synchronize, reopened] 8 | merge_group: # Needed for required workflows 9 | 10 | jobs: 11 | spellchecking: 12 | runs-on: ubuntu-22.04 13 | steps: 14 | - name: Checkout the code 15 | uses: actions/checkout@v6 16 | with: 17 | # We need all history to list all changed files. 18 | fetch-depth: 0 19 | - name: Set up node 20 | uses: actions/setup-node@v6 21 | with: 22 | node-version: "21" 23 | - name: Install cspell 24 | run: npm install --location=global cspell 25 | - name: Run spell checking 26 | run: | 27 | # Add a problem matcher. 28 | echo ::add-matcher::./.github/cspell_matcher.json 29 | # Set internal field separator (filenames can have space in them). 30 | IFS=$'\n' 31 | # Find out which files were changed since main. 32 | git diff --name-only origin/main..${{ github.event.after }} | { # Open sub-shell 33 | # Spell check all files and remember if some failed. 34 | EXIT_CODE=0 35 | # Loop over the changed files. 36 | while read CHANGED_FILE; do 37 | # Run cspell on CHANGED_FILE, do not fail if the file is 38 | # ignored or the spell-checking fails. 39 | cspell --config ./cspell.json --no-must-find-files "$CHANGED_FILE" || EXIT_CODE=$? 40 | echo $EXIT_CODE 41 | done ; 42 | exit $EXIT_CODE ; # Fail if some check failed. 43 | } ; 44 | -------------------------------------------------------------------------------- /src/sedpack/io/shard/get_shard_writer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Base class for shard writing depending on shard_file_type. 15 | """ 16 | 17 | from pathlib import Path 18 | 19 | from sedpack.io.metadata import DatasetStructure 20 | from sedpack.io.types import ShardFileTypeT 21 | 22 | from sedpack.io.shard.shard_writer_base import ShardWriterBase 23 | from sedpack.io.shard.shard_writer_flatbuffer import ShardWriterFlatBuffer 24 | from sedpack.io.shard.shard_writer_np import ShardWriterNP 25 | from sedpack.io.shard.shard_writer_tfrec import ShardWriterTFRec 26 | 27 | _SHARD_FILE_TYPE_TO_CLASS: dict[ShardFileTypeT, type[ShardWriterBase]] = { 28 | "tfrec": ShardWriterTFRec, 29 | "npz": ShardWriterNP, 30 | "fb": ShardWriterFlatBuffer, 31 | } 32 | 33 | 34 | def get_shard_writer(dataset_structure: DatasetStructure, 35 | shard_file: Path) -> ShardWriterBase: 36 | """Return the right subclass of ShardWriterBase. 37 | """ 38 | return _SHARD_FILE_TYPE_TO_CLASS[dataset_structure.shard_file_type]( 39 | dataset_structure=dataset_structure, 40 | shard_file=shard_file, 41 | ) 42 | -------------------------------------------------------------------------------- /tests/io/test_error.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from pathlib import Path 16 | import pytest 17 | 18 | import sedpack 19 | from sedpack.io import Dataset 20 | from sedpack.io import Metadata 21 | from sedpack.io.errors import DatasetExistsError 22 | 23 | 24 | def test_dataset_exists(tmpdir: str | Path) -> None: 25 | tiny_experiment_path: Path = Path(tmpdir) / "exists" 26 | # Metadata 27 | dataset_metadata = Metadata(description="Test of the lib") 28 | example_attributes = [ 29 | sedpack.io.metadata.Attribute( 30 | name="attribute_name", 31 | dtype="float32", 32 | shape=(10,), 33 | ), 34 | ] 35 | dataset_structure = sedpack.io.metadata.DatasetStructure( 36 | saved_data_description=example_attributes) 37 | 38 | # First should be ok. 39 | dataset = Dataset.create( 40 | path=tiny_experiment_path, 41 | metadata=dataset_metadata, 42 | dataset_structure=dataset_structure, 43 | ) 44 | 45 | with pytest.raises(DatasetExistsError): 46 | # This would overwrite. 47 | dataset = Dataset.create( 48 | path=tiny_experiment_path, 49 | metadata=dataset_metadata, 50 | dataset_structure=dataset_structure, 51 | ) 52 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["maturin>=1.7,<2.0"] 3 | build-backend = "maturin" 4 | 5 | [tool.maturin] 6 | python-source = "src" 7 | manifest-path = "rust/Cargo.toml" 8 | features = ["pyo3/extension-module"] 9 | # Implemented in Rust: 10 | module-name = "sedpack._sedpack_rs" 11 | 12 | [project] 13 | name = "sedpack" 14 | authors = [ 15 | { name="Elie Bursztein"}, 16 | { name="Karel Král"}, 17 | { name="Jean-Michel Picod"}, 18 | ] 19 | description = "General ML dataset package" 20 | readme = "README.md" 21 | requires-python = ">=3.10" 22 | keywords = ["machine learning", "dataset"] 23 | license = {text = "Apache License 2.0"} 24 | classifiers = [ 25 | "Development Status :: 5 - Production/Stable", 26 | "Environment :: Console", 27 | "Framework :: Jupyter", 28 | "License :: OSI Approved :: Apache Software License", 29 | "Intended Audience :: Science/Research", 30 | "Programming Language :: Python :: 3", 31 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 32 | ] 33 | dynamic = ["version"] 34 | dependencies = [ 35 | "aiofiles", 36 | "asyncstdlib", 37 | "flatbuffers", 38 | "lz4", 39 | "numpy", 40 | "pydantic", 41 | "semver", 42 | "tenacity", 43 | "tensorflow", 44 | "tqdm", 45 | "xxhash", 46 | "zstandard", 47 | ] 48 | 49 | [project.optional-dependencies] 50 | dev = [ 51 | "maturin[patchelf,zig] ; platform_system != 'Windows'", 52 | "maturin[zig] ; platform_system == 'Windows'", 53 | "mypy", 54 | "pylint", 55 | "pytest", 56 | "pytest-asyncio", 57 | "pytest-cov", 58 | "types-aiofiles", 59 | "types-tensorflow", 60 | "types-tqdm", 61 | "yapf", 62 | ] 63 | 64 | [project.scripts] 65 | 66 | [project.urls] 67 | "Homepage" = "https://github.com/google/sedpack" 68 | "Bug Tracker" = "https://github.com/google/sedpack" 69 | 70 | [tool.ruff] 71 | target-version = "py310" 72 | -------------------------------------------------------------------------------- /.github/workflows/deploy.yml: -------------------------------------------------------------------------------- 1 | name: Deploy to GitHub Pages 2 | 3 | on: 4 | # Trigger the workflow every time you push to the `main` branch 5 | # Using a different branch name? Replace `main` with your branch’s name 6 | push: 7 | branches: [ main ] 8 | paths: [ website ] 9 | merge_group: # Needed for required workflows 10 | # Allows you to run this workflow manually from the Actions tab on GitHub. 11 | workflow_dispatch: 12 | 13 | # Allow this job to clone the repo and create a page deployment 14 | permissions: 15 | contents: read 16 | pages: write 17 | id-token: write 18 | 19 | # Allow only one concurrent deployment, skipping runs queued between the run in-progress and latest queued. 20 | concurrency: 21 | group: "pages" 22 | cancel-in-progress: false 23 | 24 | jobs: 25 | build: 26 | # Run only on the upstream repository. 27 | if: github.repository == 'google/sedpack' 28 | runs-on: ubuntu-latest 29 | steps: 30 | - name: Checkout your repository using git 31 | uses: actions/checkout@v6 32 | - name: Install, build, and upload your site 33 | uses: withastro/action@2226b5671ff302b175d9843add614af27e60bbfc # v4 34 | with: 35 | path: website # The root location of your Astro project inside the repository. (optional) 36 | # node-version: 20 # The specific version of Node that should be used to build your site. Defaults to 20. (optional) 37 | # package-manager: pnpm@latest # The Node package manager that should be used to install dependencies and build your site. Automatically detected based on your lockfile. (optional) 38 | 39 | deploy: 40 | needs: build 41 | runs-on: ubuntu-latest 42 | environment: 43 | name: github-pages 44 | url: ${{ steps.deployment.outputs.page_url }} 45 | steps: 46 | - name: Deploy to GitHub Pages 47 | id: deployment 48 | uses: actions/deploy-pages@v4 49 | -------------------------------------------------------------------------------- /.github/workflows/base_benchmarks.yml: -------------------------------------------------------------------------------- 1 | # Based on the tutorial https://bencher.dev/docs/how-to/github-actions/ 2 | 3 | on: 4 | push: 5 | branches: main 6 | paths: 7 | - '**/*.rs' 8 | - 'rust/Cargo.toml' 9 | - 'rust/Cargo.lock' 10 | schedule: 11 | # Run once a month (random values for minute, hour, day; any month or day) 12 | - cron: "3 1 4 * *" 13 | # Allows you to run this workflow manually from the Actions tab on GitHub. 14 | workflow_dispatch: 15 | 16 | jobs: 17 | benchmark_base_branch: 18 | # Run only on the upstream repository. 19 | if: github.repository == 'google/sedpack' 20 | name: Continuous Benchmarking with Bencher 21 | permissions: 22 | checks: write 23 | runs-on: ubuntu-latest 24 | steps: 25 | - uses: actions/checkout@v6 26 | - name: Set up Python 27 | uses: actions/setup-python@v6 28 | with: 29 | cache: 'pip' 30 | - name: Installing sedpack pip package 31 | run: | 32 | pip install sedpack 33 | - name: Prepare benchmarking data 34 | working-directory: ./rust 35 | run: bash benches/setup.sh 36 | - uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 # v1 37 | - uses: bencherdev/bencher@f89d454e74a32a81b2eab29fe0afdb2316617342 # v0.5 38 | - name: Track base branch benchmarks with Bencher 39 | working-directory: ./rust 40 | run: | 41 | bencher run \ 42 | --project sedpack \ 43 | --token '${{ secrets.BENCHER_API_TOKEN }}' \ 44 | --branch main \ 45 | --testbed ubuntu-latest \ 46 | --threshold-measure latency \ 47 | --threshold-test t_test \ 48 | --threshold-max-sample-size 64 \ 49 | --threshold-upper-boundary 0.99 \ 50 | --thresholds-reset \ 51 | --err \ 52 | --adapter rust_criterion \ 53 | --github-actions '${{ secrets.GITHUB_TOKEN }}' \ 54 | "cargo bench" 55 | -------------------------------------------------------------------------------- /tests/io/test_dataset_base.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from pathlib import Path 16 | from typing import Union 17 | 18 | import numpy as np 19 | 20 | import sedpack 21 | from sedpack.io import Dataset, Metadata 22 | 23 | 24 | def test_dataset_info_is_copy(tmpdir: Union[str, Path]) -> None: 25 | dtype = "float32" 26 | compression = "LZ4" 27 | tiny_experiment_path: Path = Path(tmpdir) / "dataset_info" 28 | array_of_values = np.random.random((1024, 138)) 29 | array_of_values = array_of_values.astype(dtype) 30 | 31 | # Create a dataset 32 | 33 | dataset_metadata = Metadata(description="Test of the lib") 34 | 35 | example_attributes = [ 36 | sedpack.io.metadata.Attribute( 37 | name="attribute_name", 38 | dtype=str(dtype), 39 | shape=array_of_values[0].shape, 40 | ), 41 | ] 42 | 43 | dataset_structure = sedpack.io.metadata.DatasetStructure( 44 | saved_data_description=example_attributes, 45 | compression=compression, 46 | examples_per_shard=256, 47 | shard_file_type="fb", 48 | ) 49 | 50 | dataset = Dataset.create( 51 | path=tiny_experiment_path, 52 | metadata=dataset_metadata, 53 | dataset_structure=dataset_structure, 54 | ) 55 | 56 | old_dataset_info = dataset.dataset_info 57 | 58 | # Will make a copy 59 | dataset.dataset_info.dataset_structure.compression = "" 60 | 61 | assert old_dataset_info == dataset.dataset_info 62 | -------------------------------------------------------------------------------- /website/astro.config.mjs: -------------------------------------------------------------------------------- 1 | // @ts-check 2 | import { defineConfig } from 'astro/config'; 3 | import starlight from '@astrojs/starlight'; 4 | import remarkMath from 'remark-math'; 5 | import rehypeMathJax from 'rehype-mathjax'; 6 | 7 | // https://astro.build/config 8 | export default defineConfig({ 9 | site: 'https://google.github.io/sedpack/', 10 | base: '/sedpack', 11 | 12 | // Configure `remark-math` and `rehype-mathjax` plugins: 13 | markdown: { 14 | remarkPlugins: [remarkMath], 15 | rehypePlugins: [rehypeMathJax], 16 | }, 17 | 18 | integrations: [ 19 | starlight({ 20 | title: 'Sedpack Documentation', 21 | social: [ 22 | { 23 | icon: 'github', 24 | label: 'GitHub', 25 | href: 'https://github.com/google/sedpack', 26 | } 27 | ], 28 | // Custom CSS to style MathJax equations 29 | customCss: ['./src/mathjax.css'], 30 | sidebar: [ 31 | { 32 | label: 'Start Here', 33 | items: [ 34 | // Each item here is one entry in the navigation menu. 35 | { label: 'Getting Started', slug: 'start_here/intro' }, 36 | { label: 'Installation', slug: 'start_here/install' }, 37 | ], 38 | }, 39 | { 40 | label: 'Tutorials', 41 | items: [ 42 | { label: 'MNIST', slug: 'tutorials/mnist' }, 43 | { 44 | label: 'Side Channel Attacks', 45 | items: [ 46 | { label: 'SCA Overview', slug: 'tutorials/sca/overview' }, 47 | { label: 'Dataset Preparation', slug: 'tutorials/sca/dataset' }, 48 | { 49 | label: 'Classical Attacks', 50 | items: [ 51 | { label: 'Signal to Noise Ratio', slug: 'tutorials/sca/snr' }, 52 | { label: 'GPU Acceleration of CPA and Template Attacks', slug: 'tutorials/sca/gpu_cpa_template' }, 53 | ], 54 | }, 55 | { label: 'Deep Learning (GPAM)', slug: 'tutorials/sca/gpam' }, 56 | ], 57 | }, 58 | ], 59 | }, 60 | ], 61 | }), 62 | ], 63 | }); 64 | -------------------------------------------------------------------------------- /tests/io/test_file_info.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from pathlib import Path 16 | import pytest 17 | 18 | from sedpack.io.file_info import PathGenerator 19 | 20 | 21 | def test_multilevel() -> None: 22 | levels: int = 3 23 | max_branching: int = 2 24 | name_length: int = 3 25 | 26 | generator = PathGenerator( 27 | levels=levels, 28 | max_branching=max_branching, 29 | name_length=name_length, 30 | ) 31 | 32 | seen_paths: set[Path] = set() 33 | 34 | # No exception in the top level 35 | for _ in range((max_branching**levels) + 10): 36 | p = generator.get_path() 37 | assert len(p.parts) == levels 38 | 39 | # Last part is long, previous are at most name_length 40 | assert len(p.parts[-1]) >= 32 41 | assert all(len(part) == name_length for part in p.parts[:-1]) 42 | 43 | seen_paths.add(p) 44 | 45 | # Enforce bounded number of subdirectories except the top 46 | for l in range(levels - 1): 47 | for p in seen_paths: 48 | prefix = p.parents[l] 49 | count = sum(path.is_relative_to(prefix) for path in seen_paths) 50 | assert count <= max_branching**(l + 1) 51 | 52 | 53 | def test_single_level() -> None: 54 | max_branching: int = 3 55 | name_length: int = 3 56 | generator = PathGenerator( 57 | levels=1, 58 | max_branching=max_branching, 59 | name_length=name_length, 60 | ) 61 | 62 | # No exception in the top level 63 | for _ in range(max_branching + 10): 64 | n = generator.get_path() 65 | assert len(str(n)) >= 32 # last level is unbounded length 66 | assert len(n.parts) == 1 67 | -------------------------------------------------------------------------------- /.github/workflows/mypy.yml: -------------------------------------------------------------------------------- 1 | name: mypy 2 | permissions: 3 | contents: read 4 | pull-requests: write 5 | on: 6 | pull_request: 7 | types: [opened, synchronize, reopened] 8 | paths: 9 | - '**/*.py' 10 | - '**/*.rs' 11 | - 'mypy.ini' 12 | merge_group: # Needed for required workflows 13 | # Run after a review has been submitted (this is a required workflow which 14 | # might not be triggered when no code changes -- trigger before going to 15 | # merge queue). 16 | pull_request_review: 17 | types: [submitted] 18 | 19 | jobs: 20 | mypy: 21 | runs-on: ubuntu-22.04 22 | steps: 23 | - uses: actions/checkout@v6 24 | - name: Set up Python 3.10 25 | uses: actions/setup-python@v6 26 | with: 27 | python-version: '3.10' 28 | cache: 'pip' 29 | - name: Get pip cache directory 30 | id: pip-cache 31 | shell: bash 32 | run: | 33 | echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT 34 | - name: Use cached venv or create it 35 | uses: actions/cache/restore@v5 36 | id: cache 37 | with: 38 | path: ${{ steps.pip-cache.outputs.dir }} 39 | # The cache key depends on requirements.txt 40 | key: ${{ runner.os }}-pip-${{ hashFiles('requirements.txt') }}-${{ hashFiles('pyproject.toml') }} 41 | restore-keys: | 42 | ${{ runner.os }}-pip- 43 | # Build a virtualenv, but only if it doesn't already exist 44 | - name: Populate pip cache 45 | run: | 46 | pip install --require-hashes --no-deps -r requirements.txt 47 | pip install --editable ".[dev]" 48 | - name: Save cache 49 | id: cache-save 50 | uses: actions/cache/save@v5 51 | with: 52 | path: ${{ steps.pip-cache.outputs.dir }} 53 | key: ${{ steps.cache.outputs.cache-primary-key }} 54 | if: steps.cache.outputs.cache-hit != 'true' 55 | - name: Register matcher 56 | run: echo ::add-matcher::./.github/python_matcher.json 57 | - name: Running mypy 58 | run: | 59 | echo "PYTHONPATH=./src:$PYTHONPATH" >> $GITHUB_ENV 60 | mkdir -p .mypy_cache 61 | mypy --version 62 | mypy --no-color-output --install-types --non-interactive src docs 63 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Sedpack - Scalable and efficient data packing 2 | 3 | [![Coverage Status](https://coveralls.io/repos/github/google/sedpack/badge.svg?branch=main)](https://coveralls.io/github/google/sedpack?branch=main) 4 | 5 | [Documentation](https://google.github.io/sedpack/) 6 | 7 | Mainly refactored from the [SCAAML](https://github.com/google/scaaml) project. 8 | 9 | ## Available components 10 | 11 | See the documentation website: 12 | [https://google.github.io/sedpack/](https://google.github.io/sedpack/). 13 | 14 | ## Install 15 | 16 | ### Dependencies 17 | 18 | To use this library you need to have a working version of [TensorFlow 19 | 2.x](https://www.tensorflow.org/install). 20 | 21 | Development dependencies: 22 | 23 | - python-dev and gcc for [xxhash](https://pypi.org/project/xxhash/) 24 | 25 | ### Dataset install 26 | 27 | #### Development install 28 | 29 | 1. Clone the repository: `git clone https://github.com/google/sedpack` 30 | 2. Install dependencies: `python3 -m pip install --require-hashes -r requirements.txt` 31 | 3. Install the package in development mode: `python3 -m pip install --editable 32 | .` (short `pip install -e .` or legacy `python setup.py develop`) 33 | 34 | #### Rust install 35 | 36 | - Activate your Python virtual environment 37 | - [Install Rust](https://www.rust-lang.org/tools/install) 38 | - Run `maturin develop --release` 39 | - Run `python -m pytest` from the project root directory -- no tests should 40 | be skipped 41 | 42 | ### Update dependencies 43 | 44 | Make sure to have: `sudo apt install python3 python3-pip python3-venv` and 45 | activated the virtual environment. 46 | 47 | Install requirements: `pip install --require-hashes -r base-tooling-requirements.txt` 48 | 49 | Update: `pip-compile pyproject.toml --generate-hashes --upgrade` and commit requirements.txt. 50 | 51 | #### Package install 52 | 53 | `pip install sedpack` 54 | 55 | ### Tutorial 56 | 57 | A tutorial and documentation is available at 58 | [https://google.github.io/sedpack/](https://google.github.io/sedpack/). 59 | 60 | Code for the tutorials is available in the `docs/tutorials` directory. For a 61 | "hello world" see 62 | [https://google.github.io/sedpack/tutorials/mnist/](https://google.github.io/sedpack/tutorials/mnist/). 63 | 64 | ## Disclaimer 65 | 66 | This is not an official Google product. 67 | -------------------------------------------------------------------------------- /tests/io/iteration/test_rust_generator.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from pathlib import Path 16 | import pytest 17 | from typing import Union 18 | 19 | import numpy as np 20 | 21 | import sedpack 22 | from sedpack.io.iteration import RustGenerator 23 | from sedpack.io.metadata import DatasetStructure 24 | 25 | 26 | def test_wrong_file_paralelism() -> None: 27 | with pytest.raises( 28 | ValueError, 29 | match="The argument file_parallelism should be positive.*", 30 | ): 31 | g = RustGenerator( 32 | dataset_path=Path(), 33 | dataset_structure=DatasetStructure(), 34 | shard_iterator=[], 35 | process_record=None, 36 | file_parallelism=0, 37 | ) 38 | 39 | 40 | def test_wrong_shard_type() -> None: 41 | with pytest.raises( 42 | ValueError, 43 | match="RustGenerator is implemented only for FlatBuffers.", 44 | ): 45 | g = RustGenerator( 46 | dataset_path=Path(), 47 | dataset_structure=DatasetStructure(shard_file_type="tfrec"), 48 | shard_iterator=[], 49 | process_record=None, 50 | file_parallelism=1, 51 | ) 52 | 53 | 54 | def test_wrong_compression() -> None: 55 | with pytest.raises( 56 | ValueError, 57 | match= 58 | "The compression .* is not among the supported compressions: .*", 59 | ): 60 | g = RustGenerator( 61 | dataset_path=Path(), 62 | dataset_structure=DatasetStructure( 63 | shard_file_type="fb", 64 | compression="ZIP", 65 | ), 66 | shard_iterator=[], 67 | process_record=None, 68 | file_parallelism=1, 69 | ) 70 | -------------------------------------------------------------------------------- /.github/workflows/fork_pr_benchmarks_track.yml: -------------------------------------------------------------------------------- 1 | # Based on https://bencher.dev/docs/how-to/github-actions/ 2 | name: Track Benchmarks with Bencher 3 | permissions: 4 | contents: read 5 | pull-requests: write 6 | 7 | on: 8 | workflow_run: 9 | workflows: [Run Benchmarks] 10 | types: [completed] 11 | 12 | jobs: 13 | track_fork_pr_branch: 14 | if: github.event.workflow_run.conclusion == 'success' 15 | runs-on: ubuntu-latest 16 | env: 17 | BENCHMARK_RESULTS: benchmark_results.txt 18 | PR_EVENT: event.json 19 | steps: 20 | - name: Download Benchmark Results 21 | uses: dawidd6/action-download-artifact@ac66b43f0e6a346234dd65d4d0c8fbb31cb316e5 # v11 22 | with: 23 | name: ${{ env.BENCHMARK_RESULTS }} 24 | run_id: ${{ github.event.workflow_run.id }} 25 | - name: Download PR Event 26 | uses: dawidd6/action-download-artifact@ac66b43f0e6a346234dd65d4d0c8fbb31cb316e5 # v11 27 | with: 28 | name: ${{ env.PR_EVENT }} 29 | run_id: ${{ github.event.workflow_run.id }} 30 | - name: Figure out what is where 31 | run: ls 32 | - name: Export PR Event Data 33 | uses: actions/github-script@v8 34 | with: 35 | script: | 36 | let fs = require('fs'); 37 | let prEvent = JSON.parse(fs.readFileSync(process.env.PR_EVENT, {encoding: 'utf8'})); 38 | core.exportVariable("PR_HEAD", prEvent.pull_request.head.ref); 39 | core.exportVariable("PR_BASE", prEvent.pull_request.base.ref); 40 | core.exportVariable("PR_BASE_SHA", prEvent.pull_request.base.sha); 41 | core.exportVariable("PR_NUMBER", prEvent.number); 42 | - uses: bencherdev/bencher@f89d454e74a32a81b2eab29fe0afdb2316617342 # v0.5 43 | - name: Track Benchmarks with Bencher 44 | run: | 45 | bencher run \ 46 | --project sedpack \ 47 | --token '${{ secrets.BENCHER_API_TOKEN }}' \ 48 | --branch "$PR_HEAD" \ 49 | --start-point "$PR_BASE" \ 50 | --start-point-hash "$PR_BASE_SHA" \ 51 | --start-point-clone-thresholds \ 52 | --start-point-reset \ 53 | --testbed ubuntu-latest \ 54 | --err \ 55 | --adapter rust_criterion \ 56 | --github-actions '${{ secrets.GITHUB_TOKEN }}' \ 57 | --ci-number "$PR_NUMBER" \ 58 | --file "$BENCHMARK_RESULTS" 59 | -------------------------------------------------------------------------------- /tests/io/npz/test_npz_shards.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from pathlib import Path 16 | 17 | import numpy as np 18 | 19 | from sedpack.io.metadata import Attribute, DatasetStructure 20 | from sedpack.io.shard.shard_writer_np import ShardWriterNP 21 | from sedpack.io.npz import IterateShardNP 22 | 23 | 24 | def test_attribute_bytes(tmpdir: str | Path) -> None: 25 | shard_file = Path(tmpdir / "test_shard.npz") 26 | examples_per_shard = 4 27 | attr_1_shape = (128,) 28 | attr_2_shape = (8,) 29 | attr_1_values = np.random.uniform(size=(examples_per_shard, *attr_1_shape)) 30 | attr_2_values = np.random.uniform(size=(examples_per_shard, *attr_2_shape)) 31 | 32 | dataset_structure = DatasetStructure( 33 | saved_data_description=[ 34 | Attribute( 35 | name="attr_1", 36 | dtype="float32", 37 | shape=attr_1_shape, 38 | ), 39 | Attribute( 40 | name="attr_2", 41 | dtype="float32", 42 | shape=attr_2_shape, 43 | ), 44 | ], 45 | compression="ZIP", 46 | examples_per_shard=examples_per_shard, 47 | shard_file_type="npz", 48 | ) 49 | 50 | shard_writer_np = ShardWriterNP( 51 | dataset_structure=dataset_structure, 52 | shard_file=shard_file, 53 | ) 54 | 55 | for i in range(examples_per_shard): 56 | shard_writer_np.write(values={ 57 | "attr_1": attr_1_values[i], 58 | "attr_2": attr_2_values[i], 59 | }) 60 | shard_writer_np.close() 61 | 62 | for i, e in enumerate( 63 | IterateShardNP(dataset_structure=dataset_structure, 64 | process_record=None).iterate_shard(shard_file)): 65 | assert np.allclose(e["attr_1"], attr_1_values[i]) 66 | assert np.allclose(e["attr_2"], attr_2_values[i]) 67 | -------------------------------------------------------------------------------- /tests/io/test_hash_checksums.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from pathlib import Path 16 | 17 | import numpy as np 18 | import pytest 19 | 20 | from sedpack.io.utils import hash_checksums, hash_checksums_from_bytes 21 | 22 | 23 | def test_compress_gzip_write(tmpdir: str | Path) -> None: 24 | file_name = tmpdir / "file.txt" 25 | hashes = ("md5", "sha256", "sha512", "sha384") 26 | with open(file_name, "w", encoding="ascii") as f: 27 | f.write("Hello world") 28 | 29 | checksums = hash_checksums(file_name, hashes) 30 | 31 | assert checksums == ( 32 | "3e25960a79dbc69b674cd4ec67a72c62", 33 | "64ec88ca00b268e5ba1a35678a1b5316d212f4f366b2477232534a8aeca37f3c", 34 | "b7f783baed8297f0db917462184ff4f08e69c2d5e5f79a942600f9725f58ce1f29c18139bf80b06c0fff2bdd34738452ecf40c488c22a7e3d80cdf6f9c1c0d47", 35 | "9203b0c4439fd1e6ae5878866337b7c532acd6d9260150c80318e8ab8c27ce330189f8df94fb890df1d298ff360627e1", 36 | ) 37 | 38 | 39 | @pytest.mark.parametrize("size", [1, 3, 7, 8, 13, 17, 33, 67]) 40 | def test_equivalent(size: int, tmp_path: Path) -> None: 41 | file_content = np.random.randint( 42 | 0, 43 | 256, 44 | size=size, 45 | dtype=np.uint8, 46 | ).tobytes() 47 | hashes = ( 48 | "md5", 49 | "sha256", 50 | "sha512", 51 | "sha384", 52 | "xxh32", 53 | "xxh64", 54 | "xxh128", 55 | ) 56 | 57 | tmp_path = tmp_path / "shard_file.extension" 58 | 59 | with open(tmp_path, "wb") as f: 60 | f.write(file_content) 61 | 62 | # Redundant, but read again. 63 | with open(tmp_path, "rb") as f: 64 | assert f.read() == file_content 65 | 66 | assert hash_checksums_from_bytes( 67 | file_content=file_content, 68 | hashes=hashes, 69 | ) == hash_checksums( 70 | file_path=tmp_path, 71 | hashes=hashes, 72 | ) 73 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # package specific 2 | *.zip 3 | .vscode/ 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | *.prof 9 | # C extensions 10 | *.so 11 | tmp/ 12 | algo_arch_implem_v1_train/ 13 | 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | pip-wheel-metadata/ 29 | share/python-wheels/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | MANIFEST 34 | 35 | # PyInstaller 36 | # Usually these files are written by a python script from a template 37 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 38 | *.manifest 39 | *.spec 40 | 41 | # Installer logs 42 | pip-log.txt 43 | pip-delete-this-directory.txt 44 | 45 | # Unit test / coverage reports 46 | htmlcov/ 47 | .tox/ 48 | .nox/ 49 | .coverage 50 | .coverage.* 51 | .cache 52 | nosetests.xml 53 | coverage.xml 54 | *.cover 55 | *.py,cover 56 | .hypothesis/ 57 | .pytest_cache/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | db.sqlite3 67 | db.sqlite3-journal 68 | 69 | # Flask stuff: 70 | instance/ 71 | .webassets-cache 72 | 73 | # Scrapy stuff: 74 | .scrapy 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | 79 | # PyBuilder 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # IPython 86 | profile_default/ 87 | ipython_config.py 88 | 89 | # pyenv 90 | .python-version 91 | 92 | # pipenv 93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 96 | # install all needed dependencies. 97 | #Pipfile.lock 98 | 99 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 100 | __pypackages__/ 101 | 102 | # Celery stuff 103 | celerybeat-schedule 104 | celerybeat.pid 105 | 106 | # SageMath parsed files 107 | *.sage.py 108 | 109 | # Environments 110 | .env 111 | .venv 112 | env/ 113 | venv/ 114 | ENV/ 115 | env.bak/ 116 | venv.bak/ 117 | 118 | # Spyder project settings 119 | .spyderproject 120 | .spyproject 121 | 122 | # Rope project settings 123 | .ropeproject 124 | 125 | # mkdocs documentation 126 | /site 127 | 128 | # mypy 129 | .mypy_cache/ 130 | .dmypy.json 131 | dmypy.json 132 | 133 | # Pyre type checker 134 | .pyre/ 135 | -------------------------------------------------------------------------------- /src/sedpack/io/types.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023-2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Build and load tensorFlow dataset Record wrapper""" 15 | 16 | # pylint: disable=invalid-name 17 | 18 | from typing import Any, Literal, Union 19 | 20 | import numpy as np 21 | import numpy.typing as npt 22 | from typing_extensions import TypeAlias # from typing since 3.10 23 | 24 | # Valid split values (used also as directory names). 25 | SplitT: TypeAlias = Literal["train", "test", "holdout"] 26 | TRAIN_SPLIT: SplitT = "train" 27 | TEST_SPLIT: SplitT = "test" 28 | HOLDOUT_SPLIT: SplitT = "holdout" 29 | 30 | # Keras model. 31 | TFModelT: TypeAlias = Any 32 | 33 | # tf.data.Dataset and similar. 34 | TFDatasetT: TypeAlias = Any 35 | 36 | # Type of an attribute value. 37 | AttributeValueT: TypeAlias = Union[ 38 | str, # UTF-8 string 39 | int, 40 | bytes, 41 | npt.NDArray[np.generic], 42 | ] 43 | 44 | # Type of a batch of attribute values. 45 | BatchedAttributeValueT: TypeAlias = Union[ 46 | list[str], # UTF-8 string 47 | list[int], 48 | list[bytes], 49 | # NP has the first dimension as the batch dimension. 50 | npt.NDArray[np.generic], 51 | ] 52 | 53 | # Compression choices. 54 | CompressionT: TypeAlias = Literal[ 55 | "", 56 | "BZ2", 57 | "GZIP", 58 | "LZMA", 59 | "LZ4", 60 | "ZIP", 61 | "ZLIB", 62 | "ZSTD", 63 | ] 64 | 65 | # Hash checksums types (algorithms supported by hashlib). 66 | HashChecksumT: TypeAlias = Literal[ 67 | "md5", 68 | "sha1", 69 | "sha224", 70 | "sha256", 71 | "sha384", 72 | "sha512", 73 | "sha3_224", 74 | "sha3_256", 75 | "sha3_384", 76 | "sha3_512", 77 | "xxh32", 78 | "xxh64", 79 | "xxh128", 80 | ] 81 | 82 | # Shard file-type choices. Also serves as the file-extension of shard files. 83 | ShardFileTypeT: TypeAlias = Literal[ 84 | "fb", 85 | "npz", 86 | "tfrec", 87 | ] 88 | 89 | # Type alias for example, this is what gets iterated or saved. 90 | ExampleT: TypeAlias = dict[str, AttributeValueT] 91 | 92 | BatchT: TypeAlias = dict[str, BatchedAttributeValueT] 93 | -------------------------------------------------------------------------------- /website/src/content/docs/start_here/install.mdx: -------------------------------------------------------------------------------- 1 | --- 2 | title: Sedpack Installation 3 | description: Sedpack Installation 4 | --- 5 | 6 | import { Tabs, TabItem } from '@astrojs/starlight/components'; 7 | 8 | Installation of the Sedpack Python package. 9 | 10 | ## Python Package Index 11 | 12 | All one should need to install the package from 13 | [PyPI](https://pypi.org/project/sedpack/) is: 14 | 15 | ```bash 16 | pip install sedpack 17 | ``` 18 | 19 | Note that this is the latest stable version of the package. 20 | If you want the bleeding edge features install from source instead. 21 | 22 | ## Installing from Source 23 | 24 | One can always opt for installation from the source. 25 | 26 | Development dependencies: 27 | 28 | - python-dev and gcc for [xxhash](https://pypi.org/project/xxhash/) 29 | - [Install Rust](https://www.rust-lang.org/tools/install) 30 | 31 | Clone and create a Python virtual environment: 32 | 33 | ```bash 34 | git clone github.com/google/sedpack/ # Clone the repository 35 | python3 -m venv my_env # Create Python virtual environment 36 | source my_env/bin/activate # Activate your virtual environment 37 | cd sedpack/ # Change directory to the cloned git repository 38 | ``` 39 | 40 | [optional] Install pinned Python dependencies: 41 | 42 | 43 | 44 | 45 | ```bash 46 | python3 -m pip install --require-hashes -r requirements.txt 47 | ``` 48 | 49 | 50 | 51 | 52 | Currently the file `requirements.txt` is compiled for Linux and some 53 | versions of PyPI packages might be different or even not existing 54 | elsewhere. One can always recreate those: 55 | 56 | ```bash 57 | python3 -m pip install pip-tools 58 | python3 -m pip-compile --generate-hashes pyproject.toml > requirements_win.txt 59 | python3 -m pip install --require-hashes -r requirements_win.txt 60 | ``` 61 | 62 | 63 | 64 | 65 | Currently the file `requirements.txt` is compiled for Linux and some 66 | versions of PyPI packages might be different or even not existing 67 | elsewhere. One can always recreate those: 68 | 69 | ```bash 70 | python3 -m pip install pip-tools 71 | python3 -m pip-compile --generate-hashes pyproject.toml > requirements_mac.txt 72 | python3 -m pip install --require-hashes -r requirements_mac.txt 73 | ``` 74 | 75 | 76 | 77 | 78 | Install sedpack with development dependencies: 79 | 80 | ```bash 81 | python3 -m pip install --editable '.[dev]' 82 | ``` 83 | 84 | [optional] Run unit-tests, none should be skipped: 85 | 86 | ```bash 87 | python -m pytest 88 | ``` 89 | 90 | [optional] Rebuild the Rust library after a change: 91 | 92 | ```bash 93 | maturin develop --release 94 | ``` 95 | -------------------------------------------------------------------------------- /src/sedpack/io/shard/iterate_shard_base.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Base class for shard iteration.""" 15 | 16 | from abc import ABC, abstractmethod 17 | from pathlib import Path 18 | from typing import AsyncIterator, Callable, Generic, Iterable, TypeVar 19 | 20 | from sedpack.io.metadata import DatasetStructure 21 | from sedpack.io.types import ExampleT 22 | 23 | T = TypeVar("T") 24 | 25 | 26 | class IterateShardBase(ABC, Generic[T]): 27 | """Remember everything to be able to iterate shards. This can be pickled 28 | and passed as a callable object into another process. 29 | """ 30 | 31 | def __init__(self, dataset_structure: DatasetStructure, 32 | process_record: Callable[[ExampleT], T] | None) -> None: 33 | """Initialize the shard iterator. 34 | 35 | Args: 36 | 37 | dataset_structure (DatasetStructure): The structure of the dataset 38 | allowing shard parsing. 39 | 40 | process_record (Callable[[ExampleT], T] | None): How to process each 41 | record. Needs to be pickleable (for multiprocessing). 42 | """ 43 | self.dataset_structure: DatasetStructure = dataset_structure 44 | self.process_record: Callable[[ExampleT], T] | None = process_record 45 | 46 | @abstractmethod 47 | def iterate_shard(self, file_path: Path) -> Iterable[ExampleT]: 48 | """Iterate a shard. 49 | """ 50 | 51 | @abstractmethod 52 | def iterate_shard_async(self, file_path: Path) -> AsyncIterator[ExampleT]: 53 | """Asynchronously iterate a shard. 54 | """ 55 | # TODO(issue #85) fix and test async iterator typing 56 | 57 | @abstractmethod 58 | def process_and_list(self, shard_file: Path) -> list[T]: 59 | """Return a list of processed examples. Used as a function call in a 60 | different process. Returning a list as opposed to an iterator allows to 61 | do all work in another process and all that needs to be done is a 62 | memory copy between processes. 63 | 64 | Args: 65 | 66 | shard_file (Path): Path to the shard file. 67 | 68 | Returns: A list of examples present in the shard. If `process_record` 69 | is defined it is applied to all those examples. 70 | """ 71 | -------------------------------------------------------------------------------- /tests/io/test_bytes.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from pathlib import Path 16 | import pytest 17 | from typing import Union 18 | 19 | import numpy as np 20 | 21 | import sedpack 22 | from sedpack.io import Dataset, Metadata 23 | from sedpack.io.types import TRAIN_SPLIT 24 | from sedpack.io.utils import is_module_present 25 | 26 | 27 | @pytest.mark.skipif( 28 | not is_module_present("tensorflow"), 29 | reason="TensorFlow is optional, skip test if not present.", 30 | ) 31 | def test_attribute_bytes(tmpdir: Union[str, Path]) -> None: 32 | array_of_values = [ 33 | bytes( 34 | np.random.randint( 35 | 0, # low 36 | 256, # high 37 | np.random.randint(10, 1_000, (), np.int32), # size 38 | np.uint8, 39 | )) for _ in range(138) 40 | ] 41 | 42 | tiny_experiment_path: Path = Path(tmpdir) / "e2e_experiment" 43 | 44 | # Create a dataset 45 | 46 | dataset_metadata = Metadata(description="Test of the lib") 47 | 48 | example_attributes = [ 49 | sedpack.io.metadata.Attribute( 50 | name="attribute_name", 51 | dtype="bytes", 52 | shape=(), # ignored 53 | ), 54 | ] 55 | 56 | dataset_structure = sedpack.io.metadata.DatasetStructure( 57 | saved_data_description=example_attributes, 58 | compression="GZIP", 59 | examples_per_shard=256, 60 | shard_file_type="tfrec", 61 | ) 62 | 63 | dataset = Dataset.create( 64 | path=tiny_experiment_path, 65 | metadata=dataset_metadata, 66 | dataset_structure=dataset_structure, 67 | ) 68 | 69 | # Fill data in the dataset 70 | 71 | with dataset.filler(concurrency=np.random.randint(0, 4),) as filler: 72 | for attribute_value in array_of_values: 73 | filler.write_example( 74 | values={"attribute_name": attribute_value}, 75 | split=TRAIN_SPLIT, 76 | ) 77 | 78 | # Check the data is correct 79 | 80 | for i, example in enumerate( 81 | dataset.as_numpy_iterator( 82 | split=TRAIN_SPLIT, 83 | shuffle=0, 84 | repeat=False, 85 | )): 86 | if example["attribute_name"] != array_of_values[i]: 87 | print(example["attribute_name"]) 88 | print(array_of_values[i]) 89 | raise ValueError 90 | 91 | # We tested everything 92 | assert i + 1 == len(array_of_values) 93 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | # To get started with Dependabot version updates, you'll need to specify which 2 | # package ecosystems to update and where the package manifests are located. 3 | # Please see the documentation for all configuration options: 4 | # https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates 5 | 6 | version: 2 7 | updates: 8 | - package-ecosystem: "npm" 9 | directory: "/website" 10 | labels: 11 | - "dependencies" 12 | - "javascript" 13 | # Run every Monday 14 | schedule: 15 | interval: "monthly" 16 | timezone: "Europe/Zurich" 17 | # Group PRs to avoid having to rebase/merge too many 18 | groups: 19 | dependabot: 20 | patterns: 21 | - "*" 22 | # Only care about our direct dependencies 23 | allow: 24 | - dependency-type: "direct" 25 | ignore: 26 | # Filter out semver patches updates to reduce the frequency of updates 27 | - dependency-name: "*" 28 | update-types: ["version-update:semver-patch"] 29 | 30 | - package-ecosystem: "github-actions" 31 | directory: "/" 32 | labels: 33 | - "dependencies" 34 | - "CI" 35 | # Run every Monday 36 | schedule: 37 | interval: "weekly" 38 | timezone: "Europe/Zurich" 39 | groups: 40 | dependabot: 41 | patterns: 42 | - "*" 43 | ignore: 44 | - dependency-name: "*" 45 | # For github-actions, we only care about major version update 46 | update-types: 47 | - "version-update:semver-patch" 48 | - "version-update:semver-minor" 49 | 50 | - package-ecosystem: "pip" 51 | directory: "/" 52 | labels: 53 | - "dependencies" 54 | - "python" 55 | # Run every Monday 56 | schedule: 57 | interval: "weekly" 58 | timezone: "Europe/Zurich" 59 | # Group PRs to avoid having to rebase/merge too many 60 | groups: 61 | dependabot: 62 | patterns: 63 | - "*" 64 | # Only care about our direct dependencies 65 | allow: 66 | - dependency-type: "direct" 67 | ignore: 68 | # Filter out semver patches updates to reduce the frequency of updates 69 | - dependency-name: "*" 70 | update-types: ["version-update:semver-patch"] 71 | 72 | - package-ecosystem: "cargo" 73 | directory: "/rust/" 74 | labels: 75 | - "dependencies" 76 | - "rust" 77 | # Run every Monday 78 | schedule: 79 | interval: "weekly" 80 | timezone: "Europe/Zurich" 81 | # Group PRs to avoid having to rebase/merge too many 82 | groups: 83 | dependabot: 84 | patterns: 85 | - "*" 86 | # Only care about our direct dependencies 87 | allow: 88 | - dependency-type: "direct" 89 | ignore: 90 | # Filter out semver patches updates to reduce the frequency of updates 91 | - dependency-name: "*" 92 | update-types: ["version-update:semver-patch"] 93 | # Peer dependencies of Cargo, not to bump independently: 94 | - dependency-name: "semver" 95 | - dependency-name: "crates-io" 96 | -------------------------------------------------------------------------------- /tests/io/itertools/test_lazy_pool.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import pytest 16 | 17 | from sedpack.io.itertools import LazyPool 18 | 19 | 20 | def test_doc() -> None: 21 | 22 | def f(x: int) -> int: 23 | return 2 * x 24 | 25 | with LazyPool(5) as pool: 26 | result = set(pool.imap_unordered(f, range(10))) 27 | 28 | assert result == set(f(i) for i in range(10)) 29 | 30 | 31 | def test_break_from_loop() -> None: 32 | f = lambda x: list(range(x)) 33 | 34 | with LazyPool(5) as pool: 35 | for l in pool.imap_unordered(f, range(1_000**5)): 36 | if len(l) > 50: 37 | break 38 | # Not hanging. 39 | 40 | 41 | def test_two_functions() -> None: 42 | with LazyPool(3) as pool: 43 | 44 | def f(x: int) -> tuple[int, int]: 45 | return (x, x) 46 | 47 | result = set(pool.imap_unordered(f, range(10))) 48 | print(result) 49 | assert result == set((i, i) for i in range(10)) 50 | 51 | def g(x: int) -> str: 52 | return f"{x} is a number" 53 | 54 | result: set[str] = set( # type: ignore[no-redef] 55 | pool.imap_unordered(g, range(10))) 56 | assert result == set( 57 | f"{i} is a number" 58 | for i in range(10)) # type: ignore[comparison-overlap] 59 | 60 | 61 | def test_break_from_loop_break_and_iterate_again() -> None: 62 | f = lambda x: list(range(x)) 63 | 64 | pool = LazyPool(5) 65 | 66 | with pool as pool: 67 | seen: int = 0 68 | for l in pool.imap_unordered(f, range(1_000**5)): 69 | seen += 1 70 | if len(l) >= 50: 71 | break 72 | assert seen == 51 73 | 74 | with pool as pool: 75 | seen = 0 76 | for l in pool.imap_unordered(f, range(1_000**5)): 77 | seen += 1 78 | if len(l) >= 50: 79 | break 80 | assert seen == 51 81 | 82 | 83 | def test_no_restart_while_imap() -> None: 84 | f = lambda x: list(range(x)) 85 | 86 | pool = LazyPool(5) 87 | 88 | with pool as pool: 89 | seen: int = 0 90 | for l in pool.imap_unordered(f, range(1_000**5)): 91 | seen += 1 92 | if len(l) >= 50: 93 | break 94 | assert seen == 51 95 | 96 | # Did not finish the neither previous iteration nor the context. 97 | with pytest.raises(AssertionError): 98 | for l in pool.imap_unordered(f, range(10)): 99 | pass 100 | -------------------------------------------------------------------------------- /src/sedpack/io/shard/shard.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023-2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Dataset shard manipulation. 15 | """ 16 | 17 | from pathlib import Path 18 | 19 | from sedpack.io.metadata import DatasetStructure 20 | from sedpack.io.shard_file_metadata import ShardInfo 21 | from sedpack.io.types import ExampleT 22 | from sedpack.io.shard.shard_writer_base import ShardWriterBase 23 | from sedpack.io.shard.get_shard_writer import get_shard_writer 24 | 25 | 26 | class Shard(): 27 | """A shard contains N measurement pertaining to the same key""" 28 | 29 | def __init__(self, shard_info: ShardInfo, 30 | dataset_structure: DatasetStructure, 31 | dataset_root_path: Path) -> None: 32 | """Collect information about a new shard. 33 | 34 | Args: 35 | 36 | shard_info (ShardInfo): Information about this shard. 37 | 38 | dataset_structure (DatasetStructure): The structure of data being 39 | saved. 40 | 41 | dataset_root_path (Path): Path to the dataset. 42 | """ 43 | # Information needed to save the shard. 44 | self.shard_info: ShardInfo = shard_info 45 | self.dataset_structure: DatasetStructure = dataset_structure 46 | self._dataset_path: Path = dataset_root_path 47 | 48 | self._shard_writer: ShardWriterBase | None = get_shard_writer( 49 | dataset_structure=dataset_structure, 50 | shard_file=self._get_full_path(), 51 | ) 52 | 53 | def write(self, values: ExampleT) -> None: 54 | """Write an example on disk as TFRecord. 55 | 56 | Args: 57 | 58 | values (ExampleT): Attribute values. 59 | """ 60 | if not self._shard_writer: 61 | raise ValueError("Attempting to write to a closed shard.") 62 | 63 | self._shard_writer.write(values) 64 | self.shard_info.number_of_examples += 1 65 | 66 | def close(self) -> ShardInfo: 67 | """Close shard and return statistics. 68 | """ 69 | if self._shard_writer is None: 70 | raise ValueError("Closing a shard which has not been open.") 71 | 72 | hash_checksums: tuple[str, ...] = self._shard_writer.close() 73 | self._shard_writer = None 74 | 75 | # Compute sha256 checksum. 76 | self.shard_info.file_infos[0].hash_checksums = hash_checksums 77 | 78 | # Return shard info. 79 | return self.shard_info 80 | 81 | def _get_full_path(self) -> Path: 82 | """Return full path to the shard file. 83 | """ 84 | return self._dataset_path / self.shard_info.file_infos[0].file_path 85 | -------------------------------------------------------------------------------- /rust/benches/my_benchmark.rs: -------------------------------------------------------------------------------- 1 | // Copyright 2025 Google LLC 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // https://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | use criterion::{Criterion, criterion_group, criterion_main}; 16 | use glob::glob; 17 | use sedpack_rs::batch_iteration::BatchIterator; 18 | use sedpack_rs::example_iteration::{ 19 | CompressionType, ExampleIterator, ShardInfo, get_shard_progress, 20 | }; 21 | pub use sedpack_rs::parallel_map::parallel_map; 22 | 23 | pub fn get_shard_files() -> Vec { 24 | let shard_infos: Vec<_> = glob("mnist_fb_gzip/**/*.fb") 25 | .expect("Failed to load dataset") 26 | .filter_map(|p| p.ok()) 27 | .map(|p| p.display().to_string()) 28 | .map(|file_path| ShardInfo { file_path, compression_type: CompressionType::Gzip }) 29 | .collect(); 30 | println!(">> Decoding {} shards", shard_infos.len()); 31 | assert_eq!(shard_infos.len(), 275); 32 | shard_infos 33 | } 34 | 35 | pub fn batch_iterator_benchmark_deterministic(c: &mut Criterion) { 36 | let shard_infos = get_shard_files(); 37 | c.bench_function("BatchIterator", |b| { 38 | b.iter(|| { 39 | for batch in BatchIterator::new(shard_infos.clone(), 12, 32, vec![true, true], 0) { 40 | let _ = std::hint::black_box(batch); 41 | } 42 | }) 43 | }); 44 | } 45 | 46 | pub fn batch_iterator_benchmark_shuffled(c: &mut Criterion) { 47 | let shard_infos = get_shard_files(); 48 | c.bench_function("BatchIterator", |b| { 49 | b.iter(|| { 50 | for batch in BatchIterator::new(shard_infos.clone(), 12, 32, vec![true, true], 256) { 51 | let _ = std::hint::black_box(batch); 52 | } 53 | }) 54 | }); 55 | } 56 | 57 | pub fn example_iterator_benchmark(c: &mut Criterion) { 58 | let shard_infos = get_shard_files(); 59 | c.bench_function("ExampleIterator", |b| { 60 | b.iter(|| { 61 | for example in ExampleIterator::new(shard_infos.clone(), 12) { 62 | let _ = std::hint::black_box(example); 63 | } 64 | }) 65 | }); 66 | } 67 | 68 | pub fn parallel_map_benchmark(c: &mut Criterion) { 69 | let shard_infos = get_shard_files(); 70 | c.bench_function("parallel_map", |b| { 71 | b.iter(|| { 72 | for shard in 73 | parallel_map(|x| get_shard_progress(&x), shard_infos.clone().into_iter(), 32) 74 | { 75 | let _ = std::hint::black_box(shard); 76 | } 77 | }) 78 | }); 79 | } 80 | 81 | criterion_group!( 82 | benches, 83 | batch_iterator_benchmark_deterministic, 84 | batch_iterator_benchmark_shuffled, 85 | example_iterator_benchmark, 86 | parallel_map_benchmark, 87 | ); 88 | criterion_main!(benches); 89 | -------------------------------------------------------------------------------- /.github/workflows/fork_pr_benchmarks_run.yml: -------------------------------------------------------------------------------- 1 | # Based on https://bencher.dev/docs/how-to/github-actions/ 2 | name: Run Benchmarks 3 | permissions: 4 | contents: read 5 | pull-requests: write 6 | 7 | on: 8 | pull_request: 9 | types: [opened, edited] 10 | paths: 11 | - '**/*.rs' 12 | - 'rust/Cargo.toml' 13 | - 'rust/Cargo.lock' 14 | 15 | jobs: 16 | benchmark_fork_pr_branch: 17 | # Run only on the upstream repository. 18 | if: github.repository == 'google/sedpack' 19 | name: Run Fork PR Benchmarks 20 | runs-on: ubuntu-latest 21 | steps: 22 | - uses: actions/checkout@v6 23 | 24 | # begin: Use source sedpack package 25 | - uses: actions/checkout@v6 26 | - name: Set up Python 3.10 27 | uses: actions/setup-python@v6 28 | with: 29 | python-version: '3.10' 30 | cache: 'pip' 31 | - name: Get pip cache directory 32 | id: pip-cache 33 | shell: bash 34 | run: | 35 | echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT 36 | - name: Use cached venv or create it 37 | uses: actions/cache/restore@v5 38 | id: cache 39 | with: 40 | path: ${{ steps.pip-cache.outputs.dir }} 41 | # The cache key depends on requirements.txt 42 | key: ${{ matrix.platform.runner }}-pip-${{ hashFiles('pyproject.toml') }}-${{ hashFiles('requirements.txt') }} 43 | # Build a virtualenv, but only if it doesn't already exist 44 | - name: Populate pip cache 45 | # requirements.txt is not reliable since across different platforms and 46 | # their versions the pip package versions might vary. We regenerate it 47 | # again from pyproject.toml every time when pyproject.toml or 48 | # requirements.txt changes. The pinned versions in requirements.txt are 49 | # tested by coverage since that is running on ubuntu which is also used 50 | # to produce the main requirements.txt file. 51 | run: | 52 | pip install pip-tools 53 | pip-compile --generate-hashes --extra dev pyproject.toml > requirements.txt 54 | pip install -r requirements.txt 55 | if: steps.cache.outputs.cache-hit != 'true' 56 | - name: Save cache 57 | id: cache-save 58 | uses: actions/cache/save@v5 59 | with: 60 | path: ${{ steps.pip-cache.outputs.dir }} 61 | key: ${{ steps.cache.outputs.cache-primary-key }} 62 | if: steps.cache.outputs.cache-hit != 'true' 63 | - name: Install sedpack locally 64 | run: pip install --editable . 65 | # end: Use source sedpack package 66 | 67 | - name: Prepare benchmarking data 68 | run: (cd rust/ ; bash benches/setup.sh ) 69 | 70 | - uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 # v1 71 | - uses: bencherdev/bencher@f89d454e74a32a81b2eab29fe0afdb2316617342 # v0.5 72 | - name: Benchmarking 73 | run: (cd rust/ ; cargo bench > benchmark_results.txt) 74 | - name: Upload Benchmark Results 75 | uses: actions/upload-artifact@v6 76 | with: 77 | name: benchmark_results.txt 78 | path: ./rust/benchmark_results.txt 79 | - name: Upload GitHub Pull Request Event 80 | uses: actions/upload-artifact@v6 81 | with: 82 | name: event.json 83 | path: ${{ github.event_path }} 84 | -------------------------------------------------------------------------------- /tests/io/test_compression.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import gzip 16 | from pathlib import Path 17 | from typing import Union 18 | 19 | import pytest 20 | 21 | from sedpack.io.compress import CompressedFile 22 | from sedpack.io.types import CompressionT 23 | 24 | 25 | def test_compress_gzip_write(tmpdir: str | Path) -> None: 26 | file_name = tmpdir / "compressed_file" 27 | payload = bytes(x % 13 for x in range(10 * (1024**2))) # 10MB 28 | 29 | # Can read what gzip writes 30 | with gzip.open(file_name, "wb") as f: 31 | f.write(payload) 32 | 33 | with open(file_name, "rb") as f: 34 | assert CompressedFile("GZIP").decompress(f.read()) == payload 35 | 36 | 37 | def test_compress_gzip_read(tmpdir: str | Path) -> None: 38 | file_name = tmpdir / "compressed_file" 39 | payload = bytes(x % 13 for x in range(10 * (1024**2))) # 10MB 40 | 41 | # Write is readable by gzip 42 | with open(file_name, "wb") as f: 43 | f.write(CompressedFile("GZIP").compress(payload)) 44 | 45 | with gzip.open(file_name, "rb") as f: 46 | assert f.read() == payload 47 | 48 | 49 | def test_compress_decompress_file(tmpdir: str | Path) -> None: 50 | file_name = tmpdir / "compressed_file" 51 | payload = bytes(x % 13 for x in range(10 * (1024**2))) # 10MB 52 | 53 | for algorithm in CompressedFile.supported_compressions(): 54 | with open(file_name, "wb") as f: 55 | f.write(CompressedFile(algorithm).compress(payload)) 56 | 57 | with open(file_name, "rb") as f: 58 | assert CompressedFile(algorithm).decompress(f.read()) == payload 59 | 60 | 61 | def test_compress_decompress_in_memory() -> None: 62 | payload = bytes(x % 13 for x in range(10 * (1024**2))) # 10MB 63 | 64 | for algorithm in CompressedFile.supported_compressions(): 65 | # Decompress of compress is the same. 66 | assert CompressedFile(algorithm).decompress( 67 | CompressedFile(algorithm).compress(payload)) == payload 68 | 69 | 70 | def test_compresses() -> None: 71 | # This should be compressible 72 | payload = bytes(x % 13 for x in range(10 * (1024**2))) # 10MB 73 | 74 | assert len(CompressedFile("GZIP").compress(payload)) < len(payload) 75 | 76 | 77 | @pytest.mark.parametrize("compression", CompressedFile.supported_compressions()) 78 | def test_compression_works(compression: CompressionT, 79 | tmpdir: Union[str, Path]) -> None: 80 | file_name = tmpdir / "compressed_file" 81 | payload = bytes(x % 13 for x in range(10 * (1024**2))) # 10MB 82 | 83 | # Write is readable by gzip 84 | with open(file_name, "wb") as f: 85 | compressor = CompressedFile(compression) 86 | f.write(compressor.compress(payload)) 87 | 88 | with open(file_name, "rb") as f: 89 | assert CompressedFile(compression).decompress(f.read()) == payload 90 | -------------------------------------------------------------------------------- /tests/io/test_rust_iter.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023-2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from pathlib import Path 16 | from typing import get_args, Union 17 | 18 | import numpy as np 19 | import numpy.typing as npt 20 | import pytest 21 | 22 | import sedpack 23 | from sedpack.io import Dataset 24 | from sedpack.io import Metadata 25 | from sedpack.io.types import TRAIN_SPLIT, CompressionT, ShardFileTypeT 26 | 27 | from sedpack import _sedpack_rs # type: ignore[attr-defined] 28 | 29 | 30 | def end2end(tmpdir: Union[str, Path], dtype: npt.DTypeLike, 31 | shard_file_type: ShardFileTypeT, compression: CompressionT) -> None: 32 | array_of_values = np.random.random((1024, 138)) 33 | array_of_values = array_of_values.astype(dtype) 34 | 35 | tiny_experiment_path: Path = Path(tmpdir) / "e2e_experiment" 36 | 37 | # Create a dataset 38 | 39 | dataset_metadata = Metadata(description="Test of the lib") 40 | 41 | example_attributes = [ 42 | sedpack.io.metadata.Attribute( 43 | name="attribute_name", 44 | dtype=str(dtype), 45 | shape=array_of_values[0].shape, 46 | ), 47 | ] 48 | 49 | dataset_structure = sedpack.io.metadata.DatasetStructure( 50 | saved_data_description=example_attributes, 51 | compression=compression, 52 | examples_per_shard=256, 53 | shard_file_type=shard_file_type, 54 | ) 55 | 56 | dataset = Dataset.create( 57 | path=tiny_experiment_path, 58 | metadata=dataset_metadata, 59 | dataset_structure=dataset_structure, 60 | ) 61 | 62 | # Fill data in the dataset 63 | 64 | with dataset.filler(concurrency=np.random.randint(0, 4),) as filler: 65 | for attribute_value in array_of_values: 66 | filler.write_example( 67 | values={"attribute_name": attribute_value}, 68 | split=TRAIN_SPLIT, 69 | ) 70 | 71 | # Check the data is correct 72 | # Reopen the dataset 73 | dataset = Dataset(tiny_experiment_path) 74 | dataset.check() 75 | 76 | for i, example in enumerate( 77 | dataset.as_numpy_iterator_rust(split=TRAIN_SPLIT, 78 | repeat=False, 79 | shuffle=0)): 80 | assert np.allclose(example["attribute_name"], array_of_values[i]) 81 | 82 | # We tested everything 83 | assert i + 1 == array_of_values.shape[ 84 | 0], "Not all examples have been iterated" 85 | 86 | 87 | @pytest.mark.parametrize("compression", 88 | list(_sedpack_rs.RustIter.supported_compressions())) 89 | def test_end2end_as_numpy_iterator_fb(compression: str, 90 | tmpdir: Union[str, Path]) -> None: 91 | end2end( 92 | tmpdir=tmpdir, 93 | dtype="float32", 94 | shard_file_type="fb", 95 | compression=compression, 96 | ) 97 | 98 | 99 | def test_correct_compressions(): 100 | for compression in _sedpack_rs.RustIter.supported_compressions(): 101 | assert compression in set(get_args(CompressionT)) 102 | -------------------------------------------------------------------------------- /src/sedpack/io/shard/shard_writer_tfrec.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Dataset shard manipulation. 15 | 16 | For information how to read and write TFRecord files see 17 | https://www.tensorflow.org/tutorials/load_data/tfrecord 18 | """ 19 | 20 | from pathlib import Path 21 | from typing import Any 22 | 23 | from sedpack.io.metadata import DatasetStructure 24 | from sedpack.io.tfrec.tfdata import to_tfrecord 25 | from sedpack.io.types import ExampleT, CompressionT 26 | from sedpack.io.shard.shard_writer_base import ShardWriterBase 27 | 28 | 29 | class ShardWriterTFRec(ShardWriterBase): 30 | """Shard writing capabilities. 31 | """ 32 | 33 | def __init__(self, dataset_structure: DatasetStructure, 34 | shard_file: Path) -> None: 35 | """Collect information about a new shard. 36 | 37 | Args: 38 | 39 | dataset_structure (DatasetStructure): The structure of data being 40 | saved. 41 | 42 | shard_file (Path): Full path to the shard file. 43 | """ 44 | assert dataset_structure.shard_file_type == "tfrec" 45 | 46 | super().__init__( 47 | dataset_structure=dataset_structure, 48 | shard_file=shard_file, 49 | ) 50 | 51 | # Open the tf.io.TFRecordWriter only with the first `write` call. Make 52 | # it None immediately during a call to `close`. 53 | self._tf_shard_writer: Any | None = None 54 | 55 | def _write(self, values: ExampleT) -> None: 56 | """Write an example on disk. Writing may be buffered. 57 | 58 | Args: 59 | 60 | values (ExampleT): Attribute values. 61 | """ 62 | # TensorFlow is an optional dependency. 63 | import tensorflow as tf # pylint: disable=import-outside-toplevel 64 | 65 | if (self.dataset_structure.compression 66 | not in ShardWriterTFRec.supported_compressions()): 67 | raise ValueError( 68 | f"Unsupported compression {self.dataset_structure.compression}" 69 | " requested for TFRecordWriter, expected " 70 | f"{ShardWriterTFRec.supported_compressions()}") 71 | if not self._tf_shard_writer: 72 | self._tf_shard_writer = tf.io.TFRecordWriter( 73 | str(self._shard_file), 74 | self.dataset_structure.compression, # type: ignore[arg-type] 75 | ) 76 | 77 | example = to_tfrecord( 78 | saved_data_description=self.dataset_structure. 79 | saved_data_description, 80 | values=values, 81 | ) 82 | self._tf_shard_writer.write(example) 83 | 84 | def close(self) -> tuple[str, ...]: 85 | """Close the shard file(-s). 86 | """ 87 | if not self._tf_shard_writer: 88 | raise ValueError("Trying to close a shard that was not open") 89 | self._tf_shard_writer.close() 90 | self._tf_shard_writer = None 91 | return self._compute_file_hash_checksums() 92 | 93 | @staticmethod 94 | def supported_compressions() -> list[CompressionT]: 95 | """Return a list of supported compression types. 96 | """ 97 | return ["GZIP", "ZLIB", ""] 98 | -------------------------------------------------------------------------------- /tests/io/test_end2end_async.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from pathlib import Path 16 | import pytest 17 | from typing import Union 18 | 19 | import numpy as np 20 | import numpy.typing as npt 21 | 22 | import sedpack 23 | from sedpack.io import Dataset 24 | from sedpack.io import Metadata 25 | from sedpack.io.types import TRAIN_SPLIT, CompressionT, ShardFileTypeT 26 | 27 | pytest_plugins = ("pytest_asyncio",) 28 | 29 | 30 | async def end2end(tmpdir: Union[str, Path], dtype: npt.DTypeLike, method: str, 31 | shard_file_type: ShardFileTypeT, 32 | compression: CompressionT) -> None: 33 | array_of_values = np.random.random((1024, 138)) 34 | array_of_values = array_of_values.astype(dtype) 35 | 36 | tiny_experiment_path: Path = Path(tmpdir) / "e2e_experiment" 37 | 38 | # Create a dataset 39 | 40 | dataset_metadata = Metadata(description="Test of the lib") 41 | 42 | example_attributes = [ 43 | sedpack.io.metadata.Attribute( 44 | name="attribute_name", 45 | dtype=str(dtype), 46 | shape=array_of_values[0].shape, 47 | ), 48 | ] 49 | 50 | dataset_structure = sedpack.io.metadata.DatasetStructure( 51 | saved_data_description=example_attributes, 52 | compression=compression, 53 | examples_per_shard=256, 54 | shard_file_type=shard_file_type, 55 | ) 56 | 57 | dataset = Dataset.create( 58 | path=tiny_experiment_path, 59 | metadata=dataset_metadata, 60 | dataset_structure=dataset_structure, 61 | ) 62 | 63 | # Fill data in the dataset 64 | 65 | with dataset.filler() as filler: 66 | for attribute_value in array_of_values: 67 | filler.write_example( 68 | values={"attribute_name": attribute_value}, 69 | split=TRAIN_SPLIT, 70 | ) 71 | 72 | # Check the data is correct 73 | # Reopen the dataset 74 | dataset = Dataset(tiny_experiment_path) 75 | dataset.check() 76 | 77 | i = 0 78 | async for example in dataset.as_numpy_iterator_async(split=TRAIN_SPLIT, 79 | shuffle=0, 80 | repeat=False): 81 | assert np.allclose(example["attribute_name"], array_of_values[i]) 82 | i += 1 83 | 84 | # We tested everything 85 | assert i == array_of_values.shape[0], "Not all examples have been iterated" 86 | 87 | 88 | @pytest.mark.asyncio 89 | async def test_end2end_as_numpy_iterator_npz(tmpdir: Union[str, Path]) -> None: 90 | await end2end(tmpdir=tmpdir, 91 | dtype="float32", 92 | method="as_numpy_iterator", 93 | shard_file_type="npz", 94 | compression="ZIP") 95 | 96 | 97 | @pytest.mark.asyncio 98 | async def test_end2end_as_numpy_iterator_fb(tmpdir: Union[str, Path]) -> None: 99 | await end2end(tmpdir=tmpdir, 100 | dtype="float32", 101 | method="as_numpy_iterator", 102 | shard_file_type="fb", 103 | compression="LZ4") 104 | -------------------------------------------------------------------------------- /tests/io/test_write_multiprocessing.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from pathlib import Path 16 | from typing import Union 17 | 18 | import numpy as np 19 | import numpy.typing as npt 20 | 21 | import sedpack 22 | from sedpack.io import Dataset, DatasetFiller 23 | from sedpack.io import Metadata 24 | from sedpack.io.types import SplitT, TRAIN_SPLIT 25 | 26 | 27 | def feed_writer(dataset_filler: DatasetFiller, 28 | array_of_values: npt.NDArray[np.generic], split: SplitT) -> int: 29 | # Fill data in the dataset 30 | 31 | with dataset_filler as filler: 32 | for attribute_value in array_of_values: 33 | filler.write_example( 34 | values={"attribute_name": attribute_value}, 35 | split=split, 36 | ) 37 | 38 | return len(array_of_values) 39 | 40 | 41 | def test_write_multiprocessing(tmpdir: Union[str, Path]) -> None: 42 | dtype = "float32" 43 | array_of_values = np.random.random((1024, 138)) 44 | array_of_values = array_of_values.astype(dtype) 45 | 46 | tiny_experiment_path: Path = Path(tmpdir) / "write_multiprocessing" 47 | 48 | # Create a dataset 49 | 50 | dataset_metadata = Metadata( 51 | description="Test of the lib write_multiprocessing") 52 | 53 | example_attributes = [ 54 | sedpack.io.metadata.Attribute( 55 | name="attribute_name", 56 | dtype=str(dtype), 57 | shape=array_of_values[0].shape, 58 | ), 59 | ] 60 | 61 | dataset_structure = sedpack.io.metadata.DatasetStructure( 62 | saved_data_description=example_attributes, 63 | compression="GZIP", 64 | examples_per_shard=256, 65 | shard_file_type="fb", 66 | ) 67 | 68 | dataset = Dataset.create( 69 | path=tiny_experiment_path, 70 | metadata=dataset_metadata, 71 | dataset_structure=dataset_structure, 72 | ) 73 | 74 | custom_arguments = [ 75 | (array_of_values[:100],), 76 | (array_of_values[100:900],), 77 | (array_of_values[900:],), 78 | ] 79 | custom_kwarguments = [ 80 | { 81 | "split": "train" 82 | }, 83 | { 84 | "split": "train" 85 | }, 86 | { 87 | "split": "train" 88 | }, 89 | ] 90 | 91 | results = dataset.write_multiprocessing( 92 | feed_writer=feed_writer, 93 | custom_arguments=custom_arguments, 94 | custom_kwarguments=custom_kwarguments, 95 | single_process=True) 96 | 97 | assert results == [ 98 | len(part_of_array_of_values[0]) 99 | for part_of_array_of_values in custom_arguments 100 | ] 101 | 102 | # Check the data is correct 103 | 104 | for i, example in enumerate( 105 | dataset.as_numpy_iterator_rust_batched( 106 | split=TRAIN_SPLIT, 107 | shuffle=0, 108 | repeat=False, 109 | batch_size=1, 110 | )): 111 | assert np.allclose(example["attribute_name"], array_of_values[i:i + 1]) 112 | 113 | # We tested everything 114 | assert i + 1 == array_of_values.shape[ 115 | 0], "Not all examples have been iterated" 116 | -------------------------------------------------------------------------------- /src/sedpack/io/flatbuffer/shardfile/Example.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # automatically generated by the FlatBuffers compiler, do not modify 16 | 17 | # namespace: shardfile 18 | 19 | # pylint: skip-file 20 | 21 | import flatbuffers 22 | from flatbuffers.builder import Builder 23 | from flatbuffers.compat import import_numpy 24 | 25 | from sedpack.io.flatbuffer.shardfile.Attribute import Attribute 26 | 27 | np = import_numpy() 28 | 29 | 30 | class Example(object): 31 | __slots__ = ['_tab'] 32 | 33 | @classmethod 34 | def GetRootAs(cls, buf: bytes, offset: int = 0) -> "Example": 35 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) 36 | x = Example() 37 | x.Init(buf, n + offset) 38 | return x 39 | 40 | # Example 41 | def Init(self, buf: bytes, pos: int) -> None: 42 | self._tab = flatbuffers.table.Table(buf, pos) 43 | 44 | # Example 45 | def Attributes(self, j: int) -> Attribute | None: 46 | o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) 47 | if o != 0: 48 | x = self._tab.Vector(o) 49 | x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4 50 | x = self._tab.Indirect(x) 51 | obj = Attribute() 52 | obj.Init(self._tab.Bytes, x) 53 | return obj 54 | return None 55 | 56 | # Example 57 | def AttributesLength(self) -> int: 58 | o: int = flatbuffers.number_types.UOffsetTFlags.py_type( 59 | self._tab.Offset(4)) 60 | if o != 0: 61 | return self._tab.VectorLen(o) # type: ignore[no-any-return] 62 | return 0 63 | 64 | # Example 65 | def AttributesIsNone(self) -> bool: 66 | o: int = flatbuffers.number_types.UOffsetTFlags.py_type( 67 | self._tab.Offset(4)) 68 | return o == 0 69 | 70 | 71 | def ExampleStart(builder: Builder) -> None: # type: ignore[no-any-unimported] 72 | builder.StartObject(1) 73 | 74 | 75 | def Start(builder: Builder) -> None: # type: ignore[no-any-unimported] 76 | ExampleStart(builder) 77 | 78 | 79 | def ExampleAddAttributes( # type: ignore[no-any-unimported] 80 | builder: Builder, attributes: int) -> None: 81 | builder.PrependUOffsetTRelativeSlot( 82 | 0, flatbuffers.number_types.UOffsetTFlags.py_type(attributes), 0) 83 | 84 | 85 | def AddAttributes( # type: ignore[no-any-unimported] 86 | builder: Builder, attributes: int) -> None: 87 | ExampleAddAttributes(builder, attributes) 88 | 89 | 90 | def ExampleStartAttributesVector( # type: ignore[no-any-unimported] 91 | builder: Builder, numElems: int) -> int: 92 | return builder.StartVector(4, numElems, 4) # type: ignore[no-any-return] 93 | 94 | 95 | def StartAttributesVector( # type: ignore[no-any-unimported] 96 | builder: Builder, numElems: int) -> int: 97 | return ExampleStartAttributesVector(builder, numElems) 98 | 99 | 100 | def ExampleEnd(builder: Builder) -> int: # type: ignore[no-any-unimported] 101 | return builder.EndObject() # type: ignore[no-any-return] 102 | 103 | 104 | def End(builder: Builder) -> int: # type: ignore[no-any-unimported] 105 | return ExampleEnd(builder) 106 | -------------------------------------------------------------------------------- /src/sedpack/io/shard/shard_writer_base.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Base class for shard writing depending on shard_file_type. 15 | """ 16 | 17 | from abc import ABC, abstractmethod 18 | from pathlib import Path 19 | 20 | import numpy as np 21 | 22 | import sedpack 23 | from sedpack.io.metadata import DatasetStructure 24 | from sedpack.io.types import ExampleT, CompressionT 25 | 26 | 27 | class ShardWriterBase(ABC): 28 | """Shard writing capabilities. 29 | """ 30 | 31 | def __init__(self, dataset_structure: DatasetStructure, 32 | shard_file: Path) -> None: 33 | """Collect information about a new shard. 34 | 35 | Args: 36 | 37 | dataset_structure (DatasetStructure): The structure of data being 38 | saved. 39 | 40 | shard_file (Path): Full path to the shard file. 41 | """ 42 | # Information needed to save the shard. 43 | self.dataset_structure: DatasetStructure = dataset_structure 44 | self._shard_file: Path = shard_file 45 | 46 | # Make sure that the directory exists. 47 | self._shard_file.parent.mkdir(exist_ok=True, parents=True) 48 | 49 | # Make sure that the compression is supported for this shard file type. 50 | assert dataset_structure.compression in self.supported_compressions() 51 | 52 | def write(self, values: ExampleT) -> None: 53 | """Write an example on disk. Writing may be buffered. 54 | 55 | Args: 56 | 57 | values (ExampleT): Attribute values. 58 | """ 59 | # Check the values are correct type and shape. 60 | for attribute in self.dataset_structure.saved_data_description: 61 | # If the attribute dtype is "bytes" and shape is empty tuple we 62 | # consider this a variable size attribute and do not check shape. 63 | if attribute.has_variable_size(): 64 | continue 65 | 66 | # Else check the shape (the value should be a NumPy array but maybe 67 | # it is an int or bytearray). 68 | current_shape = np.array(values[attribute.name]).shape 69 | if current_shape != attribute.shape: 70 | raise ValueError(f"Attribute {attribute.name} has shape " 71 | f"{current_shape} expected {attribute.shape}") 72 | 73 | self._write(values=values) 74 | 75 | @abstractmethod 76 | def _write(self, values: ExampleT) -> None: 77 | """Write an example on disk. Writing may be buffered. 78 | 79 | Args: 80 | 81 | values (ExampleT): Attribute values. 82 | """ 83 | 84 | @abstractmethod 85 | def close(self) -> tuple[str, ...]: 86 | """Close the shard file(-s). 87 | """ 88 | 89 | @staticmethod 90 | @abstractmethod 91 | def supported_compressions() -> list[CompressionT]: 92 | """Return a list of supported compression types. 93 | """ 94 | 95 | def _compute_file_hash_checksums(self) -> tuple[str, ...]: 96 | """Compute hash checksums of the shard file(-s). """ 97 | return sedpack.io.utils.hash_checksums( 98 | file_path=self._shard_file, 99 | hashes=self.dataset_structure.hash_checksum_algorithms, 100 | ) 101 | -------------------------------------------------------------------------------- /tests/io/shard_info_iterator/test_balanced_iterator.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import itertools 16 | 17 | from sedpack.io.shard_file_metadata import ShardInfo 18 | from sedpack.io.shard_info_iterator.balanced_iterator import _split_balancing 19 | 20 | 21 | def test_split_balancing_all() -> None: 22 | shard_list: list[ShardInfo] = [ 23 | ShardInfo( 24 | file_infos=(), 25 | number_of_examples=1, 26 | custom_metadata={ 27 | "id": i, 28 | "id_mod_3_is_zero": i % 3 == 0, 29 | }, 30 | ) for i in range(100) 31 | ] 32 | 33 | balanced = _split_balancing( 34 | shard_list=shard_list, 35 | balance_by=[ 36 | lambda shard_info: shard_info.custom_metadata["id_mod_3_is_zero"] 37 | ], 38 | repeat=False, 39 | shuffle=10, 40 | ) 41 | 42 | assert set( 43 | shard_info.custom_metadata["id"] for shard_info in balanced) == set( 44 | shard_info.custom_metadata["id"] for shard_info in shard_list) 45 | 46 | 47 | def test_split_balancing_balances() -> None: 48 | shard_list: list[ShardInfo] = [ 49 | ShardInfo( 50 | file_infos=(), 51 | number_of_examples=1, 52 | custom_metadata={ 53 | "id": i, 54 | "id_mod_3_is_zero": i % 3 == 0, 55 | }, 56 | ) for i in range(100) 57 | ] 58 | 59 | balanced = _split_balancing( 60 | shard_list=shard_list, 61 | balance_by=[ 62 | lambda shard_info: shard_info.custom_metadata["id_mod_3_is_zero"] 63 | ], 64 | repeat=True, 65 | shuffle=10, 66 | ) 67 | 68 | take_n = 1_000 69 | assert sum(shard_info.custom_metadata["id_mod_3_is_zero"] for shard_info in 70 | list(itertools.islice(balanced, take_n))) == take_n // 2 71 | 72 | 73 | def test_custom_weight() -> None: 74 | shard_list: list[ShardInfo] = [ 75 | ShardInfo( 76 | file_infos=(), 77 | number_of_examples=1, 78 | custom_metadata={ 79 | "id": i, 80 | "id_mod_3_is_zero": i % 3 == 0, 81 | }, 82 | ) for i in range(100) 83 | ] 84 | 85 | class BalanceBy: 86 | 87 | def __call__(self, shard_info: ShardInfo) -> bool: 88 | return shard_info.custom_metadata["id_mod_3_is_zero"] 89 | 90 | def weight(self, shard_info: ShardInfo) -> float: 91 | if shard_info.custom_metadata["id_mod_3_is_zero"]: 92 | # Do four times more of the zeros. Meaning for each non-zero 93 | # example there are four zero examples -> 80% of the zero 94 | # examples. 95 | return 4 96 | else: 97 | return 1 98 | 99 | balance_by_top = BalanceBy() 100 | 101 | balanced = _split_balancing( 102 | shard_list=shard_list, 103 | balance_by=[ 104 | balance_by_top, 105 | ], 106 | repeat=True, 107 | shuffle=10, 108 | ) 109 | 110 | take_n = 1_000 111 | assert sum(shard_info.custom_metadata["id_mod_3_is_zero"] for shard_info in 112 | list(itertools.islice(balanced, take_n))) == take_n * (4 / 5) 113 | -------------------------------------------------------------------------------- /src/sedpack/io/flatbuffer/shardfile/Shard.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # automatically generated by the FlatBuffers compiler, do not modify 16 | # Added best-effort type hints since this file is not likely to change. 17 | 18 | # namespace: shardfile 19 | 20 | # pylint: skip-file 21 | 22 | import flatbuffers 23 | from flatbuffers import Builder 24 | from flatbuffers.compat import import_numpy 25 | 26 | from sedpack.io.flatbuffer.shardfile.Example import Example 27 | 28 | np = import_numpy() 29 | 30 | 31 | class Shard(object): 32 | __slots__ = ['_tab'] 33 | 34 | @classmethod 35 | def GetRootAs(cls, buf: bytes, offset: int = 0) -> "Shard": 36 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) 37 | x = Shard() 38 | x.Init(buf, n + offset) 39 | return x 40 | 41 | # Shard 42 | def Init(self, buf: bytes, pos: int) -> None: 43 | self._tab = flatbuffers.table.Table(buf, pos) 44 | 45 | # Shard 46 | def Examples(self, j: int) -> Example | None: 47 | o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) 48 | if o != 0: 49 | x = self._tab.Vector(o) 50 | x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4 51 | x = self._tab.Indirect(x) 52 | obj = Example() 53 | obj.Init(self._tab.Bytes, x) 54 | return obj 55 | return None 56 | 57 | # Shard 58 | def ExamplesLength(self) -> int: 59 | o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) 60 | if o != 0: 61 | return self._tab.VectorLen(o) # type: ignore[no-any-return] 62 | return 0 63 | 64 | # Shard 65 | def ExamplesIsNone(self) -> bool: 66 | o: int = flatbuffers.number_types.UOffsetTFlags.py_type( 67 | self._tab.Offset(4)) 68 | return o == 0 69 | 70 | 71 | def ShardStart(builder: Builder) -> None: # type: ignore[no-any-unimported] 72 | builder.StartObject(1) 73 | 74 | 75 | def Start(builder: Builder) -> None: # type: ignore[no-any-unimported] 76 | ShardStart(builder) 77 | 78 | 79 | def ShardAddExamples( # type: ignore[no-any-unimported] 80 | builder: Builder, 81 | examples: int, 82 | ) -> None: 83 | builder.PrependUOffsetTRelativeSlot( 84 | 0, 85 | flatbuffers.number_types.UOffsetTFlags.py_type(examples), 86 | 0, 87 | ) 88 | 89 | 90 | def AddExamples( # type: ignore[no-any-unimported] 91 | builder: Builder, 92 | examples: int, 93 | ) -> None: 94 | ShardAddExamples(builder, examples) 95 | 96 | 97 | def ShardStartExamplesVector( # type: ignore[no-any-unimported] 98 | builder: Builder, 99 | numElems: int, 100 | ) -> int: 101 | return builder.StartVector(4, numElems, 4) # type: ignore[no-any-return] 102 | 103 | 104 | def StartExamplesVector( # type: ignore[no-any-unimported] 105 | builder: Builder, numElems: int) -> int: 106 | return ShardStartExamplesVector( 107 | builder, 108 | numElems, 109 | ) 110 | 111 | 112 | def ShardEnd(builder: Builder) -> int: # type: ignore[no-any-unimported] 113 | return builder.EndObject() # type: ignore[no-any-return] 114 | 115 | 116 | def End(builder: Builder) -> int: # type: ignore[no-any-unimported] 117 | return ShardEnd(builder) 118 | -------------------------------------------------------------------------------- /src/sedpack/io/tfrec/read.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Read a dataset using pure asyncio. 15 | """ 16 | 17 | import asyncio 18 | import os 19 | from pathlib import Path 20 | from typing import Any, AsyncIterator, Callable, Iterable 21 | 22 | import tensorflow as tf 23 | 24 | from sedpack.io.metadata import DatasetStructure 25 | from sedpack.io.shard import IterateShardBase 26 | from sedpack.io.shard.iterate_shard_base import T 27 | from sedpack.io.tfrec.tfdata import get_from_tfrecord 28 | from sedpack.io.types import ExampleT 29 | from sedpack.io.utils import func_or_identity 30 | 31 | 32 | class IterateShardTFRec(IterateShardBase[T]): 33 | """Iterate a TFRec shard. 34 | """ 35 | 36 | def __init__( 37 | self, 38 | dataset_structure: DatasetStructure, 39 | process_record: Callable[[ExampleT], Any] | None, 40 | num_parallel_calls: int = os.cpu_count() or 4, 41 | ) -> None: 42 | super().__init__(dataset_structure=dataset_structure, 43 | process_record=process_record) 44 | # This is not pickleable, but can be created on the fly. 45 | self.from_tfrecord: Callable[[Any], Any] | None = None 46 | self.num_parallel_calls: int = num_parallel_calls 47 | 48 | def iterate_shard(self, file_path: Path) -> Iterable[ExampleT]: 49 | """Iterate a shard saved in the TFRec format 50 | """ 51 | if not self.from_tfrecord: 52 | self.from_tfrecord = get_from_tfrecord( 53 | self.dataset_structure.saved_data_description) 54 | 55 | # Read the shard. 56 | tf_dataset_records = tf.data.TFRecordDataset( 57 | str(file_path), 58 | compression_type=self.dataset_structure. 59 | compression, # type: ignore[arg-type] 60 | ) 61 | 62 | # Decode examples. 63 | tf_dataset_examples = tf_dataset_records.map( 64 | self.from_tfrecord, 65 | num_parallel_calls=self.num_parallel_calls, 66 | ) 67 | 68 | yield from tf_dataset_examples.as_numpy_iterator() # type: ignore[misc] 69 | 70 | # TODO(issue #85) fix and test async iterator typing 71 | async def iterate_shard_async( # pylint: disable=invalid-overridden-method 72 | self, 73 | file_path: Path, 74 | ) -> AsyncIterator[ExampleT]: 75 | for example in self.iterate_shard(file_path=file_path): 76 | yield example 77 | # Give up event loop (a bit dirty). 78 | await asyncio.sleep(0) 79 | 80 | def process_and_list(self, shard_file: Path) -> list[T]: 81 | """Return a list of processed examples. Used as a function call in a 82 | different process. Returning a list as opposed to an iterator allows to 83 | do all work in another process and all that needs to be done is a 84 | memory copy between processes. 85 | 86 | TODO think of a way to avoid copying memory between processes. 87 | 88 | Args: 89 | 90 | shard_file (Path): Path to the shard file. 91 | 92 | Returns a list of examples present in the shard identified by the path 93 | where a `process_record` function has been applied (if not None). 94 | """ 95 | process_record = func_or_identity(self.process_record) 96 | 97 | return [ 98 | process_record(example) 99 | for example in self.iterate_shard(file_path=shard_file) 100 | ] 101 | -------------------------------------------------------------------------------- /src/sedpack/io/compress.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Open a file with specified compression. 15 | """ 16 | 17 | import bz2 18 | import gzip 19 | import lzma 20 | 21 | import lz4.frame 22 | import zstandard as zstd 23 | 24 | from sedpack.io.types import CompressionT 25 | 26 | 27 | class CompressedFile: 28 | """Provide an easy open function for dealing with compressed files. 29 | """ 30 | 31 | def __init__(self, compression_type: CompressionT) -> None: 32 | """Initialize a compressed file opening. 33 | 34 | compression_type (CompressionT): The type of compression. Note that ZIP 35 | is not supported yet. 36 | """ 37 | self.compression_type: CompressionT = compression_type 38 | 39 | if compression_type in ["ZIP"]: 40 | # Zip is a container meaning we open something.zip and inside that 41 | # we open file(-s). This requires more work on the context manager 42 | # side. Not implementing yet. 43 | raise NotImplementedError(f"Compression {compression_type} is not " 44 | f"supported yet by CompressedFile") 45 | 46 | @staticmethod 47 | def supported_compressions() -> list[CompressionT]: 48 | """Return a list of supported compression types. 49 | """ 50 | return [ 51 | "", 52 | "BZ2", 53 | "GZIP", 54 | "LZMA", 55 | "LZ4", 56 | "ZLIB", 57 | "ZSTD", 58 | ] 59 | 60 | def compress(self, data: bytes) -> bytes: 61 | """Compression. 62 | 63 | Args: 64 | 65 | data (bytes): Content to compress. 66 | 67 | Returns: the compressed data. 68 | """ 69 | match self.compression_type: 70 | case "": 71 | return data 72 | case "GZIP" | "ZLIB": 73 | return gzip.compress(data, compresslevel=9) 74 | case "BZ2": 75 | return bz2.compress(data, compresslevel=9) 76 | case "LZMA": 77 | return lzma.compress(data) 78 | case "LZ4": 79 | return lz4.frame.compress(data) # type: ignore[no-any-return] 80 | case "ZSTD": 81 | return zstd.compress(data) 82 | case _: 83 | raise NotImplementedError(f"CompressedFile does not implement " 84 | f"{self.compression_type} yet.") 85 | 86 | def decompress(self, data: bytes) -> bytes: 87 | """Decompression. 88 | 89 | Args: 90 | 91 | data (bytes): Content of the file to be decompressed. 92 | 93 | Returns: the decompressed data. 94 | """ 95 | match self.compression_type: 96 | case "": 97 | return data 98 | case "GZIP" | "ZLIB": 99 | return gzip.decompress(data) 100 | case "BZ2": 101 | return bz2.decompress(data) 102 | case "LZMA": 103 | return lzma.decompress(data) 104 | case "LZ4": 105 | return lz4.frame.decompress(data) # type: ignore[no-any-return] 106 | case "ZSTD": 107 | return zstd.decompress(data) 108 | case _: 109 | raise NotImplementedError(f"CompressedFile does not implement " 110 | f"{self.compression_type} yet.") 111 | -------------------------------------------------------------------------------- /src/sedpack/io/dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023-2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Build and load tensorFlow dataset Record wrapper""" 15 | 16 | from pathlib import Path 17 | from typing import Union 18 | 19 | from sedpack.io.dataset_base import DatasetBase 20 | from sedpack.io.dataset_iteration import DatasetIteration 21 | from sedpack.io.dataset_iteration_tf import DatasetIterationTF 22 | from sedpack.io.dataset_writing import DatasetWriting 23 | from sedpack.io.errors import DatasetExistsError 24 | from sedpack.io.metadata import DatasetInfo, DatasetStructure, Metadata 25 | 26 | 27 | class Dataset( 28 | DatasetIteration, 29 | DatasetIterationTF, 30 | DatasetWriting, 31 | ): 32 | """Dataset class.""" 33 | 34 | def __init__(self, 35 | path: Union[str, Path], 36 | create_dataset: bool = False) -> None: 37 | """Class for saving and loading a database. 38 | 39 | Args: 40 | 41 | path (Union[str, Path]): Path from where to load the dataset (a 42 | directory -- for instance 43 | "/home/user_name/datasets/my_awesome_dataset"). 44 | 45 | create_dataset (bool): Are we creating a new dataset? Defaults to 46 | False which is used when (down-)loading a dataset. 47 | 48 | Raises: 49 | 50 | ValueError if the dataset was created using a newer version of the 51 | sedpack than the one trying to load it. See 52 | sedpack.__version__ docstring. 53 | 54 | FileNotFoundError if `create_dataset` is False and the 55 | `dataset_info.json` file does not exist. 56 | """ 57 | dataset_info: DatasetInfo 58 | if create_dataset: 59 | # Default DatasetInfo. 60 | dataset_info = DatasetInfo() 61 | else: 62 | # Load the information. 63 | dataset_info = DatasetBase._load(Path(path)) 64 | 65 | super().__init__(path=path, dataset_info=dataset_info) 66 | 67 | @staticmethod 68 | def create( 69 | path: Union[str, Path], 70 | metadata: Metadata, 71 | dataset_structure: DatasetStructure, 72 | ) -> "Dataset": 73 | """Create an empty dataset to be filled using the `filler` or 74 | `write_multiprocessing` API. 75 | 76 | Args: 77 | 78 | path (Union[str, Path]): Path where the dataset gets saved (a 79 | directory -- for instance 80 | "/home/user_name/datasets/my_awesome_dataset"). 81 | 82 | metadata (Metadata): Information about this dataset. 83 | 84 | dataset_structure (DatasetStructure): Structure of saved records. 85 | 86 | Raises: DatasetExistsError if creating this object would overwrite the 87 | corresponding config file. 88 | """ 89 | # Create a new object. 90 | dataset = Dataset(path=Path(path), create_dataset=True) 91 | 92 | # Do not overwrite an existing dataset. 93 | if Dataset._get_config_path(dataset.path).is_file(): 94 | # Raise if the dataset already exists. 95 | raise DatasetExistsError(dataset.path) 96 | 97 | # Create a new dataset directory if needed. 98 | dataset.path.mkdir(parents=True, exist_ok=True) 99 | 100 | # Fill metadata and structure parameters. 101 | dataset.metadata = metadata 102 | dataset.dataset_structure = dataset_structure 103 | 104 | # Write empty config. 105 | dataset.write_config(updated_infos=[]) 106 | return dataset 107 | -------------------------------------------------------------------------------- /src/sedpack/io/shard_info_iterator/shard_info_iterator.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Base class for a shard info iterator.""" 15 | import itertools 16 | from pathlib import Path 17 | 18 | from typing import Iterator 19 | 20 | from sedpack.io.metadata import DatasetInfo 21 | from sedpack.io.shard_file_metadata import ShardInfo, ShardsList, ShardListInfo 22 | from sedpack.io.types import SplitT 23 | 24 | 25 | class ShardInfoIterator: 26 | """Iterate shards of a dataset. 27 | """ 28 | 29 | def __init__( 30 | self, 31 | *, 32 | dataset_path: Path, 33 | dataset_info: DatasetInfo, 34 | split: SplitT | None, 35 | repeat: bool = False, 36 | ) -> None: 37 | """Initialize shard information iteration. 38 | 39 | Args: 40 | 41 | dataset_path (Path): The path to the dataset directory. 42 | 43 | dataset_info (DatasetInfo): The information about the iterated 44 | dataset. 45 | 46 | split (SplitT | None): Which split to iterate or all if set to None. 47 | 48 | repeat (bool): Should we cycle indefinitely? 49 | """ 50 | self.dataset_path: Path = Path(dataset_path) 51 | self.dataset_info: DatasetInfo = dataset_info 52 | self.split: SplitT | None = split 53 | self.repeat: bool = repeat 54 | 55 | self._iterator: Iterator[ShardInfo] = iter([]) 56 | 57 | def __len__(self) -> int: 58 | """Either return the number of ShardInfo objects iterated or raise a 59 | ValueError if infinite cycle. 60 | """ 61 | if self.number_of_shards() == 0 or not self.repeat: 62 | return self.number_of_shards() 63 | raise ValueError("Infinite iteration") 64 | 65 | def number_of_shards(self) -> int: 66 | """Return the number of distinct shards that are iterated. When 67 | repeated this method still returns a finite answer. 68 | """ 69 | # Single split. 70 | if self.split is None: 71 | # Sum all splits. 72 | return sum(shard_list_info.number_of_shards 73 | for shard_list_info in self.dataset_info.splits.values()) 74 | 75 | if self.split not in self.dataset_info.splits: 76 | return 0 77 | 78 | return self.dataset_info.splits[self.split].number_of_shards 79 | 80 | def _shard_info_iterator( 81 | self, shard_list_info: ShardListInfo) -> Iterator[ShardInfo]: 82 | """Recursively yield `ShardInfo` from the whole directory tree. 83 | """ 84 | shard_list: ShardsList = ShardsList.model_validate_json( 85 | (self.dataset_path / 86 | shard_list_info.shard_list_info_file.file_path).read_text()) 87 | 88 | yield from shard_list.shard_files 89 | 90 | for child in shard_list.children_shard_lists: 91 | yield from self._shard_info_iterator(child) 92 | 93 | def __iter__(self) -> Iterator[ShardInfo]: 94 | """Return the shard information iterator (reentrant). 95 | """ 96 | if self.split is None: 97 | self._iterator = itertools.chain.from_iterable( 98 | self._shard_info_iterator(shard_list_info) 99 | for shard_list_info in self.dataset_info.splits.values()) 100 | else: 101 | self._iterator = self._shard_info_iterator( 102 | self.dataset_info.splits[self.split]) 103 | 104 | return self._iterator 105 | 106 | def __next__(self) -> ShardInfo: 107 | """Get the next item. 108 | """ 109 | return next(self._iterator) 110 | -------------------------------------------------------------------------------- /website/src/content/docs/tutorials/sca/overview.mdx: -------------------------------------------------------------------------------- 1 | --- 2 | title: Side Channel Attacks Overview 3 | description: Side Channel Attacks Overview 4 | --- 5 | 6 | Side channel attacks (SCA) and side channel analysis (conveniently also SCA) 7 | study how to correlate data dependent computation characteristics (e.g., 8 | timing, power consumption or electromagnetic emissions) to secret values. There 9 | is a rich body of research in this area. For some evaluations large amounts of 10 | data are needed. 11 | 12 | At the [CHES 24 OPTIMIST workshop](https://optimist-ose.org/workshop-24) an 13 | initiative for [Open Tools, Interfaces and Metrics for Implementation Security 14 | Testing (OPTIMIST)](https://optimist-ose.org/) started. One of the outcomes 15 | being a call for format for trace storage. In this series of tutorials we argue 16 | that Sedpack is a viable solution for this purpose. If you have any questions, 17 | feature suggestions or patches see [sedpack GitHub 18 | repository](https://github.com/google/sedpack). 19 | 20 | The **sedpack** project started as a refactor and evolution of the data storage 21 | system used by **SCAAML (Side Channel Attacks Assisted with Machine 22 | Learning)**. See the [SCAAML website](https://google.github.io/scaaml/) or the 23 | [SCAAML GitHub repository](https://github.com/google/scaaml/). 24 | 25 | ## What's Next 26 | 27 | For the purposes of exposition we mainly focus on the very easy dataset of 28 | power consumption measurements of a textbook [AES 29 | (Wikipedia)](https://wikipedia.org/wiki/Advanced_Encryption_Standard) 30 | implementation. This dataset is both publicly available and easy to attack 31 | (analyze). If you know side channel attacks feel free to skip to the next 32 | sections otherwise you might choose to read through the following blog-posts: 33 | 34 | - [A Hacker Guide To Deep Learning Based Side Channel 35 | Attacks](https://elie.net/talk/a-hackerguide-to-deep-learning-based-side-channel-attacks) 36 | a video of DefCon 27 (2019) talk introducing SCAAML. 37 | - [Hacker's guide to deep-learning side-channel attacks: the 38 | theory](https://elie.net/blog/security/hacker-guide-to-deep-learning-side-channel-attacks-the-theory) 39 | a gentle introduction to the theory. 40 | - [Hacker's guide to deep-learning side-channel attacks: code 41 | walkthrough](https://elie.net/blog/security/hacker-guide-to-deep-learning-side-channel-attacks-code-walkthrough) 42 | the very first version of the SCAAML model. Later in this series of 43 | tutorials we use our most recent model on the same dataset. 44 | 45 | In this series of tutorials we showcase how the sedpack storage can be 46 | leveraged to help with side channel analysis. This series is not a 47 | one-solution-fits fully fledged side channel framework. The point is to 48 | showcase a tool designed to do one thing well (data storage) and enables others 49 | to work together using as universal interface as possible. We will see that 50 | different forms of iteration (randomly shuffled or just plain iteration) 51 | possibly with batching (yielding several examples at once) provides a ground 52 | for both machine learning as well as classical attack needs. 53 | 54 | ## Existing Tools 55 | 56 | There is a variety of excellent tools for side channel analysis. Most of which 57 | provide not only storage capabilities but also side channel analysis tools. 58 | This provides an advantage of being ready to use. On the other hand storing 59 | data in one of those $N$ tools and using algorithms from another creates an $N 60 | \times N$ compatibility matrix (with the diagonal hopefully being trivial). Our 61 | goal is to make one storage tool which can be then used by other tools. And 62 | thus saving a significant amount of duplicated work. 63 | 64 | ### An Incomplete List of Existing Tools 65 | 66 | We acknowledge that the following list is incomplete. In fact it is our hope 67 | that it is incomplete as we hope this list will be growing. In an alphabetical 68 | order: 69 | 70 | - [LASCAR](https://github.com/Ledger-Donjon/lascar) 71 | - [SCARR](https://github.com/decryptofy/scarr) 72 | - [riscure](https://www.keysight.com/ch/de/products/network-test/device-vulnerability-analysis.html) 73 | 74 | If we forgot your tool and you want it to be listed here please open an issue 75 | or send us a pull request. 76 | -------------------------------------------------------------------------------- /tests/io/test_shard_custom_metadata.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from pathlib import Path 16 | from typing import Union 17 | 18 | import numpy as np 19 | 20 | import sedpack 21 | from sedpack.io import Dataset 22 | from sedpack.io import Metadata 23 | from sedpack.io.types import TRAIN_SPLIT 24 | 25 | 26 | def test_custom_shard_metadata(tmpdir: Union[str, Path]) -> None: 27 | attribute_dtype = "float32" 28 | shard_file_type = "fb" 29 | 30 | array_of_values = np.random.random((32, 138)).astype(dtype=attribute_dtype) 31 | 32 | experiment_path: Path = Path(tmpdir) / "custom_shard_metadata_experiment" 33 | 34 | # Create a dataset 35 | dataset_metadata = Metadata(description="Test of the lib") 36 | 37 | example_attributes = [ 38 | sedpack.io.metadata.Attribute( 39 | name="attribute_name", 40 | dtype=attribute_dtype, 41 | shape=array_of_values[0].shape, 42 | ), 43 | ] 44 | 45 | dataset_structure = sedpack.io.metadata.DatasetStructure( 46 | saved_data_description=example_attributes, 47 | compression="GZIP", 48 | examples_per_shard=4, 49 | shard_file_type=shard_file_type, 50 | ) 51 | 52 | dataset = Dataset.create( 53 | path=experiment_path, 54 | metadata=dataset_metadata, 55 | dataset_structure=dataset_structure, 56 | ) 57 | 58 | # Fill data in the dataset 59 | custom_metadata_0 = {"key": {"key2": "valueA"}} 60 | custom_metadata_1 = {"key": {"key2": "valueB"}} 61 | 62 | with dataset.filler() as filler: 63 | # No custom metadata. 64 | filler.write_example( 65 | values={"attribute_name": array_of_values[0]}, 66 | split=TRAIN_SPLIT, 67 | ) 68 | # Still the same shard, retroactively setting metadata here. 69 | filler.write_example( 70 | values={"attribute_name": array_of_values[1]}, 71 | split=TRAIN_SPLIT, 72 | custom_metadata=custom_metadata_0, 73 | ) 74 | # Another shard has been open. 75 | filler.write_example( 76 | values={"attribute_name": array_of_values[2]}, 77 | split=TRAIN_SPLIT, 78 | custom_metadata=custom_metadata_1, 79 | ) 80 | # Still the same shard. 81 | filler.write_example( 82 | values={"attribute_name": array_of_values[3]}, 83 | split=TRAIN_SPLIT, 84 | custom_metadata=custom_metadata_1, 85 | ) 86 | # Still the same shard. 87 | filler.write_example( 88 | values={"attribute_name": array_of_values[4]}, 89 | split=TRAIN_SPLIT, 90 | custom_metadata=custom_metadata_1, 91 | ) 92 | # Still the same shard. 93 | filler.write_example( 94 | values={"attribute_name": array_of_values[5]}, 95 | split=TRAIN_SPLIT, 96 | custom_metadata=custom_metadata_1, 97 | ) 98 | # Shard full, another opened. 99 | filler.write_example( 100 | values={"attribute_name": array_of_values[6]}, 101 | split=TRAIN_SPLIT, 102 | custom_metadata=custom_metadata_1, 103 | ) 104 | 105 | # The object in memory and the saved metadata are the same. 106 | assert dataset._dataset_info == Dataset(experiment_path)._dataset_info 107 | 108 | # There are three shards with the custom metadata. 109 | shards: list[ShardInfo] = list(dataset.shard_info_iterator(TRAIN_SPLIT)) 110 | assert len(shards) == 3 111 | assert shards[0].custom_metadata == custom_metadata_0 112 | assert shards[1].custom_metadata == custom_metadata_1 113 | assert shards[2].custom_metadata == custom_metadata_1 114 | -------------------------------------------------------------------------------- /src/sedpack/io/flatbuffer/shardfile/Attribute.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # automatically generated by the FlatBuffers compiler, do not modify 16 | 17 | # namespace: shardfile 18 | 19 | # pylint: skip-file 20 | 21 | import flatbuffers 22 | 23 | import numpy as np 24 | import numpy.typing as npt 25 | 26 | 27 | class Attribute(object): 28 | __slots__ = ['_tab'] 29 | 30 | @classmethod 31 | def GetRootAs(cls, buf: bytes, offset: int = 0) -> "Attribute": 32 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) 33 | x = Attribute() 34 | x.Init(buf, n + offset) 35 | return x 36 | 37 | # Attribute 38 | def Init(self, buf: bytes, pos: int) -> None: 39 | self._tab = flatbuffers.table.Table(buf, pos) 40 | 41 | # Attribute 42 | def AttributeBytes(self, j: int) -> bytes: 43 | o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) 44 | if o != 0: 45 | a = self._tab.Vector(o) 46 | return self._tab.Get( # type: ignore[no-any-return] 47 | flatbuffers.number_types.Uint8Flags, 48 | a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 1)) 49 | return bytes([]) 50 | 51 | # Attribute 52 | def AttributeBytesAsNumpy(self) -> npt.NDArray[np.uint8]: 53 | o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) 54 | if o != 0: 55 | return self._tab.GetVectorAsNumpy( # type: ignore[no-any-return] 56 | flatbuffers.number_types.Uint8Flags, o) 57 | return np.array([], dtype=np.uint8) 58 | 59 | # Attribute 60 | def AttributeBytesLength(self) -> int: 61 | o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) 62 | if o != 0: 63 | return self._tab.VectorLen(o) # type: ignore[no-any-return] 64 | return 0 65 | 66 | # Attribute 67 | def AttributeBytesIsNone(self) -> bool: 68 | o: int = flatbuffers.number_types.UOffsetTFlags.py_type( 69 | self._tab.Offset(4)) 70 | return o == 0 71 | 72 | 73 | def AttributeStart( # type: ignore[no-any-unimported] 74 | builder: flatbuffers.builder.Builder) -> None: 75 | builder.StartObject(1) 76 | 77 | 78 | def Start( # type: ignore[no-any-unimported] 79 | builder: flatbuffers.builder.Builder) -> None: 80 | AttributeStart(builder) 81 | 82 | 83 | def AttributeAddAttributeBytes( # type: ignore[no-any-unimported] 84 | builder: flatbuffers.builder.Builder, attributeBytes: int) -> None: 85 | builder.PrependUOffsetTRelativeSlot( 86 | 0, 87 | flatbuffers.number_types.UOffsetTFlags.py_type(attributeBytes), 88 | 0, 89 | ) 90 | 91 | 92 | def AddAttributeBytes( # type: ignore[no-any-unimported] 93 | builder: flatbuffers.builder.Builder, attributeBytes: int) -> None: 94 | AttributeAddAttributeBytes(builder, attributeBytes) 95 | 96 | 97 | def AttributeStartAttributeBytesVector( # type: ignore[no-any-unimported] 98 | builder: flatbuffers.builder.Builder, numElems: int) -> int: 99 | return builder.StartVector(1, numElems, 1) # type: ignore[no-any-return] 100 | 101 | 102 | def StartAttributeBytesVector( # type: ignore[no-any-unimported] 103 | builder: flatbuffers.builder.Builder, numElems: int) -> int: 104 | return AttributeStartAttributeBytesVector(builder, numElems) 105 | 106 | 107 | def AttributeEnd( # type: ignore[no-any-unimported] 108 | builder: flatbuffers.builder.Builder) -> int: 109 | return builder.EndObject() # type: ignore[no-any-return] 110 | 111 | 112 | def End( # type: ignore[no-any-unimported] 113 | builder: flatbuffers.builder.Builder) -> int: 114 | return AttributeEnd(builder) 115 | -------------------------------------------------------------------------------- /base-tooling-requirements.txt: -------------------------------------------------------------------------------- 1 | build==1.2.1 --hash=sha256:526263f4870c26f26c433545579475377b2b7588b6f1eac76a001e873ae3e19d --hash=sha256:75e10f767a433d9a86e50d83f418e83efc18ede923ee5ff7df93b6cb0306c5d4 2 | click==8.2.0 --hash=sha256:6b303f0b2aa85f1cb4e5303078fadcbcd4e476f114fab9b5007005711839325c --hash=sha256:f5452aeddd9988eefa20f90f05ab66f17fce1ee2a36907fd30b05bbb5953814d 3 | importlib-metadata==8.7.0 --hash=sha256:d13b81ad223b890aa16c5471f2ac3056cf76c5f10f82d6f9292f0b415f389000 --hash=sha256:e5dd1551894c77868a30651cef00984d50e1002d06942a7101d34870c5f02afd 4 | packaging==25.0 --hash=sha256:29572ef2b1f17581046b3a2227d5c611fb25ec70ca1ba8554b24b0e69331a484 --hash=sha256:d443872c98d677bf60f6a1f2f8c1cb748e8fe762d2bf9d3148b5599295b0fc4f 5 | pip-tools==7.4.0 --hash=sha256:a92a6ddfa86ff389fe6ace381d463bc436e2c705bd71d52117c25af5ce867bb7 --hash=sha256:b67432fd0759ed834c5367f9e0ce8c95441acecfec9c8e24b41aca166757adf0 6 | pyproject-hooks==1.2.0 --hash=sha256:1e859bd5c40fae9448642dd871adf459e5e2084186e8d2c2a79a824c970da1f8 --hash=sha256:9e5c6bfa8dcc30091c74b0cf803c81fdd29d94f01992a7707bc97babb1141913 7 | tomli==2.2.1 --hash=sha256:023aa114dd824ade0100497eb2318602af309e5a55595f76b626d6d9f3b7b0a6 --hash=sha256:02abe224de6ae62c19f090f68da4e27b10af2b93213d36cf44e6e1c5abd19fdd --hash=sha256:286f0ca2ffeeb5b9bd4fcc8d6c330534323ec51b2f52da063b11c502da16f30c --hash=sha256:2d0f2fdd22b02c6d81637a3c95f8cd77f995846af7414c5c4b8d0545afa1bc4b --hash=sha256:33580bccab0338d00994d7f16f4c4ec25b776af3ffaac1ed74e0b3fc95e885a8 --hash=sha256:400e720fe168c0f8521520190686ef8ef033fb19fc493da09779e592861b78c6 --hash=sha256:40741994320b232529c802f8bc86da4e1aa9f413db394617b9a256ae0f9a7f77 --hash=sha256:465af0e0875402f1d226519c9904f37254b3045fc5084697cefb9bdde1ff99ff --hash=sha256:4a8f6e44de52d5e6c657c9fe83b562f5f4256d8ebbfe4ff922c495620a7f6cea --hash=sha256:4e340144ad7ae1533cb897d406382b4b6fede8890a03738ff1683af800d54192 --hash=sha256:678e4fa69e4575eb77d103de3df8a895e1591b48e740211bd1067378c69e8249 --hash=sha256:6972ca9c9cc9f0acaa56a8ca1ff51e7af152a9f87fb64623e31d5c83700080ee --hash=sha256:7fc04e92e1d624a4a63c76474610238576942d6b8950a2d7f908a340494e67e4 --hash=sha256:889f80ef92701b9dbb224e49ec87c645ce5df3fa2cc548664eb8a25e03127a98 --hash=sha256:8d57ca8095a641b8237d5b079147646153d22552f1c637fd3ba7f4b0b29167a8 --hash=sha256:8dd28b3e155b80f4d54beb40a441d366adcfe740969820caf156c019fb5c7ec4 --hash=sha256:9316dc65bed1684c9a98ee68759ceaed29d229e985297003e494aa825ebb0281 --hash=sha256:a198f10c4d1b1375d7687bc25294306e551bf1abfa4eace6650070a5c1ae2744 --hash=sha256:a38aa0308e754b0e3c67e344754dff64999ff9b513e691d0e786265c93583c69 --hash=sha256:a92ef1a44547e894e2a17d24e7557a5e85a9e1d0048b0b5e7541f76c5032cb13 --hash=sha256:ac065718db92ca818f8d6141b5f66369833d4a80a9d74435a268c52bdfa73140 --hash=sha256:b82ebccc8c8a36f2094e969560a1b836758481f3dc360ce9a3277c65f374285e --hash=sha256:c954d2250168d28797dd4e3ac5cf812a406cd5a92674ee4c8f123c889786aa8e --hash=sha256:cb55c73c5f4408779d0cf3eef9f762b9c9f147a77de7b258bef0a5628adc85cc --hash=sha256:cd45e1dc79c835ce60f7404ec8119f2eb06d38b1deba146f07ced3bbc44505ff --hash=sha256:d3f5614314d758649ab2ab3a62d4f2004c825922f9e370b29416484086b264ec --hash=sha256:d920f33822747519673ee656a4b6ac33e382eca9d331c87770faa3eef562aeb2 --hash=sha256:db2b95f9de79181805df90bedc5a5ab4c165e6ec3fe99f970d0e302f384ad222 --hash=sha256:e59e304978767a54663af13c07b3d1af22ddee3bb2fb0618ca1593e4f593a106 --hash=sha256:e85e99945e688e32d5a35c1ff38ed0b3f41f43fad8df0bdf79f72b2ba7bc5272 --hash=sha256:ece47d672db52ac607a3d9599a9d48dcb2f2f735c6c2d1f34130085bb12b112a --hash=sha256:f4039b9cbc3048b2416cc57ab3bda989a6fcf9b36cf8937f01a6e731b64f80d7 8 | typing-extensions==4.13.1 --hash=sha256:4b6cf02909eb5495cfbc3f6e8fd49217e6cc7944e145cdda8caa3734777f9e69 --hash=sha256:98795af00fb9640edec5b8e31fc647597b4691f099ad75f469a2616be1a76dff 9 | wheel==0.45.0 --hash=sha256:52f0baa5e6522155090a09c6bd95718cc46956d1b51d537ea5454249edb671c7 --hash=sha256:a57353941a3183b3d5365346b567a260a0602a0f8a635926a7dede41b94c674a 10 | zipp==3.23.0 --hash=sha256:071652d6115ed432f5ce1d34c336c0adfd6a884660d1e9712a256d3d3bd4b14e --hash=sha256:a07157588a12518c9d4034df3fbbee09c814741a33ff63c05fa29d26a2404166 11 | pip==25.3 --hash=sha256:8d0538dbbd7babbd207f261ed969c65de439f6bc9e5dbd3b3b9a77f25d95f343 --hash=sha256:9655943313a94722b7774661c21049070f6bbb0a1516bf02f7c8d5d9201514cd 12 | setuptools==80.9.0 --hash=sha256:062d34222ad13e0cc312a4c02d73f059e86a4acbfbdea8f8f76b28c99f306922 --hash=sha256:f36b47402ecde768dbfafc46e8e4207b4360c654f1f3bb84475f0a28628fb19c 13 | -------------------------------------------------------------------------------- /docs/code-of-conduct.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to making participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, gender identity and expression, level of 9 | experience, education, socio-economic status, nationality, personal appearance, 10 | race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or reject 41 | comments, commits, code, wiki edits, issues, and other contributions that are 42 | not aligned to this Code of Conduct, or to ban temporarily or permanently any 43 | contributor for other behaviors that they deem inappropriate, threatening, 44 | offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies both within project spaces and in public spaces 49 | when an individual is representing the project or its community. Examples of 50 | representing a project or community include using an official project e-mail 51 | address, posting via an official social media account, or acting as an appointed 52 | representative at an online or offline event. Representation of a project may be 53 | further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when the Project 56 | Steward has a reasonable belief that an individual's behavior may have a 57 | negative impact on the project or its community. 58 | 59 | ## Conflict Resolution 60 | 61 | We do not believe that all conflict is bad; healthy debate and disagreement 62 | often yield positive results. However, it is never okay to be disrespectful or 63 | to engage in behavior that violates the project’s code of conduct. 64 | 65 | If you see someone violating the code of conduct, you are encouraged to address 66 | the behavior directly with those involved. Many issues can be resolved quickly 67 | and easily, and this gives people more control over the outcome of their 68 | dispute. If you are unable to resolve the matter for any reason, or if the 69 | behavior is threatening or harassing, report it. We are dedicated to providing 70 | an environment where participants feel welcome and safe. 71 | 72 | Reports should be directed to *[PROJECT STEWARD NAME(s) AND EMAIL(s)]*, the 73 | Project Steward(s) for *[PROJECT NAME]*. It is the Project Steward’s duty to 74 | receive and address reported violations of the code of conduct. They will then 75 | work with a committee consisting of representatives from the Open Source 76 | Programs Office and the Google Open Source Strategy team. If for any reason you 77 | are uncomfortable reaching out to the Project Steward, please email 78 | [opensource@google.com](opensource@google.com). 79 | 80 | We will investigate every complaint, but you may not receive a direct response. 81 | We will use our discretion in determining when and how to follow up on reported 82 | incidents, which may range from not taking action to permanent expulsion from 83 | the project and project-sponsored spaces. We will notify the accused of the 84 | report and provide them an opportunity to discuss it before any action is taken. 85 | The identity of the reporter will be omitted from the details of the report 86 | supplied to the accused. In potentially harmful situations, such as ongoing 87 | harassment or threats to anyone's safety, we may take action without notice. 88 | 89 | ## Attribution 90 | 91 | This Code of Conduct is adapted from the Contributor Covenant, version 1.4, 92 | available at 93 | [https://www.contributor-covenant.org/version/1/4/code-of-conduct/](https://www.contributor-covenant.org/version/1/4/code-of-conduct/) 94 | -------------------------------------------------------------------------------- /tests/io/shard/test_shard_write_async.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from pathlib import Path 16 | import pytest 17 | 18 | import numpy as np 19 | 20 | from sedpack.io.metadata import Attribute, DatasetStructure 21 | from sedpack.io.types import ShardFileTypeT 22 | 23 | from sedpack.io.flatbuffer import IterateShardFlatBuffer 24 | from sedpack.io.npz import IterateShardNP 25 | from sedpack.io.shard.get_shard_writer import get_shard_writer 26 | 27 | pytest_plugins = ("pytest_asyncio",) 28 | 29 | 30 | async def shard_write_and_read(attributes: dict[str, 31 | np.ndarray], shard_file: Path, 32 | shard_file_type: ShardFileTypeT) -> None: 33 | dataset_structure = DatasetStructure( 34 | saved_data_description=[ 35 | Attribute( 36 | name=name, 37 | shape=value.shape[1:], 38 | dtype=str(value.dtype), 39 | ) for name, value in attributes.items() 40 | ], 41 | shard_file_type=shard_file_type, 42 | compression="", 43 | ) 44 | 45 | # Write data into the file. 46 | writer = get_shard_writer(dataset_structure=dataset_structure, 47 | shard_file=shard_file) 48 | one_value = next(iter(attributes.values())) # One of the values. 49 | for i in range(one_value.shape[0]): 50 | writer.write(values={ 51 | name: value[i] for name, value in attributes.items() 52 | }) 53 | writer.close() 54 | 55 | iterate_shard: IterateShardBase 56 | match shard_file_type: 57 | case "npz": 58 | iterate_shard = IterateShardNP(dataset_structure=dataset_structure, 59 | process_record=None) 60 | case "fb": 61 | iterate_shard = IterateShardFlatBuffer( 62 | dataset_structure=dataset_structure, process_record=None) 63 | case _: 64 | raise ValueError(f"Unknown {shard_file_type = }") 65 | 66 | # Read those back. 67 | seen: int = 0 68 | i: int = 0 69 | async for example in iterate_shard.iterate_shard_async(shard_file): 70 | for name, value in attributes.items(): 71 | np.testing.assert_allclose(example[name], value[i]) 72 | seen += 1 73 | i += 1 # manual enumerate 74 | assert seen == one_value.shape[0] 75 | 76 | 77 | @pytest.mark.asyncio 78 | async def test_async_npz_with_int(tmp_path): 79 | shard_file = tmp_path / "shard_file.npz" 80 | attributes = { 81 | "a": np.array([[13 + 512, 2, 3], [4, 5, 6]]), 82 | } 83 | await shard_write_and_read(attributes, shard_file, shard_file_type="npz") 84 | 85 | 86 | @pytest.mark.asyncio 87 | async def test_async_npz_with_float(tmp_path): 88 | shard_file = tmp_path / "shard_file.npz" 89 | attributes = { 90 | "a": np.random.uniform(size=(10, 15)), 91 | } 92 | await shard_write_and_read(attributes, shard_file, shard_file_type="npz") 93 | 94 | 95 | @pytest.mark.asyncio 96 | async def test_async_npz_mixed(tmp_path): 97 | shard_file = tmp_path / "shard_file.npz" 98 | attributes = { 99 | "a": np.random.uniform(size=(10, 15)), 100 | "b": np.random.uniform(size=(10, 25)), 101 | "c": np.random.randint(-5, 20, size=(10, 21)), 102 | } 103 | await shard_write_and_read(attributes, shard_file, shard_file_type="npz") 104 | 105 | 106 | @pytest.mark.asyncio 107 | async def test_async_fb_with_int(tmp_path): 108 | shard_file = tmp_path / "shard_file" 109 | attributes = { 110 | "a": np.array([[13 + 512, 2, 3], [4, 5, 6]]), 111 | } 112 | await shard_write_and_read(attributes, shard_file, shard_file_type="fb") 113 | 114 | 115 | @pytest.mark.asyncio 116 | async def test_async_fb_with_float(tmp_path): 117 | shard_file = tmp_path / "shard_file" 118 | attributes = { 119 | "a": np.random.uniform(size=(10, 15)), 120 | } 121 | await shard_write_and_read(attributes, shard_file, shard_file_type="fb") 122 | 123 | 124 | @pytest.mark.asyncio 125 | async def test_async_fb_mixed(tmp_path): 126 | shard_file = tmp_path / "shard_file" 127 | attributes = { 128 | "a": np.random.uniform(size=(10, 15)), 129 | "b": np.random.uniform(size=(10, 25)), 130 | "c": np.random.randint(-5, 20, size=(10, 21)), 131 | } 132 | await shard_write_and_read(attributes, shard_file, shard_file_type="fb") 133 | -------------------------------------------------------------------------------- /website/src/content/docs/tutorials/sca/dataset.mdx: -------------------------------------------------------------------------------- 1 | --- 2 | title: Converting the TinyAES Dataset into Sedpack 3 | description: Converting the TinyAES Dataset into Sedpack 4 | --- 5 | 6 | The TinyAES dataset is a dataset of power-measurements of a textbook [AES 7 | (Wikipedia)](https://de.wikipedia.org/wiki/Advanced_Encryption_Standard) 8 | implementation. It was introduced for the DefCon 27 demo: [A Hacker Guide To 9 | Deep Learning Based Side Channel 10 | Attacks](https://elie.net/talk/a-hackerguide-to-deep-learning-based-side-channel-attacks). 11 | In this tutorial we show how to convert the original data into the sedpack 12 | format. 13 | 14 | ## Original Dataset 15 | 16 | The dataset was captured from STM32F4 chips using the ChipWhisperer [CW308 UFO 17 | board](https://www.newae.com/chipwhisperer). The capture was asynchronous 18 | meaning that the oscilloscope clock signal and the target chip clock signal 19 | were not synchronized (oscilloscope was oversampling to cope with this fact). 20 | The oscilloscope used was PicoScope® 6404D. 21 | 22 | One can download all zipped files either by clicking the following link 23 | [datasets.zip 24 | (8.2GB)](https://storage.googleapis.com/scaaml-public/scaaml_intro/datasets.zip) 25 | or: 26 | 27 | ```bash 28 | wget https://storage.googleapis.com/scaaml-public/scaaml_intro/datasets.zip 29 | sha256sum datasets.zip # 4bf2c6defb79b40b30f01f488e83762396b56daad14a694f64916be2b665b2f8 30 | unzip datasets.zip 31 | ``` 32 | 33 | The original files were saved in the 34 | [NumPy](https://numpy.org/doc/1.26/reference/generated/numpy.savez.html) 35 | format. Each file having a constant key and balanced plaintexts (see [SCAAML 36 | Attack Point 37 | Iterators](https://google.github.io/scaaml/guides/capture/attack_point_iterators/) 38 | for explanation). This is completely fine other than: 39 | 40 | - There is no standard format -- if somebody wants to experiment with this 41 | dataset they have to discover the data format and code custom data loaders. 42 | Not complicated but tedious. 43 | - This approach does not scale well for huge datasets. 44 | 45 | ## Conversion to Sedpack 46 | 47 | The script 48 | [docs/tutorials/sca/tiny_aes.py](https://github.com/google/sedpack/blob/main/docs/tutorials/sca/tiny_aes.py) 49 | contains a function `convert_to_sedpack` which takes the directory with 50 | original NumPy files and converts it to a sedpack dataset. 51 | 52 | We save the following attributes: 53 | 54 | - `trace1`: float16, length 80_000 55 | - `key`: uint8, length 16 56 | - `plaintext`: uint8, length 16 57 | - `ciphertext`: uint8, length 16 58 | - `sub_bytes_out`: uint8, length 16 59 | - `sub_bytes_in`: uint8, length 16 60 | 61 | Some (most) of our scenarios will be dealing with profiled attacks. Where 62 | during the profiling phase we know the secret (`key`) and during attack we try 63 | to find it out. For convenience we also save the `ciphertext` (result of 64 | encryption of the given `plaintext` by the `key`) and first round of AES S-BOX 65 | inputs and outputs. These are for convenience and we could omit those. In a 66 | profiled setting (one device where we train and another device where we attack, 67 | these different phases are called the profiling phase and the attack phase) 68 | `ciphertext`, `sub_bytes_in`, `sub_bytes_out` could be considered to be 69 | redundant information. In the non-profiled setting (single device where we do 70 | not know the `key`) we would have access only to `plaintext` and `ciphertext`. 71 | 72 | Naturally if we wanted to save more attack points, e.g., last round of AES 73 | states, all inputs of the S-BOX, we could. Alternatively we can choose to 74 | compute those on the fly when / if we need them. 75 | 76 | The limitation of the TinyAES dataset is that all data was collected on a 77 | single device (both the original `train` and `test` splits). This is one of the 78 | reasons why this dataset is suitable for demonstration purposes only when one 79 | is considering profiled attacks. In a realistic scenario one would take `train` 80 | and `test` on the profiling chip and `holdout` on a different physical device 81 | (of the same type). We deal with the lack of the `holdout` split by splitting 82 | the original `test` into `test` and `holdout`. 83 | 84 | ### Iteration 85 | 86 | Let us check the data we just converted. 87 | 88 | ```python 89 | from sedpack.io import Dataset 90 | 91 | dataset = Dataset("tiny_aes_sedpack") 92 | for example in dataset.as_numpy_iterator(split="train", 93 | shuffle=0, 94 | repeat=False): 95 | print(example["key"]) 96 | ``` 97 | 98 | We should see groups of 256 examples with the exactly same `key` value. Same 99 | holds for `test` and `holdout` splits. 100 | 101 | And we plot a single trace just to be sure everything works as expected. 102 | 103 | ```python 104 | import matplotlib.pyplot as plt 105 | 106 | # Plot a single trace. 107 | plt.plot(next(iter(dataset.as_numpy_iterator(split="train")))["trace1"]) 108 | plt.savefig("tiny_aes_trace.png") 109 | ``` 110 | 111 | ![tiny_aes_trace.png](tiny_aes_trace.png) 112 | -------------------------------------------------------------------------------- /.github/workflows/pytest.yml: -------------------------------------------------------------------------------- 1 | name: pytest 2 | permissions: 3 | contents: read 4 | pull-requests: write 5 | on: 6 | pull_request: 7 | types: [opened, synchronize, reopened] 8 | paths: 9 | - '**/*.py' 10 | - '**/*.rs' 11 | - 'pytest.ini' 12 | schedule: 13 | - cron: 0 5 * * 1 # Every Monday at 5:00 UTC 14 | merge_group: # Needed for required workflows 15 | # Run after a review has been submitted (this is a required workflow which 16 | # might not be triggered when no code changes -- trigger before going to 17 | # merge queue). 18 | pull_request_review: 19 | types: [submitted] 20 | 21 | jobs: 22 | unittesting: 23 | runs-on: ${{ matrix.platform.runner }} 24 | strategy: 25 | matrix: 26 | # ubuntu-20.04-arm was not stable enough when testing 27 | platform: 28 | - runner: ubuntu-latest # x64 29 | - runner: windows-latest # x64 30 | - runner: macos-14 # arm64 31 | - runner: macos-15-intel # Intel 32 | - runner: macos-latest # arm64 33 | if: github.event_name != 'schedule' 34 | steps: 35 | - uses: actions/checkout@v6 36 | - name: Set up Python 3.10 37 | uses: actions/setup-python@v6 38 | with: 39 | python-version: '3.10' 40 | cache: 'pip' 41 | - name: Get pip cache directory 42 | id: pip-cache 43 | shell: bash 44 | run: | 45 | echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT 46 | - name: Use cached venv or create it 47 | uses: actions/cache/restore@v5 48 | id: cache 49 | with: 50 | path: ${{ steps.pip-cache.outputs.dir }} 51 | key: ${{ matrix.platform.runner }}-pip 52 | # Build a virtualenv, but only if it doesn't already exist 53 | - name: Populate pip cache 54 | # requirements.txt is not reliable since across different platforms and 55 | # their versions the pip package versions might vary. We regenerate it 56 | # again from pyproject.toml every time when pyproject.toml or 57 | # requirements.txt changes. The pinned versions in requirements.txt are 58 | # tested by coverage since that is running on ubuntu which is also used 59 | # to produce the main requirements.txt file. 60 | run: | 61 | pip install pip==25.2 # TODO(remove the pinning) pip-tools issue 2252 62 | pip install pip-tools 63 | pip-compile --generate-hashes --extra dev pyproject.toml > dev_requirements.txt 64 | pip install -r dev_requirements.txt 65 | if: steps.cache.outputs.cache-hit != 'true' 66 | - name: Save cache 67 | id: cache-save 68 | uses: actions/cache/save@v5 69 | with: 70 | path: ${{ steps.pip-cache.outputs.dir }} 71 | key: ${{ steps.cache.outputs.cache-primary-key }} 72 | if: steps.cache.outputs.cache-hit != 'true' 73 | - name: Install sedpack locally 74 | run: pip install --editable ".[dev]" 75 | - name: Running unit tests 76 | run: python -m pytest 77 | 78 | coverage: 79 | runs-on: ubuntu-22.04 80 | steps: 81 | - uses: actions/checkout@v6 82 | - name: Set up Python 3.10 83 | uses: actions/setup-python@v6 84 | with: 85 | python-version: '3.10' 86 | cache: 'pip' 87 | - name: Get pip cache directory 88 | id: pip-cache 89 | shell: bash 90 | run: | 91 | echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT 92 | - name: Use cached venv or create it 93 | uses: actions/cache/restore@v5 94 | id: cache 95 | with: 96 | path: ${{ steps.pip-cache.outputs.dir }} 97 | # The cache key depends on requirements.txt 98 | key: ${{ runner.os }}-pip-${{ hashFiles('requirements.txt') }} 99 | # Build a virtualenv, but only if it doesn't already exist 100 | - name: Populate pip cache 101 | run: | 102 | python -m pip install --require-hashes --no-deps -r requirements.txt 103 | if: steps.cache.outputs.cache-hit != 'true' 104 | - name: Save cache 105 | id: cache-save 106 | uses: actions/cache/save@v5 107 | with: 108 | path: ${{ steps.pip-cache.outputs.dir }} 109 | key: ${{ steps.cache.outputs.cache-primary-key }} 110 | if: steps.cache.outputs.cache-hit != 'true' 111 | - name: Installing test requirements and sedpack 112 | # Start by "installing" sedpack to be sure all dependencies are listed 113 | run: | 114 | pip install --editable ".[dev]" 115 | echo "PYTHONPATH=./src:$PYTHONPATH" >> $GITHUB_ENV 116 | - name: Install workflow dependencies 117 | run: pip install --upgrade pytest coverage 118 | - name: Running unit tests with coverage 119 | env: 120 | DISABLE_AUTOGRAPH: 1 121 | # TODO remove the -i (ignore errors) 122 | run: | 123 | coverage run -m pytest 124 | coverage xml -i 125 | - name: Upload results 126 | uses: coverallsapp/github-action@648a8eb78e6d50909eff900e4ec85cab4524a45b # v2 127 | with: 128 | file: coverage.xml 129 | 130 | -------------------------------------------------------------------------------- /docs/tutorials/quick_start/mnist_save.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023-2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Download the MNIST dataset and save it in dataset-lib format. For a tutorial 15 | with explanations see: https://google.github.io/sedpack/tutorials/mnist 16 | 17 | Example use: 18 | python mnist_save.py -d "~/Datasets/mnist_dataset/" 19 | python mnist_read_keras.py -d "~/Datasets/mnist_dataset/" 20 | """ 21 | 22 | import argparse 23 | import random 24 | from typing import get_args 25 | 26 | from tensorflow import keras 27 | from tqdm import tqdm 28 | 29 | from sedpack.io import Dataset, Metadata, DatasetStructure, Attribute 30 | from sedpack.io.types import SplitT 31 | from sedpack.io.types import CompressionT, ShardFileTypeT 32 | 33 | 34 | def main() -> None: 35 | """Convert the MNIST dataset into sedpack format. 36 | """ 37 | parser = argparse.ArgumentParser( 38 | description="Convert MNIST dataset into dataset-lib format") 39 | parser.add_argument("--dataset_directory", 40 | "-d", 41 | help="Where to save the dataset", 42 | required=True) 43 | parser.add_argument("--compression", 44 | "-c", 45 | help="Which compression algorithm to use", 46 | default="GZIP", 47 | choices=get_args(CompressionT)) 48 | parser.add_argument("--shard_file_type", 49 | "-f", 50 | help="Which shard file type to use", 51 | default="tfrec", 52 | choices=get_args(ShardFileTypeT)) 53 | args = parser.parse_args() 54 | 55 | # General info about the dataset 56 | metadata = Metadata( 57 | description="MNIST dataset in the sedpack format", 58 | dataset_license=""" 59 | Yann LeCun and Corinna Cortes hold the copyright of MNIST dataset, which is 60 | a derivative work from original NIST datasets. MNIST dataset is made 61 | available under the terms of the Creative Commons Attribution-Share Alike 62 | 3.0 license. 63 | """, 64 | custom_metadata={ 65 | "list of authors": ["Yann LeCun", "Corinna Cortes"], 66 | }, 67 | ) 68 | 69 | # Types of attributes stored 70 | dataset_structure = DatasetStructure( 71 | saved_data_description=[ 72 | Attribute( 73 | name="input", 74 | shape=(28, 28), 75 | dtype="float32", 76 | ), 77 | Attribute( 78 | name="digit", 79 | shape=(), 80 | dtype="uint8", 81 | ), 82 | ], 83 | compression=args.compression, 84 | shard_file_type=args.shard_file_type, 85 | ) 86 | 87 | # Create a new dataset 88 | dataset = Dataset.create( 89 | path=args.dataset_directory, # All files are stored here 90 | metadata=metadata, 91 | dataset_structure=dataset_structure, 92 | ) 93 | 94 | # Fill in examples 95 | (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data() 96 | x_train = x_train.astype("float32") / 256 97 | x_test = x_test.astype("float32") / 256 98 | 99 | # DatasetFiller makes sure that all shard files are written properly 100 | # when exiting the context. 101 | with dataset.filler() as dataset_filler: 102 | # Determine which data are in the holdout (test) 103 | for i in tqdm(range(len(x_test)), desc="holdout"): 104 | dataset_filler.write_example( 105 | values={ 106 | "input": x_test[i], 107 | "digit": y_test[i], 108 | }, 109 | split="holdout", 110 | ) 111 | 112 | # Randomly assign 10% of validation and the rest is training. 113 | assert len(x_train) == len(y_train) 114 | train_indices: list[int] = list(range(len(x_train))) 115 | random.shuffle(train_indices) 116 | validation_split_position: int = int(len(x_train) * 0.1) 117 | for index_position, index in enumerate( 118 | tqdm(train_indices, desc="train and val")): 119 | 120 | # Assign to either train or test (aka validation). 121 | split: SplitT = "test" 122 | if index_position < validation_split_position: 123 | split = "train" 124 | 125 | # Write the example. 126 | dataset_filler.write_example( 127 | values={ 128 | "input": x_train[index], 129 | "digit": y_train[index], 130 | }, 131 | split=split, 132 | ) 133 | 134 | 135 | if __name__ == "__main__": 136 | main() 137 | -------------------------------------------------------------------------------- /tests/io/test_continue_writing.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | from pathlib import Path 17 | import random 18 | from typing import Union 19 | 20 | import numpy as np 21 | 22 | import sedpack 23 | from sedpack.io import Dataset 24 | from sedpack.io import Metadata 25 | 26 | 27 | def get_dataset(tmpdir: Union[str, Path]) -> Dataset: 28 | tiny_experiment_path: Path = Path(tmpdir) / "e2e_experiment" 29 | 30 | # Create a dataset 31 | 32 | dataset_metadata = Metadata(description="Test of the lib") 33 | 34 | example_attributes = [ 35 | sedpack.io.metadata.Attribute( 36 | name="attribute_name", 37 | dtype="float32", 38 | shape=(138,), 39 | ), 40 | ] 41 | 42 | dataset_structure = sedpack.io.metadata.DatasetStructure( 43 | saved_data_description=example_attributes,) 44 | 45 | dataset = Dataset.create( 46 | path=tiny_experiment_path, 47 | metadata=dataset_metadata, 48 | dataset_structure=dataset_structure, 49 | ) 50 | 51 | return dataset 52 | 53 | 54 | def fill(dataset, split, data): 55 | # Fill data in the dataset 56 | with dataset.filler() as filler: 57 | for attribute_value in data: 58 | filler.write_example( 59 | values={"attribute_name": attribute_value}, 60 | split=split, 61 | ) 62 | 63 | # Check the data is correct 64 | # Reopen the dataset 65 | dataset = Dataset(dataset.path) 66 | dataset.check() 67 | 68 | 69 | def check_presence(dataset, split, data): 70 | for i, example in enumerate( 71 | dataset.as_numpy_iterator( 72 | split=split, 73 | shuffle=0, 74 | repeat=False, 75 | )): 76 | assert np.allclose(example["attribute_name"], data[i:i + 1]) 77 | 78 | # We tested everything 79 | assert i + 1 == data.shape[0], "Not all examples have been iterated" 80 | 81 | 82 | def test_continue_writing_another_split(tmpdir: Union[str, Path]) -> None: 83 | """Check that we can write more examples into empty / single split. This 84 | would uncover the bug addressed by merging updates info. 85 | """ 86 | data_train = np.random.random((1024, 138)).astype(dtype=np.float32) 87 | filled_train: int = 0 88 | data_test = np.random.random((1024, 138)).astype(dtype=np.float32) 89 | filled_test: int = 0 90 | 91 | dataset = get_dataset(tmpdir) 92 | 93 | fill_now: int = 20 94 | 95 | for _ in range(4): 96 | fill_now = random.randint(10, 50) 97 | fill( 98 | dataset=dataset, 99 | split="train", 100 | data=data_train[filled_train:filled_train + fill_now], 101 | ) 102 | filled_train += fill_now 103 | 104 | fill_now = random.randint(10, 50) 105 | fill( 106 | dataset=dataset, 107 | split="test", 108 | data=data_test[filled_test:filled_test + fill_now], 109 | ) 110 | filled_test += fill_now 111 | 112 | # Both splits are present after writing (not just directly after 113 | # writing into one). 114 | check_presence(dataset, "train", data_train[:filled_train]) 115 | check_presence(Dataset(dataset.path), "train", 116 | data_train[:filled_train]) 117 | assert dataset._dataset_info.splits[ 118 | "train"].number_of_examples == filled_train 119 | check_presence(dataset, "test", data_test[:filled_test]) 120 | check_presence(Dataset(dataset.path), "test", data_test[:filled_test]) 121 | assert dataset._dataset_info.splits[ 122 | "test"].number_of_examples == filled_test 123 | 124 | 125 | def test_local_root_path(tmpdir: Union[str, Path]) -> None: 126 | """Check that relative path checks work even when dataset root path is 127 | local. For this we need to write multiple times in the dataset. 128 | """ 129 | # Change the working directory to be in the /tmp/pytest-of-user/ 130 | os.chdir(tmpdir) 131 | 132 | data_train = np.random.random((1024, 138)).astype(dtype=np.float32) 133 | filled_train: int = 0 134 | data_test = np.random.random((1024, 138)).astype(dtype=np.float32) 135 | filled_test: int = 0 136 | 137 | dataset = get_dataset("my_dataset") 138 | 139 | fill_now: int = 20 140 | 141 | for _ in range(4): 142 | fill_now = random.randint(10, 50) 143 | fill( 144 | dataset=dataset, 145 | split="train", 146 | data=data_train[filled_train:filled_train + fill_now], 147 | ) 148 | filled_train += fill_now 149 | 150 | fill_now = random.randint(10, 50) 151 | fill( 152 | dataset=dataset, 153 | split="test", 154 | data=data_test[filled_test:filled_test + fill_now], 155 | ) 156 | filled_test += fill_now 157 | -------------------------------------------------------------------------------- /src/sedpack/io/merge_shard_infos.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Merge new shard lists into the dataset. 15 | """ 16 | 17 | from collections import defaultdict 18 | from pathlib import Path 19 | 20 | from sedpack.io.shard_file_metadata import ShardsList, ShardListInfo 21 | from sedpack.io.types import HashChecksumT 22 | 23 | 24 | def merge_shard_infos(updates: list[ShardListInfo], dataset_root: Path, 25 | common: int, hashes: tuple[HashChecksumT, 26 | ...]) -> ShardListInfo: 27 | """Merge a list of new `ShardListInfo`s into the dataset. 28 | 29 | Args: 30 | 31 | updates (list[ShardListInfo]): New shards lists information to merge. All 32 | of these belonging to the same split. 33 | 34 | dataset_root (Path): Where the dataset is saved. 35 | 36 | common (int): A positive integer indicating how many directories deep are 37 | shared between all `ShardListInfo`s (relative from the `dataset_root`). 38 | When calling this function all `updates` have to be in the same split, 39 | thus one should set `common=1`. It is not guaranteed to update 40 | `shards_list.json` files all the way to the split when `common>1`. 41 | 42 | hashes (tuple[HashChecksumT, ...]): A tuple of hash checksum algorithms. 43 | """ 44 | assert updates, "Nothing to update." 45 | 46 | # Check that all common prefixes are the same. 47 | for update in updates: 48 | if update.shard_list_info_file.file_path.parts[:common] != updates[ 49 | 0].shard_list_info_file.file_path.parts[:common]: 50 | raise ValueError( 51 | f"Not all relative paths are the same " 52 | f"{update.shard_list_info_file.file_path.parts[:common]} vs " 53 | f"{updates[0].shard_list_info_file.file_path.parts[:common]}") 54 | 55 | # The current level ShardsList (if it is in the updates). 56 | root_shard_list: ShardsList = ShardsList.load_or_create( 57 | dataset_root_path=dataset_root, 58 | relative_path_self=Path().joinpath( 59 | *updates[0].shard_list_info_file.file_path.parts[:common]) / 60 | "shards_list.json", 61 | ) 62 | 63 | # Divide the updates to current level shard lists and the deeper ones. 64 | current_level: list[ShardListInfo] = [ 65 | update for update in updates 66 | if len(update.shard_list_info_file.file_path.parts) == common + 1 67 | ] 68 | deeper_updates: list[ShardListInfo] = [ 69 | update for update in updates 70 | if len(update.shard_list_info_file.file_path.parts) > common + 1 71 | ] 72 | # Check correctness of this implementation (O(1) check just to make sure we 73 | # do not forget anything). 74 | assert len(current_level) + len(deeper_updates) == len(updates) 75 | # Since the ShardsList is saved in a file named shards_list.json there can 76 | # be at most one update in this depth. We can ignore it since it has been 77 | # loaded into root_shard_list. 78 | assert len(current_level) <= 1 79 | 80 | # Move children of root_shard_list into deeper_updates to let recursion 81 | # merge everything. The updates are listed after the already present. 82 | for child in root_shard_list.children_shard_lists: 83 | root_shard_list.number_of_examples -= child.number_of_examples 84 | deeper_updates = root_shard_list.children_shard_lists + deeper_updates 85 | root_shard_list.children_shard_lists = [] 86 | 87 | # Recursively update children with one longer common prefix. 88 | # Sort the infos by common directory. 89 | recursively_update: defaultdict[str, 90 | list[ShardListInfo]] = defaultdict(list) 91 | for update in deeper_updates: 92 | # The path is at least `split / $DIRECTORY / shards_list.json` or 93 | # longer. 94 | current_path: Path = update.shard_list_info_file.file_path 95 | directory = str(current_path.parts[common]) 96 | recursively_update[directory].append(update) 97 | 98 | # Recursively update. 99 | merged: dict[str, ShardListInfo] = { # Merge recursively. 100 | directory: 101 | merge_shard_infos(updates=recursive_updates, 102 | dataset_root=dataset_root, 103 | common=common + 1, 104 | hashes=hashes) 105 | for directory, recursive_updates in recursively_update.items() 106 | } 107 | 108 | # Merge the recursive into root_shard_list. 109 | for child in merged.values(): 110 | root_shard_list.number_of_examples += child.number_of_examples 111 | root_shard_list.children_shard_lists.append(child) 112 | 113 | # Write the root_shard_list and return its ShardListInfo. 114 | return root_shard_list.write_config( 115 | dataset_root_path=dataset_root, 116 | hashes=hashes, 117 | ) 118 | -------------------------------------------------------------------------------- /tests/io/test_end2end_wrong_type.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023-2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from pathlib import Path 16 | import pytest 17 | from typing import Union 18 | 19 | import numpy as np 20 | import numpy.typing as npt 21 | 22 | import sedpack 23 | from sedpack.io import Dataset 24 | from sedpack.io import Metadata 25 | from sedpack.io.types import TRAIN_SPLIT, CompressionT, ShardFileTypeT 26 | 27 | 28 | def end2end(tmpdir: Union[str, Path], input_dtype: npt.DTypeLike, 29 | saved_dtype: npt.DTypeLike, method: str, 30 | shard_file_type: ShardFileTypeT, compression: CompressionT) -> None: 31 | array_of_values = np.random.random((10, 138)) * 256 32 | array_of_values = np.array(array_of_values, dtype=input_dtype) 33 | #print(f">>> {array_of_values.dtype = }") 34 | #print(f">>> {array_of_values.shape = }") 35 | #print(f">>> {array_of_values[0, :10] = }") 36 | #print(f">>> {array_of_values.tobytes() = }") 37 | 38 | tiny_experiment_path: Path = Path(tmpdir) / "e2e_experiment" 39 | 40 | # Create a dataset 41 | 42 | dataset_metadata = Metadata(description="Test of the lib") 43 | 44 | example_attributes = [ 45 | sedpack.io.metadata.Attribute( 46 | name="filling_before", 47 | dtype=str(array_of_values.dtype), 48 | shape=array_of_values.shape, 49 | ), 50 | sedpack.io.metadata.Attribute( 51 | name="attribute_name", 52 | dtype=str(saved_dtype), 53 | shape=array_of_values[0].shape, 54 | ), 55 | sedpack.io.metadata.Attribute( 56 | name="filling_after", 57 | dtype=str(array_of_values.dtype), 58 | shape=array_of_values.shape, 59 | ), 60 | ] 61 | 62 | dataset_structure = sedpack.io.metadata.DatasetStructure( 63 | saved_data_description=example_attributes, 64 | compression=compression, 65 | examples_per_shard=256, 66 | shard_file_type=shard_file_type, 67 | ) 68 | 69 | dataset = Dataset.create( 70 | path=tiny_experiment_path, 71 | metadata=dataset_metadata, 72 | dataset_structure=dataset_structure, 73 | ) 74 | 75 | # Fill data in the dataset 76 | 77 | with dataset.filler() as filler: 78 | for attribute_value in array_of_values: 79 | filler.write_example( 80 | values={ 81 | "filling_before": array_of_values, 82 | "attribute_name": attribute_value, 83 | "filling_after": array_of_values, 84 | }, 85 | split=TRAIN_SPLIT, 86 | ) 87 | 88 | # Check the data is correct 89 | # Reopen the dataset 90 | dataset = Dataset(tiny_experiment_path) 91 | dataset.check() 92 | 93 | match method: 94 | case "as_tfdataset": 95 | for i, example in enumerate( 96 | dataset.as_tfdataset( 97 | split=TRAIN_SPLIT, 98 | shuffle=0, 99 | repeat=False, 100 | batch_size=1, 101 | )): 102 | assert np.allclose(example["attribute_name"], 103 | array_of_values[i:i + 1]) 104 | case "as_numpy_iterator": 105 | for i, example in enumerate( 106 | dataset.as_numpy_iterator( 107 | split=TRAIN_SPLIT, 108 | shuffle=0, 109 | repeat=False, 110 | )): 111 | assert np.allclose(example["attribute_name"], 112 | array_of_values[i]) 113 | case "as_numpy_iterator_concurrent": 114 | for i, example in enumerate( 115 | dataset.as_numpy_iterator_concurrent( 116 | split=TRAIN_SPLIT, 117 | shuffle=0, 118 | repeat=False, 119 | )): 120 | assert np.allclose(example["attribute_name"], 121 | array_of_values[i]) 122 | 123 | # We tested everything 124 | assert i + 1 == array_of_values.shape[ 125 | 0], "Not all examples have been iterated" 126 | 127 | 128 | def test_end2end_wrong_value_type(tmpdir: Union[str, Path]) -> None: 129 | end2end(tmpdir=tmpdir, 130 | input_dtype="uint8", 131 | saved_dtype="int32", 132 | method="as_numpy_iterator", 133 | shard_file_type="fb", 134 | compression="LZ4") 135 | 136 | 137 | def test_end2end_wrong_value_type_no_cast(tmpdir: Union[str, Path]) -> None: 138 | with pytest.raises(ValueError) as e: 139 | end2end(tmpdir=tmpdir, 140 | input_dtype="float32", 141 | saved_dtype="uint8", 142 | method="as_numpy_iterator", 143 | shard_file_type="fb", 144 | compression="LZ4") 145 | -------------------------------------------------------------------------------- /tests/io/test_as_tfdataset.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023-2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Test all shard file types with as_tfdataset. 15 | """ 16 | 17 | import itertools 18 | from pathlib import Path 19 | from typing import Callable, Union 20 | 21 | import pytest 22 | import numpy as np 23 | import numpy.typing as npt 24 | 25 | import sedpack 26 | from sedpack.io import Dataset, Metadata 27 | from sedpack.io.shard.iterate_shard_base import T 28 | from sedpack.io.shard.shard_writer_flatbuffer import ShardWriterFlatBuffer 29 | from sedpack.io.shard.shard_writer_np import ShardWriterNP 30 | from sedpack.io.shard.shard_writer_tfrec import ShardWriterTFRec 31 | from sedpack.io.types import ( 32 | ExampleT, 33 | TRAIN_SPLIT, 34 | CompressionT, 35 | ShardFileTypeT, 36 | ) 37 | from sedpack.io.utils import is_module_present 38 | 39 | 40 | def end2end( 41 | tmpdir: Union[str, Path], 42 | dtype: npt.DTypeLike, 43 | shard_file_type: ShardFileTypeT, 44 | compression: CompressionT, 45 | process_record: Callable[[ExampleT], T] | None, 46 | ) -> None: 47 | array_of_values = np.random.random((1024, 138)) 48 | array_of_values = array_of_values.astype(dtype) 49 | 50 | tiny_experiment_path: Path = Path(tmpdir) / "e2e_experiment" 51 | 52 | # Create a dataset 53 | 54 | dataset_metadata = Metadata(description="Test of the lib") 55 | 56 | example_attributes = [ 57 | sedpack.io.metadata.Attribute( 58 | name="attribute_name", 59 | dtype=str(dtype), 60 | shape=array_of_values[0].shape, 61 | ), 62 | ] 63 | 64 | dataset_structure = sedpack.io.metadata.DatasetStructure( 65 | saved_data_description=example_attributes, 66 | compression=compression, 67 | examples_per_shard=256, 68 | shard_file_type=shard_file_type, 69 | ) 70 | 71 | dataset = Dataset.create( 72 | path=tiny_experiment_path, 73 | metadata=dataset_metadata, 74 | dataset_structure=dataset_structure, 75 | ) 76 | 77 | # Fill data in the dataset 78 | 79 | with dataset.filler() as filler: 80 | for attribute_value in array_of_values: 81 | filler.write_example( 82 | values={"attribute_name": attribute_value}, 83 | split=TRAIN_SPLIT, 84 | ) 85 | 86 | # Check the data is correct 87 | # Reopen the dataset 88 | dataset = Dataset(tiny_experiment_path) 89 | dataset.check() 90 | 91 | for i, example in enumerate( 92 | dataset.as_tfdataset( 93 | split=TRAIN_SPLIT, 94 | shuffle=0, 95 | repeat=False, 96 | batch_size=1, 97 | process_record=process_record, 98 | )): 99 | if process_record: 100 | assert np.allclose( 101 | example["attribute_name"], 102 | process_record({"attribute_name": array_of_values[i:i + 1] 103 | })["attribute_name"], 104 | ) 105 | else: 106 | assert np.allclose(example["attribute_name"], 107 | array_of_values[i:i + 1]) 108 | 109 | # We tested everything 110 | assert i + 1 == array_of_values.shape[ 111 | 0], "Not all examples have been iterated" 112 | 113 | 114 | @pytest.mark.skipif( 115 | not is_module_present("tensorflow"), 116 | reason="TensorFlow is optional, skip test if not present.", 117 | ) 118 | @pytest.mark.parametrize( 119 | "shard_file_type,compression,dtype,process_record", 120 | itertools.chain( 121 | itertools.product( 122 | ["tfrec"], 123 | ShardWriterTFRec.supported_compressions(), 124 | ["float16", "float32"], 125 | [ 126 | None, 127 | lambda d: { 128 | k: v + 1 for k, v in d.items() 129 | }, 130 | ], 131 | ), 132 | itertools.product( 133 | ["npz"], 134 | ShardWriterNP.supported_compressions(), 135 | ["float16", "float32"], 136 | [ 137 | None, 138 | lambda d: { 139 | k: v + 1 for k, v in d.items() 140 | }, 141 | ], 142 | ), 143 | itertools.product( 144 | ["fb"], 145 | ShardWriterFlatBuffer.supported_compressions(), 146 | ["float16", "float32"], 147 | [ 148 | None, 149 | lambda d: { 150 | k: v + 1 for k, v in d.items() 151 | }, 152 | ], 153 | ), 154 | ), 155 | ) 156 | def test_end2end_as_tfdataset( 157 | shard_file_type: str, 158 | compression: str, 159 | dtype: str, 160 | process_record: Callable[[ExampleT], T] | None, 161 | tmp_path: Union[str, Path], 162 | ) -> None: 163 | end2end( 164 | tmpdir=tmp_path, 165 | dtype=dtype, 166 | shard_file_type=shard_file_type, 167 | compression=compression, 168 | process_record=process_record, 169 | ) 170 | -------------------------------------------------------------------------------- /docs/tutorials/quick_start/mnist_read_keras.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023-2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Read MNIST data and feed it to a neural network. For a tutorial with 15 | explanations see: https://google.github.io/sedpack/tutorials/mnist 16 | 17 | Example use: 18 | python mnist_save.py -d "~/Datasets/my_new_dataset/" 19 | python mnist_read_keras.py -d "~/Datasets/my_new_dataset/" 20 | """ 21 | import argparse 22 | from typing import Any, Tuple 23 | 24 | import numpy as np 25 | import tensorflow as tf 26 | from tensorflow import keras 27 | from tensorflow.keras import layers 28 | 29 | from sedpack.io import Dataset 30 | from sedpack.io.types import ExampleT, TFModelT 31 | 32 | 33 | def get_model() -> TFModelT: 34 | """Return a CNN model. 35 | """ 36 | input_shape = (28, 28) 37 | num_classes = 10 38 | 39 | input_data = keras.Input(shape=input_shape, name="input") 40 | 41 | x = input_data 42 | x = layers.Reshape((*input_shape, 1))(x) 43 | x = layers.Conv2D(32, kernel_size=(3, 3), activation="relu")(x) 44 | x = layers.MaxPooling2D(pool_size=(2, 2))(x) 45 | x = layers.Conv2D(64, kernel_size=(3, 3), activation="relu")(x) 46 | x = layers.MaxPooling2D(pool_size=(2, 2))(x) 47 | x = layers.Flatten()(x) 48 | x = layers.Dropout(0.5)(x) 49 | x = layers.Dense(num_classes, activation="softmax", name="digit")(x) 50 | 51 | model: TFModelT = keras.Model(inputs=input_data, outputs=x) 52 | 53 | model.summary() 54 | model.compile(loss="categorical_crossentropy", 55 | optimizer="adam", 56 | metrics=["accuracy"]) 57 | return model 58 | 59 | 60 | def main() -> None: 61 | """Train a neural network on the MNIST dataset saved in the sedpack 62 | format. 63 | """ 64 | parser = argparse.ArgumentParser( 65 | description= 66 | "Read MNIST in dataset-lib format and train a small neural network.") 67 | parser.add_argument("--dataset_directory", 68 | "-d", 69 | help="Where to load the dataset", 70 | required=True) 71 | parser.add_argument("--ascii_evaluations", 72 | "-e", 73 | help="How many images to print and evaluate", 74 | type=int, 75 | default=10) 76 | args = parser.parse_args() 77 | 78 | # Load train and test and train 79 | model = get_model() 80 | 81 | dataset = Dataset(args.dataset_directory) # Load the dataset 82 | 83 | # ExampleT: TypeAlias of dict[str, sedpack.io.types.AttributeValueT] 84 | def process_record(rec: ExampleT) -> Tuple[Any, Any]: 85 | output = rec["digit"] 86 | output = tf.one_hot(output, 10) 87 | return rec["input"], output 88 | 89 | # Load train and validation splits of the dataset 90 | batch_size = 128 91 | train_data = dataset.as_tfdataset( 92 | "train", 93 | batch_size=batch_size, 94 | process_record=process_record, 95 | ) 96 | validation_data = dataset.as_tfdataset( 97 | "test", # validation split 98 | batch_size=batch_size, 99 | process_record=process_record, 100 | ) 101 | 102 | steps_per_epoch = 100 103 | epochs = 10 104 | _ = model.fit( 105 | train_data, 106 | steps_per_epoch=steps_per_epoch, 107 | epochs=epochs, 108 | validation_data=validation_data, 109 | validation_steps=steps_per_epoch // 10, 110 | ) 111 | 112 | # Evaluate the model on holdout. 113 | holdout_data = dataset.as_tfdataset( 114 | "holdout", 115 | batch_size=batch_size, 116 | process_record=process_record, 117 | repeat=False, # Single iteration over the dataset. 118 | ) 119 | score = model.evaluate( 120 | holdout_data, 121 | verbose=0, 122 | ) 123 | print(f"Test loss: {score[0]}") 124 | print(f"Test accuracy: {100 * score[1]:.2f}%") 125 | 126 | evaluated: int = 0 127 | ascii_shades = " .-/X0#" 128 | for example in dataset.as_numpy_iterator(split="holdout", 129 | process_record=process_record): 130 | # Stop after a few evaluations. 131 | evaluated += 1 132 | if evaluated >= args.ascii_evaluations: 133 | break 134 | 135 | # Pass just the input (the handwritten digit image) to the model and get 136 | # the predicted class as the class with highest probability. 137 | image: list[list[float]] = example[0] # type: ignore[assignment,index] 138 | # Note that the model still expects a batch, here we pass a batch of one 139 | # image. 140 | predicted_class: int = np.argmax(model(np.expand_dims( 141 | image, axis=0))) # type: ignore[assignment] 142 | correct_class: int = np.argmax( 143 | example[1]) # type: ignore[assignment,index] 144 | print("") 145 | print(f"Predicted: {predicted_class} (should be {correct_class}) for") 146 | # Turn into ASCII art 147 | for row in image: 148 | print("".join( 149 | ascii_shades[int(pixel * len(ascii_shades))] for pixel in row)) 150 | 151 | 152 | if __name__ == "__main__": 153 | main() 154 | --------------------------------------------------------------------------------