├── src
└── sedpack
│ ├── py.typed
│ ├── io
│ ├── flatbuffer
│ │ ├── shardfile
│ │ │ ├── __init__.py
│ │ │ ├── Example.py
│ │ │ ├── Shard.py
│ │ │ └── Attribute.py
│ │ ├── unit_tests
│ │ │ ├── shard_writer_flatbuffer_test_schema
│ │ │ │ └── __init__.py
│ │ │ └── shard_writer_flatbuffer_test_schema.fbs
│ │ ├── shard.fbs
│ │ └── __init__.py
│ ├── npz
│ │ └── __init__.py
│ ├── tfrec
│ │ ├── __init__.py
│ │ └── read.py
│ ├── shard
│ │ ├── __init__.py
│ │ ├── get_shard_writer.py
│ │ ├── iterate_shard_base.py
│ │ ├── shard.py
│ │ ├── shard_writer_tfrec.py
│ │ └── shard_writer_base.py
│ ├── iteration
│ │ └── __init__.py
│ ├── shard_info_iterator
│ │ ├── __init__.py
│ │ └── shard_info_iterator.py
│ ├── __init__.py
│ ├── itertools
│ │ └── __init__.py
│ ├── errors.py
│ ├── types.py
│ ├── compress.py
│ ├── dataset.py
│ └── merge_shard_infos.py
│ └── __init__.py
├── .style.yapf
├── website
├── tsconfig.json
├── src
│ ├── env.d.ts
│ ├── content
│ │ ├── docs
│ │ │ ├── tutorials
│ │ │ │ └── sca
│ │ │ │ │ ├── tiny_aes_trace.png
│ │ │ │ │ ├── tiny_aes_snr_sbi_0.png
│ │ │ │ │ ├── overview.mdx
│ │ │ │ │ └── dataset.mdx
│ │ │ ├── index.mdx
│ │ │ └── start_here
│ │ │ │ ├── intro.md
│ │ │ │ └── install.mdx
│ │ └── config.ts
│ └── mathjax.css
├── .gitignore
├── README.md
├── package.json
└── astro.config.mjs
├── .github
├── cspell_matcher.json
├── workflows
│ ├── mdlint.yml
│ ├── yapf.yml
│ ├── piptest.yml
│ ├── rust_ci.yml
│ ├── pylint.yml
│ ├── spellcheck.yml
│ ├── deploy.yml
│ ├── base_benchmarks.yml
│ ├── mypy.yml
│ ├── fork_pr_benchmarks_track.yml
│ ├── fork_pr_benchmarks_run.yml
│ └── pytest.yml
├── python_matcher.json
└── dependabot.yml
├── rust
├── rustfmt.toml
├── benches
│ ├── setup.sh
│ └── my_benchmark.rs
└── Cargo.toml
├── setup.py
├── cspell.json
├── tools
├── run_pylint.sh
└── check_copyright.sh
├── .markdownlint.json
├── tests
└── io
│ ├── shard
│ ├── test_shard_writer_base.py
│ └── test_shard_write_async.py
│ ├── itertools
│ ├── test_itertools.py
│ └── test_lazy_pool.py
│ ├── test_error.py
│ ├── test_dataset_base.py
│ ├── test_file_info.py
│ ├── iteration
│ └── test_rust_generator.py
│ ├── npz
│ └── test_npz_shards.py
│ ├── test_hash_checksums.py
│ ├── test_bytes.py
│ ├── test_compression.py
│ ├── test_rust_iter.py
│ ├── test_end2end_async.py
│ ├── test_write_multiprocessing.py
│ ├── shard_info_iterator
│ └── test_balanced_iterator.py
│ ├── test_shard_custom_metadata.py
│ ├── test_continue_writing.py
│ ├── test_end2end_wrong_type.py
│ └── test_as_tfdataset.py
├── docs
├── contributing.md
├── code-of-conduct.md
└── tutorials
│ └── quick_start
│ ├── mnist_save.py
│ └── mnist_read_keras.py
├── project-words.txt
├── mypy.ini
├── pyproject.toml
├── README.md
├── .gitignore
└── base-tooling-requirements.txt
/src/sedpack/py.typed:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/sedpack/io/flatbuffer/shardfile/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.style.yapf:
--------------------------------------------------------------------------------
1 | [style]
2 | based_on_style = google
3 |
--------------------------------------------------------------------------------
/website/tsconfig.json:
--------------------------------------------------------------------------------
1 | {
2 | "extends": "astro/tsconfigs/strict"
3 | }
--------------------------------------------------------------------------------
/src/sedpack/io/flatbuffer/unit_tests/shard_writer_flatbuffer_test_schema/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/website/src/env.d.ts:
--------------------------------------------------------------------------------
1 | ///
2 | ///
3 |
--------------------------------------------------------------------------------
/website/src/content/docs/tutorials/sca/tiny_aes_trace.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/sedpack/HEAD/website/src/content/docs/tutorials/sca/tiny_aes_trace.png
--------------------------------------------------------------------------------
/website/src/content/docs/tutorials/sca/tiny_aes_snr_sbi_0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/sedpack/HEAD/website/src/content/docs/tutorials/sca/tiny_aes_snr_sbi_0.png
--------------------------------------------------------------------------------
/website/src/mathjax.css:
--------------------------------------------------------------------------------
1 | /* Starlight styles all images in Markdown with `display: block` forcing
2 | * MathJax's inline SVG on new lines. Fix by:
3 | */
4 | mjx-container svg {
5 | display: inline !important;
6 | }
7 |
--------------------------------------------------------------------------------
/website/src/content/config.ts:
--------------------------------------------------------------------------------
1 | import { defineCollection } from 'astro:content';
2 | import { docsSchema } from '@astrojs/starlight/schema';
3 |
4 | export const collections = {
5 | docs: defineCollection({ schema: docsSchema() }),
6 | };
7 |
--------------------------------------------------------------------------------
/website/.gitignore:
--------------------------------------------------------------------------------
1 | # build output
2 | dist/
3 | # generated types
4 | .astro/
5 |
6 | # dependencies
7 | node_modules/
8 |
9 | # logs
10 | npm-debug.log*
11 | yarn-debug.log*
12 | yarn-error.log*
13 | pnpm-debug.log*
14 |
15 |
16 | # environment variables
17 | .env
18 | .env.production
19 |
20 | # macOS-specific files
21 | .DS_Store
22 |
--------------------------------------------------------------------------------
/website/src/content/docs/index.mdx:
--------------------------------------------------------------------------------
1 | ---
2 | title: Welcome to Sedpack
3 | description: Scalable and Efficient Data Packing
4 | template: splash
5 | hero:
6 | tagline: Start using Sedpack.
7 | actions:
8 | - text: Get started
9 | link: /sedpack/start_here/intro/
10 | icon: right-arrow
11 | ---
12 |
13 | import { Card, CardGrid } from '@astrojs/starlight/components';
14 |
--------------------------------------------------------------------------------
/website/src/content/docs/start_here/intro.md:
--------------------------------------------------------------------------------
1 | ---
2 | title: Sedpack Intro
3 | description: Sedpack - Scalable and efficient data packing
4 | ---
5 |
6 | Scalable and efficient data packing
7 |
8 | See [code samples](https://github.com/google/sedpack/tree/main/docs).
9 |
10 | The code is a major refactor of the data saving and loading code from the
11 | [SCAAML](https://github.com/google/scaaml) project.
12 |
--------------------------------------------------------------------------------
/.github/cspell_matcher.json:
--------------------------------------------------------------------------------
1 | {
2 | "problemMatcher": [
3 | {
4 | "owner": "spellcheck",
5 | "pattern": [
6 | {
7 | "regexp": "^([^:]+):(\\d+):([^:]+)\\s*-\\s*(.*)$",
8 | "file": 1,
9 | "line": 2,
10 | "column": 3,
11 | "message": 4
12 | }
13 | ]
14 | }
15 | ]
16 | }
17 |
--------------------------------------------------------------------------------
/rust/rustfmt.toml:
--------------------------------------------------------------------------------
1 | comment_width = 100
2 | fn_params_layout = "Compressed"
3 | format_code_in_doc_comments = true
4 | format_strings = true
5 | group_imports = "StdExternalCrate"
6 | imports_granularity = "Module"
7 | normalize_comments = true
8 | normalize_doc_attributes = true
9 | spaces_around_ranges = true
10 | use_small_heuristics = "Max"
11 | where_single_line = true
12 | wrap_comments = true
13 | ignore = [
14 | # Autogenerated files
15 | "src/shard_generated.rs",
16 | ]
17 |
--------------------------------------------------------------------------------
/website/README.md:
--------------------------------------------------------------------------------
1 | # Sedpack Documentation
2 |
3 | Website available at [google.github.io/sedpack](https://google.github.io/sedpack)
4 |
5 | ## Adding more documentation
6 |
7 | Local development preparation:
8 |
9 | - Install requirements: `sudo apt-get install nodejs npm`
10 | - Install project: `cd website/ ; npm install`
11 |
12 | Run local server: `npm run dev`
13 |
14 | ## Deploy to github pages
15 |
16 | Manually run the "Deploy to GitHub Pages" workflow (needs maintainer
17 | permission).
18 |
--------------------------------------------------------------------------------
/website/package.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "documentation",
3 | "type": "module",
4 | "version": "0.0.1",
5 | "scripts": {
6 | "dev": "astro dev",
7 | "start": "astro dev",
8 | "build": "astro check && astro build",
9 | "preview": "astro preview",
10 | "astro": "astro"
11 | },
12 | "dependencies": {
13 | "@astrojs/check": "^0.9.3",
14 | "@astrojs/starlight": "^0.37.0",
15 | "astro": "^5.16.0",
16 | "rehype-mathjax": "^7.1.0",
17 | "remark-math": "^6.0.0",
18 | "sharp": "^0.34.0",
19 | "typescript": "^5.9.2"
20 | }
21 | }
22 |
--------------------------------------------------------------------------------
/src/sedpack/io/flatbuffer/shard.fbs:
--------------------------------------------------------------------------------
1 | // Shard file schema using https://flatbuffers.dev/
2 |
3 | // Call `flatc --python sedpack/io/flatbuffer/shard.fbs` from the src/
4 | // directory otherwise the autogenerated code contains wrong imports. Beware
5 | // that this overwrites all __init__.py files on the path.
6 | // Also remember to update `sedpack_rs/src/shard_generated.rs`.
7 |
8 |
9 | namespace sedpack.io.flatbuffer.shardfile;
10 |
11 | table Attribute {
12 | attribute_bytes:[ubyte];
13 | }
14 |
15 | table Example {
16 | attributes:[Attribute];
17 | }
18 |
19 | table Shard {
20 | examples:[Example];
21 | }
22 |
23 | root_type Shard;
24 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Compatibility with legacy builds.
15 | """
16 |
17 | from setuptools import setup
18 |
19 | setup()
20 |
--------------------------------------------------------------------------------
/cspell.json:
--------------------------------------------------------------------------------
1 | {
2 | "$schema": "https://raw.githubusercontent.com/streetsidesoftware/cspell/main/cspell.schema.json",
3 | "version": "0.2",
4 | "dictionaryDefinitions": [
5 | {
6 | "name": "project-words",
7 | "path": "./project-words.txt",
8 | "addWords": true
9 | }
10 | ],
11 | "dictionaries": [
12 | "en-gb",
13 | "en_US",
14 | "project-words",
15 | "python"
16 | ],
17 | "ignorePaths": [
18 | ".pylintrc",
19 | "/.github",
20 | "/project-words.txt",
21 | "/tests",
22 | "base-tooling-requirements.txt",
23 | "cspell.json",
24 | "node_modules",
25 | "requirements.in",
26 | "requirements.txt"
27 | ]
28 | }
29 |
--------------------------------------------------------------------------------
/tools/run_pylint.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # Copyright 2022 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | # Ensure we are at the project root directory
17 | cd $(readlink -f $(dirname $0))/..
18 |
19 | pylint *.py src docs
20 |
--------------------------------------------------------------------------------
/src/sedpack/io/npz/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Dataset creation and usage."""
15 |
16 | from sedpack.io.npz.iterate_npz import IterateShardNP
17 |
18 | __all__ = [
19 | "IterateShardNP",
20 | ]
21 |
--------------------------------------------------------------------------------
/src/sedpack/io/tfrec/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Dataset creation and usage."""
15 |
16 | from sedpack.io.tfrec.read import IterateShardTFRec
17 |
18 | __all__ = [
19 | "IterateShardTFRec",
20 | ]
21 |
--------------------------------------------------------------------------------
/src/sedpack/io/flatbuffer/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Dataset creation and usage."""
15 |
16 | from sedpack.io.flatbuffer.iterate import IterateShardFlatBuffer
17 |
18 | __all__ = [
19 | "IterateShardFlatBuffer",
20 | ]
21 |
--------------------------------------------------------------------------------
/src/sedpack/io/shard/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023-2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Dataset creation and usage."""
15 |
16 | from sedpack.io.shard.iterate_shard_base import IterateShardBase
17 | from sedpack.io.shard.shard import Shard
18 |
19 | __all__ = [
20 | "IterateShardBase",
21 | "Shard",
22 | ]
23 |
--------------------------------------------------------------------------------
/.markdownlint.json:
--------------------------------------------------------------------------------
1 | {
2 | "default": true,
3 | "heading-style": {
4 | "style": "atx"
5 | },
6 | "no-trailing-spaces": {
7 | "br_spaces": 0,
8 | "strict": true
9 | },
10 | "ul-indent": {
11 | "indent": 4
12 | },
13 | "line-length": {
14 | "line_length": 80,
15 | "heading_line_length": 120,
16 | "tables": false,
17 | "code_blocks": false
18 | },
19 | "list-marker-space": {
20 | "ol_single": 2,
21 | "ol_multi": 2,
22 | "ul_single": 3,
23 | "ul_multi": 3
24 | },
25 | "no-inline-html": {
26 | "allowed_elements": [
27 | "img"
28 | ]
29 | },
30 | "fenced-code-language": true,
31 | "code-block-style": {
32 | "style": "fenced"
33 | },
34 | "code-fence-style": {
35 | "style": "backtick"
36 | }
37 | }
38 |
--------------------------------------------------------------------------------
/.github/workflows/mdlint.yml:
--------------------------------------------------------------------------------
1 | name: markdownlint
2 | on:
3 | pull_request:
4 | types: [opened, synchronize, reopened]
5 | paths:
6 | - '**/*.md'
7 | - '**/*.mdx'
8 | - '.markdownlint.json'
9 | merge_group: # Needed for required workflows
10 | # Run after a review has been submitted (this is a required workflow which
11 | # might not be triggered when no code changes -- trigger before going to
12 | # merge queue).
13 | pull_request_review:
14 | types: [submitted]
15 |
16 | jobs:
17 | lint:
18 | permissions:
19 | contents: read
20 | pull-requests: write
21 | runs-on: ubuntu-latest
22 | steps:
23 | - uses: actions/checkout@v6
24 | - uses: DavidAnson/markdownlint-cli2-action@07035fd053f7be764496c0f8d8f9f41f98305101 # v20
25 | with:
26 | config: .markdownlint.json
27 | globs: |
28 | *.md
29 | **/*.md
30 |
--------------------------------------------------------------------------------
/src/sedpack/io/iteration/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Dataset iteration."""
15 |
16 | from sedpack.io.iteration.rust_batched_generator import RustBatchedGenerator
17 | from sedpack.io.iteration.rust_generator import RustGenerator
18 |
19 | __all__ = [
20 | "RustBatchedGenerator",
21 | "RustGenerator",
22 | ]
23 |
--------------------------------------------------------------------------------
/rust/benches/setup.sh:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # Setup datasets for Rust benchmarking. Call `bash benches/setup.sh` from the
16 | # `sedpack/rust` directory.
17 |
18 | python3 ../docs/tutorials/quick_start/mnist_save.py \
19 | --dataset_directory mnist_fb_gzip \
20 | --shard_file_type fb \
21 | --compression GZIP
22 |
--------------------------------------------------------------------------------
/tests/io/shard/test_shard_writer_base.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from typing import get_args
16 |
17 | from sedpack.io.types import ShardFileTypeT
18 | from sedpack.io.shard.get_shard_writer import _SHARD_FILE_TYPE_TO_CLASS
19 |
20 |
21 | def test_all_file_types_supported():
22 | assert set(_SHARD_FILE_TYPE_TO_CLASS.keys()) == set(
23 | get_args(ShardFileTypeT))
24 |
--------------------------------------------------------------------------------
/src/sedpack/io/shard_info_iterator/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Iterating of shard information."""
15 |
16 | from sedpack.io.shard_info_iterator.shard_info_iterator import ShardInfoIterator
17 | from sedpack.io.shard_info_iterator.cached_shard_info_iterator import CachedShardInfoIterator
18 |
19 | __all__ = [
20 | "CachedShardInfoIterator",
21 | "ShardInfoIterator",
22 | ]
23 |
--------------------------------------------------------------------------------
/src/sedpack/io/flatbuffer/unit_tests/shard_writer_flatbuffer_test_schema.fbs:
--------------------------------------------------------------------------------
1 | // Shard file schema using https://flatbuffers.dev/
2 |
3 | // Call `flatc --python shard_writer_flatbuffer_test_schema.fbs` from the
4 | // tests/io/shard directory otherwise the autogenerated code contains wrong
5 | // imports. Beware that this overwrites all __init__.py files on the path.
6 |
7 |
8 | namespace shard_writer_flatbuffer_test_schema;
9 |
10 | table NumPyVectorTest {
11 | // Vectors of types according to
12 | // https://flatbuffers.dev/flatbuffers_guide_writing_schema.html
13 |
14 | // 8 bit
15 | attribute_bool:[byte];
16 | attribute_byte:[byte];
17 | attribute_ubyte:[byte];
18 |
19 | // 16 bit
20 | attribute_short:[byte];
21 | attribute_ushort:[byte];
22 |
23 | // 32 bit
24 | attribute_int:[byte];
25 | attribute_uint:[byte];
26 | attribute_float:[byte];
27 |
28 | // 64 bit
29 | attribute_long:[byte];
30 | attribute_ulong:[byte];
31 | attribute_double:[byte];
32 | }
33 |
34 | root_type NumPyVectorTest;
35 |
--------------------------------------------------------------------------------
/src/sedpack/io/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023-2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Dataset creation and usage."""
15 |
16 | from sedpack.io.dataset import Dataset
17 | from sedpack.io.dataset_filler import DatasetFiller, DatasetFillerContext
18 | from sedpack.io.metadata import Attribute
19 | from sedpack.io.metadata import DatasetStructure
20 | from sedpack.io.metadata import Metadata
21 |
22 | __all__ = [
23 | "Attribute",
24 | "Dataset",
25 | "DatasetFiller",
26 | "DatasetFillerContext",
27 | "DatasetStructure",
28 | "Metadata",
29 | ]
30 |
--------------------------------------------------------------------------------
/src/sedpack/io/itertools/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Dataset creation and usage."""
15 |
16 | from sedpack.io.itertools.itertools import shuffle_buffer, shuffle_buffer_async
17 | from sedpack.io.itertools.itertools import round_robin, round_robin_async
18 | from sedpack.io.itertools.lazy_pool import LazyPool
19 | #from sedpack.io.itertools.lazy_pool_multiprocessing import LazyPool
20 |
21 | __all__ = [
22 | "shuffle_buffer",
23 | "shuffle_buffer_async",
24 | "round_robin",
25 | "round_robin_async",
26 | "LazyPool",
27 | ]
28 |
--------------------------------------------------------------------------------
/.github/workflows/yapf.yml:
--------------------------------------------------------------------------------
1 | name: yapf
2 | permissions:
3 | contents: read
4 | pull-requests: write
5 | on:
6 | pull_request:
7 | types: [opened, synchronize, reopened]
8 | paths:
9 | - '**/*.py'
10 | merge_group: # Needed for required workflows
11 | # Run after a review has been submitted (this is a required workflow which
12 | # might not be triggered when no code changes -- trigger before going to
13 | # merge queue).
14 | pull_request_review:
15 | types: [submitted]
16 |
17 | jobs:
18 | yapf:
19 | runs-on: ubuntu-22.04
20 | steps:
21 | - uses: actions/checkout@v6
22 | - name: Set up Python 3.10
23 | uses: actions/setup-python@v6
24 | with:
25 | python-version: '3.10'
26 | cache: 'pip'
27 | - name: Install dependencies
28 | run: |
29 | python -m pip install --upgrade pip setuptools wheel
30 | pip install --upgrade 'yapf>=0.30.0'
31 | - name: Register matcher
32 | run:
33 | echo ::add-matcher::./.github/python_matcher.json
34 | - name: Test code formatting with yapf
35 | run:
36 | yapf --recursive --diff .
37 |
--------------------------------------------------------------------------------
/docs/contributing.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | We would love to accept your patches and contributions to this project.
4 |
5 | ## Before you begin
6 |
7 | ### Sign our Contributor License Agreement
8 |
9 | Contributions to this project must be accompanied by a
10 | [Contributor License Agreement](https://cla.developers.google.com/about) (CLA).
11 | You (or your employer) retain the copyright to your contribution; this simply
12 | gives us permission to use and redistribute your contributions as part of the
13 | project.
14 |
15 | If you or your current employer have already signed the Google CLA (even if it
16 | was for a different project), you probably don't need to do it again.
17 |
18 | Visit to see your current agreements or to
19 | sign a new one.
20 |
21 | ### Review our Community Guidelines
22 |
23 | This project follows [Google's Open Source Community
24 | Guidelines](https://opensource.google/conduct/).
25 |
26 | ## Contribution process
27 |
28 | ### Code Reviews
29 |
30 | All submissions, including submissions by project members, require review. We
31 | use [GitHub pull requests](https://docs.github.com/articles/about-pull-requests)
32 | for this purpose.
33 |
--------------------------------------------------------------------------------
/tools/check_copyright.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # Copyright 2023-2024 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
17 | errors=0
18 | e() {
19 | echo -e "$(tput bold)$(tput setaf 1)Error:$(tput sgr0) $*"
20 | errors=$(( $error + 1 ))
21 | }
22 |
23 | # Files we want to check for copyright
24 | EXTENSIONS="py\|sh"
25 |
26 | EXCLUDE_FILES="src/sedpack/io/flatbuffer/shardfile/*\.py"
27 |
28 | for file in $(git ls-files | \
29 | grep -e '\.\('"${EXTENSIONS}"'\)$' | \
30 | grep -v -e '^\('"${EXCLUDE_FILES}"'\)$')
31 | do
32 | sed -n 'N;/Copyright/q;q1' $file || e "No copyright notice in $file"
33 | done
34 |
35 | if [ $errors -gt 0 ]
36 | then
37 | exit 1
38 | fi
39 | exit 0
40 |
41 |
--------------------------------------------------------------------------------
/.github/workflows/piptest.yml:
--------------------------------------------------------------------------------
1 | name: piptest
2 | permissions:
3 | contents: read
4 | pull-requests: write
5 | on:
6 | pull_request:
7 | types: [opened, synchronize, reopened]
8 | paths:
9 | - 'docs/**/*.py'
10 | - 'pytest.ini'
11 | schedule:
12 | - cron: 0 5 * * 1 # Every Monday at 5:00 UTC
13 |
14 | jobs:
15 | piptesting:
16 | runs-on: ${{ matrix.platform.runner }}
17 | strategy:
18 | matrix:
19 | # ubuntu-24.04-arm is not stable enough
20 | platform:
21 | - runner: ubuntu-latest # x64
22 | - runner: windows-latest # x64
23 | - runner: macos-14 # arm64
24 | - runner: macos-15-intel # Intel
25 | - runner: macos-latest # arm64
26 | steps:
27 | - uses: actions/checkout@v6
28 | - name: Set up Python 3.10
29 | uses: actions/setup-python@v6
30 | with:
31 | python-version: '3.10'
32 | cache: 'pip'
33 | - name: Installing sedpack pip package
34 | run: |
35 | pip install sedpack
36 | - name: Run tutorial using sedpack pip package
37 | run: |
38 | python docs/tutorials/quick_start/mnist_save.py -d mnist_dataset
39 | python docs/tutorials/quick_start/mnist_read_keras.py -d mnist_dataset
40 |
--------------------------------------------------------------------------------
/.github/workflows/rust_ci.yml:
--------------------------------------------------------------------------------
1 | name: Rust CI
2 | on:
3 | pull_request:
4 | types: [opened, synchronize, reopened]
5 | paths:
6 | - '**/*.rs'
7 |
8 | permissions:
9 | contents: read
10 |
11 | jobs:
12 | fmt:
13 | runs-on: ubuntu-latest
14 | defaults:
15 | run:
16 | working-directory: ./rust
17 | steps:
18 | - name: Checkout code
19 | uses: actions/checkout@v6
20 | # Ensure rustfmt is installed and setup problem matcher
21 | - uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 # v1
22 | with:
23 | components: rustfmt
24 | toolchain: nightly
25 | matcher: true
26 | - name: fmt
27 | run: cargo +nightly fmt -- --check
28 | clippy:
29 | runs-on: ubuntu-latest
30 | defaults:
31 | run:
32 | working-directory: ./rust
33 | steps:
34 | - name: Checkout code
35 | uses: actions/checkout@v6
36 | # Ensure clippy is installed and setup problem matcher
37 | - uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 # v1
38 | with:
39 | components: clippy
40 | toolchain: nightly
41 | - name: clippy
42 | run: cargo +nightly clippy --all-targets
43 |
--------------------------------------------------------------------------------
/rust/Cargo.toml:
--------------------------------------------------------------------------------
1 | [package]
2 | name = "sedpack_rs"
3 | version = "0.1.4"
4 | edition = "2024"
5 | description = "Rust bindings for sedpack a general ML dataset package"
6 | authors = [
7 | "Elie Bursztein",
8 | "Karel Král",
9 | "Jean-Michel Picod",
10 | ]
11 | license = "Apache-2.0"
12 | # So far just a small subset of sedpack functionality. Not meant to be
13 | # published as a crate.
14 | publish = false
15 |
16 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
17 | [lib]
18 | name = "sedpack_rs"
19 | crate-type = ["cdylib", "rlib"]
20 |
21 | [dependencies]
22 | criterion = "0.8.1"
23 | flatbuffers = "25.9"
24 | flate2 = { version = "1.1.4" }
25 | glob = "0.3.3"
26 | lz4_flex = { version = "0.12.0", default-features = false , features = ["frame"] }
27 | numpy = "0.26"
28 | pyo3 = "0.26"
29 | rand = "0.9"
30 | rayon = "1.11.0"
31 | strum = "0.27.2"
32 | strum_macros = "0.27.2"
33 | tracing = "0.1.41"
34 | tracing-perfetto = "0.1.5"
35 | tracing-subscriber = "0.3.20"
36 | yoke = { version = "0.8", features = ["derive"] }
37 | zstd = "0.13.3"
38 |
39 | # Release performance optimizations.
40 | [profile.release]
41 | codegen-units = 1
42 | lto = true
43 | panic = "abort"
44 |
45 | [profile.bench]
46 | debug = true
47 |
48 | [[bench]]
49 | name = "my_benchmark"
50 | # Otherwise needs nightly.
51 | harness = false
52 |
--------------------------------------------------------------------------------
/project-words.txt:
--------------------------------------------------------------------------------
1 | Adafactor
2 | AdamW
3 | Babuska
4 | Bursztein
5 | CHES
6 | CUDA
7 | Corinna
8 | Elems
9 | GPAM
10 | Hastie
11 | IACR
12 | Invernizzi
13 | Josyula
14 | Karel
15 | Král
16 | Luca
17 | MNIST
18 | Mangard
19 | Moghimi
20 | Pankaj
21 | Picod
22 | Pytree
23 | Pytrees
24 | Rohatgi
25 | Ryzen
26 | SBOX
27 | SCAAML
28 | SCARR
29 | Sakkis
30 | Suresh
31 | Tibshirani
32 | Woudenberg
33 | Yann
34 | Yokeable
35 | arange
36 | argmax
37 | argnames
38 | asyncstdlib
39 | bldr
40 | booktitle
41 | burszteindc
42 | byteswap
43 | cdylib
44 | codegen
45 | compresslevel
46 | convnet
47 | crossentropy
48 | ddof
49 | diutils
50 | dtype
51 | dtypes
52 | dunder
53 | einsum
54 | fbapi
55 | flatbuffer
56 | flate
57 | frombuffer
58 | howpublished
59 | hyperparameter
60 | hyperparameters
61 | hypertuned
62 | hypertuning
63 | inproceedings
64 | ipynb
65 | itemsize
66 | kwarguments
67 | linalg
68 | logpdf
69 | mathbb
70 | mathjax
71 | ndarray
72 | newbyteorder
73 | parseable
74 | patchelf
75 | perfcounters
76 | perfetto
77 | pickleable
78 | pyarray
79 | pyclass
80 | pyfunction
81 | pymethods
82 | pymodule
83 | riscure
84 | rlib
85 | rngs
86 | savez
87 | sedpack
88 | setaf
89 | shardfile
90 | subclassing
91 | tensorspec
92 | tfdata
93 | tfdataset
94 | tfrec
95 | tfrecord
96 | tfrecords
97 | tinyaes
98 | tobytes
99 | uoffset
100 | vmap
101 | vmaps
102 | xored
103 | zstd
104 |
--------------------------------------------------------------------------------
/.github/python_matcher.json:
--------------------------------------------------------------------------------
1 | {
2 | "problemMatcher": [
3 | {
4 | "owner": "yapf-diff",
5 | "pattern": [
6 | {
7 | "regexp": "^---\\s*([^\\s]*)\\s*\\(original\\)$",
8 | "file": 1
9 | },
10 | {
11 | "regexp": "^\\+\\+\\+\\s*([^\\s]*)\\s*\\((.*)\\)$",
12 | "message": 2
13 | },
14 | {
15 | "regexp": "^@@\\s*-(\\d+),(\\d+)\\s*\\+(\\d+),(\\d+)\\s*@@$",
16 | "line": 1
17 | }
18 | ]
19 | },
20 | {
21 | "owner": "pylint",
22 | "pattern": [
23 | {
24 | "regexp": "^([^:]+):(\\d+):(\\d+):\\s*([CEFIRW]\\d{4}):\\s*(.*)$",
25 | "file": 1,
26 | "line": 2,
27 | "column": 3,
28 | "code": 4,
29 | "message": 5
30 | }
31 | ]
32 | },
33 | {
34 | "owner": "mypy",
35 | "pattern": [
36 | {
37 | "regexp": "^([^:]+):(\\d+):\\s*([^:]+):\\s*(.*)$",
38 | "file": 1,
39 | "line": 2,
40 | "code": 3,
41 | "message": 4
42 | }
43 | ]
44 | }
45 | ]
46 | }
47 |
--------------------------------------------------------------------------------
/src/sedpack/io/errors.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Implement an error to indicate that a scaaml.io.Dataset already exists.
15 |
16 | Creating scaaml.io.Dataset should not overwrite existing files. When it could
17 | the constructor needs to raise an error, which should also contain the dataset
18 | directory.
19 | """
20 |
21 | from pathlib import Path
22 |
23 |
24 | class DatasetExistsError(FileExistsError):
25 | """Error for signalling that the dataset already exists."""
26 |
27 | def __init__(self, dataset_path: Path) -> None:
28 | """Represents that the dataset already exists.
29 |
30 | Args:
31 | dataset_path: The dataset path.
32 | """
33 | super().__init__(
34 | f'Dataset info file exists and would be overwritten. Use instead:'
35 | f' Dataset(path="{dataset_path}")')
36 | self.dataset_path = dataset_path
37 |
--------------------------------------------------------------------------------
/src/sedpack/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Dataset library.
15 |
16 | Version format: MAJOR.MINOR.PATCH (see https://pypi.org/project/semver/ for
17 | more possibilities).
18 | """
19 |
20 | from importlib.metadata import version, PackageNotFoundError
21 |
22 | # The version of this package is defined by rust/Cargo.toml and dynamically
23 | # deduced by Maturin (see
24 | # https://www.maturin.rs/metadata.html#dynamic-metadata). To ensure
25 | # compatibility with existing code we dynamically set the __version__ attribute
26 | # here.
27 | try:
28 | # When package is installed use the version.
29 | __version__ = version("sedpack") # pylint: disable=invalid-name
30 | except PackageNotFoundError:
31 | # Package is not installed. The Rust part of this package is probably not
32 | # going to work in this case (the Rust binding would be probably missing).
33 | __version__ = "0.0.7-dev" # pylint: disable=invalid-name
34 |
--------------------------------------------------------------------------------
/.github/workflows/pylint.yml:
--------------------------------------------------------------------------------
1 | name: pylint
2 | permissions:
3 | contents: read
4 | pull-requests: write
5 | on:
6 | pull_request:
7 | types: [opened, synchronize, reopened]
8 | paths:
9 | - '**/*.py'
10 | - '.pylintrc'
11 | merge_group: # Needed for required workflows
12 | # Run after a review has been submitted (this is a required workflow which
13 | # might not be triggered when no code changes -- trigger before going to
14 | # merge queue).
15 | pull_request_review:
16 | types: [submitted]
17 |
18 | jobs:
19 | copyright_header:
20 | runs-on: ubuntu-22.04
21 | steps:
22 | - uses: actions/checkout@v6
23 | - name: Check licence headers
24 | run: ./tools/check_copyright.sh
25 |
26 | pylint:
27 | runs-on: ubuntu-22.04
28 | strategy:
29 | matrix:
30 | python-version: ['3.10', '3.11', '3.12']
31 | steps:
32 | - uses: actions/checkout@v6
33 | - name: Set up Python ${{ matrix.python-version }}
34 | uses: actions/setup-python@v6
35 | with:
36 | python-version: ${{ matrix.python-version }}
37 | cache: 'pip'
38 | - name: Install dependencies
39 | run: |
40 | python -m pip install --upgrade pip setuptools wheel
41 | pip install --require-hashes --no-deps -r requirements.txt
42 | pip install --upgrade pylint
43 | - name: Register matcher
44 | run: echo ::add-matcher::./.github/python_matcher.json
45 | - name: Test code with pylint
46 | run: ./tools/run_pylint.sh
47 |
48 |
--------------------------------------------------------------------------------
/tests/io/itertools/test_itertools.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import random
16 |
17 | from sedpack.io.itertools import *
18 |
19 |
20 | def test_round_robin_simple() -> None:
21 | l = ["ABC", "D", "EFGH"]
22 |
23 | assert sorted(round_robin(l)) == list("ABCDEFGH")
24 |
25 |
26 | def test_round_robin_long_iter() -> None:
27 | l = map(lambda x: range(x, x + 10), range(100))
28 |
29 | assert len(list(round_robin(l))) == 1_000
30 |
31 |
32 | def test_round_robin_docstring() -> None:
33 | l = ["ABC", "D", "EF"]
34 |
35 | assert sorted(round_robin(l)) == ["A", "B", "C", "D", "E", "F"]
36 |
37 |
38 | def test_random_shuffle() -> None:
39 | random.seed(42)
40 | l = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
41 |
42 | seen = []
43 |
44 | for x in shuffle_buffer(l, buffer_size=3):
45 | seen.append(x)
46 |
47 | # We have seen all elements and each only once.
48 | assert sorted(l) == sorted(seen)
49 | assert l != seen
50 |
--------------------------------------------------------------------------------
/mypy.ini:
--------------------------------------------------------------------------------
1 | [mypy]
2 | plugins = pydantic.mypy
3 | show_error_codes = True
4 | follow_imports = silent
5 | local_partial_types = true
6 | strict_equality = true
7 | no_implicit_optional = true
8 | warn_incomplete_stub = true
9 | warn_redundant_casts = true
10 | warn_unused_configs = true
11 | warn_unused_ignores = true
12 | enable_error_code = ignore-without-code, redundant-self, truthy-iterable
13 | disable_error_code = annotation-unchecked, import-not-found, import-untyped
14 | extra_checks = false
15 | check_untyped_defs = true
16 | disallow_incomplete_defs = true
17 | disallow_subclassing_any = true
18 | disallow_untyped_calls = true
19 | disallow_untyped_decorators = true
20 | disallow_untyped_defs = true
21 | warn_return_any = true
22 | warn_unreachable = true
23 | allow_redefinition = false
24 | strict_optional = true
25 |
26 | [pydantic-mypy]
27 | init_forbid_extra = true
28 | init_typed = true
29 | warn_required_dynamic_aliases = true
30 | warn_untyped_fields = true
31 |
32 | [mypy-sedpack.*]
33 | no_implicit_reexport = true
34 | disallow_untyped_calls = true
35 | disallow_any_unimported = true
36 | disallow_untyped_decorators = true
37 | strict = true
38 | enable_error_code = ignore-without-code, redundant-self, truthy-iterable, possibly-undefined, truthy-bool, truthy-iterable, unused-ignore, mutable-override
39 |
40 | [mypy-tests.*]
41 | disallow_untyped_defs = false
42 | disallow_untyped_calls = false
43 | disallow_untyped_decorators = false
44 |
45 | [mypy-sedpack.io.flatbuffer.unit_tests.*]
46 | disallow_untyped_defs = false
47 | disallow_untyped_calls = false
48 | disallow_untyped_decorators = false
49 |
--------------------------------------------------------------------------------
/.github/workflows/spellcheck.yml:
--------------------------------------------------------------------------------
1 | name: spellcheck
2 | permissions:
3 | contents: read
4 | pull-requests: write
5 | on:
6 | pull_request:
7 | types: [opened, synchronize, reopened]
8 | merge_group: # Needed for required workflows
9 |
10 | jobs:
11 | spellchecking:
12 | runs-on: ubuntu-22.04
13 | steps:
14 | - name: Checkout the code
15 | uses: actions/checkout@v6
16 | with:
17 | # We need all history to list all changed files.
18 | fetch-depth: 0
19 | - name: Set up node
20 | uses: actions/setup-node@v6
21 | with:
22 | node-version: "21"
23 | - name: Install cspell
24 | run: npm install --location=global cspell
25 | - name: Run spell checking
26 | run: |
27 | # Add a problem matcher.
28 | echo ::add-matcher::./.github/cspell_matcher.json
29 | # Set internal field separator (filenames can have space in them).
30 | IFS=$'\n'
31 | # Find out which files were changed since main.
32 | git diff --name-only origin/main..${{ github.event.after }} | { # Open sub-shell
33 | # Spell check all files and remember if some failed.
34 | EXIT_CODE=0
35 | # Loop over the changed files.
36 | while read CHANGED_FILE; do
37 | # Run cspell on CHANGED_FILE, do not fail if the file is
38 | # ignored or the spell-checking fails.
39 | cspell --config ./cspell.json --no-must-find-files "$CHANGED_FILE" || EXIT_CODE=$?
40 | echo $EXIT_CODE
41 | done ;
42 | exit $EXIT_CODE ; # Fail if some check failed.
43 | } ;
44 |
--------------------------------------------------------------------------------
/src/sedpack/io/shard/get_shard_writer.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Base class for shard writing depending on shard_file_type.
15 | """
16 |
17 | from pathlib import Path
18 |
19 | from sedpack.io.metadata import DatasetStructure
20 | from sedpack.io.types import ShardFileTypeT
21 |
22 | from sedpack.io.shard.shard_writer_base import ShardWriterBase
23 | from sedpack.io.shard.shard_writer_flatbuffer import ShardWriterFlatBuffer
24 | from sedpack.io.shard.shard_writer_np import ShardWriterNP
25 | from sedpack.io.shard.shard_writer_tfrec import ShardWriterTFRec
26 |
27 | _SHARD_FILE_TYPE_TO_CLASS: dict[ShardFileTypeT, type[ShardWriterBase]] = {
28 | "tfrec": ShardWriterTFRec,
29 | "npz": ShardWriterNP,
30 | "fb": ShardWriterFlatBuffer,
31 | }
32 |
33 |
34 | def get_shard_writer(dataset_structure: DatasetStructure,
35 | shard_file: Path) -> ShardWriterBase:
36 | """Return the right subclass of ShardWriterBase.
37 | """
38 | return _SHARD_FILE_TYPE_TO_CLASS[dataset_structure.shard_file_type](
39 | dataset_structure=dataset_structure,
40 | shard_file=shard_file,
41 | )
42 |
--------------------------------------------------------------------------------
/tests/io/test_error.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from pathlib import Path
16 | import pytest
17 |
18 | import sedpack
19 | from sedpack.io import Dataset
20 | from sedpack.io import Metadata
21 | from sedpack.io.errors import DatasetExistsError
22 |
23 |
24 | def test_dataset_exists(tmpdir: str | Path) -> None:
25 | tiny_experiment_path: Path = Path(tmpdir) / "exists"
26 | # Metadata
27 | dataset_metadata = Metadata(description="Test of the lib")
28 | example_attributes = [
29 | sedpack.io.metadata.Attribute(
30 | name="attribute_name",
31 | dtype="float32",
32 | shape=(10,),
33 | ),
34 | ]
35 | dataset_structure = sedpack.io.metadata.DatasetStructure(
36 | saved_data_description=example_attributes)
37 |
38 | # First should be ok.
39 | dataset = Dataset.create(
40 | path=tiny_experiment_path,
41 | metadata=dataset_metadata,
42 | dataset_structure=dataset_structure,
43 | )
44 |
45 | with pytest.raises(DatasetExistsError):
46 | # This would overwrite.
47 | dataset = Dataset.create(
48 | path=tiny_experiment_path,
49 | metadata=dataset_metadata,
50 | dataset_structure=dataset_structure,
51 | )
52 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["maturin>=1.7,<2.0"]
3 | build-backend = "maturin"
4 |
5 | [tool.maturin]
6 | python-source = "src"
7 | manifest-path = "rust/Cargo.toml"
8 | features = ["pyo3/extension-module"]
9 | # Implemented in Rust:
10 | module-name = "sedpack._sedpack_rs"
11 |
12 | [project]
13 | name = "sedpack"
14 | authors = [
15 | { name="Elie Bursztein"},
16 | { name="Karel Král"},
17 | { name="Jean-Michel Picod"},
18 | ]
19 | description = "General ML dataset package"
20 | readme = "README.md"
21 | requires-python = ">=3.10"
22 | keywords = ["machine learning", "dataset"]
23 | license = {text = "Apache License 2.0"}
24 | classifiers = [
25 | "Development Status :: 5 - Production/Stable",
26 | "Environment :: Console",
27 | "Framework :: Jupyter",
28 | "License :: OSI Approved :: Apache Software License",
29 | "Intended Audience :: Science/Research",
30 | "Programming Language :: Python :: 3",
31 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
32 | ]
33 | dynamic = ["version"]
34 | dependencies = [
35 | "aiofiles",
36 | "asyncstdlib",
37 | "flatbuffers",
38 | "lz4",
39 | "numpy",
40 | "pydantic",
41 | "semver",
42 | "tenacity",
43 | "tensorflow",
44 | "tqdm",
45 | "xxhash",
46 | "zstandard",
47 | ]
48 |
49 | [project.optional-dependencies]
50 | dev = [
51 | "maturin[patchelf,zig] ; platform_system != 'Windows'",
52 | "maturin[zig] ; platform_system == 'Windows'",
53 | "mypy",
54 | "pylint",
55 | "pytest",
56 | "pytest-asyncio",
57 | "pytest-cov",
58 | "types-aiofiles",
59 | "types-tensorflow",
60 | "types-tqdm",
61 | "yapf",
62 | ]
63 |
64 | [project.scripts]
65 |
66 | [project.urls]
67 | "Homepage" = "https://github.com/google/sedpack"
68 | "Bug Tracker" = "https://github.com/google/sedpack"
69 |
70 | [tool.ruff]
71 | target-version = "py310"
72 |
--------------------------------------------------------------------------------
/.github/workflows/deploy.yml:
--------------------------------------------------------------------------------
1 | name: Deploy to GitHub Pages
2 |
3 | on:
4 | # Trigger the workflow every time you push to the `main` branch
5 | # Using a different branch name? Replace `main` with your branch’s name
6 | push:
7 | branches: [ main ]
8 | paths: [ website ]
9 | merge_group: # Needed for required workflows
10 | # Allows you to run this workflow manually from the Actions tab on GitHub.
11 | workflow_dispatch:
12 |
13 | # Allow this job to clone the repo and create a page deployment
14 | permissions:
15 | contents: read
16 | pages: write
17 | id-token: write
18 |
19 | # Allow only one concurrent deployment, skipping runs queued between the run in-progress and latest queued.
20 | concurrency:
21 | group: "pages"
22 | cancel-in-progress: false
23 |
24 | jobs:
25 | build:
26 | # Run only on the upstream repository.
27 | if: github.repository == 'google/sedpack'
28 | runs-on: ubuntu-latest
29 | steps:
30 | - name: Checkout your repository using git
31 | uses: actions/checkout@v6
32 | - name: Install, build, and upload your site
33 | uses: withastro/action@2226b5671ff302b175d9843add614af27e60bbfc # v4
34 | with:
35 | path: website # The root location of your Astro project inside the repository. (optional)
36 | # node-version: 20 # The specific version of Node that should be used to build your site. Defaults to 20. (optional)
37 | # package-manager: pnpm@latest # The Node package manager that should be used to install dependencies and build your site. Automatically detected based on your lockfile. (optional)
38 |
39 | deploy:
40 | needs: build
41 | runs-on: ubuntu-latest
42 | environment:
43 | name: github-pages
44 | url: ${{ steps.deployment.outputs.page_url }}
45 | steps:
46 | - name: Deploy to GitHub Pages
47 | id: deployment
48 | uses: actions/deploy-pages@v4
49 |
--------------------------------------------------------------------------------
/.github/workflows/base_benchmarks.yml:
--------------------------------------------------------------------------------
1 | # Based on the tutorial https://bencher.dev/docs/how-to/github-actions/
2 |
3 | on:
4 | push:
5 | branches: main
6 | paths:
7 | - '**/*.rs'
8 | - 'rust/Cargo.toml'
9 | - 'rust/Cargo.lock'
10 | schedule:
11 | # Run once a month (random values for minute, hour, day; any month or day)
12 | - cron: "3 1 4 * *"
13 | # Allows you to run this workflow manually from the Actions tab on GitHub.
14 | workflow_dispatch:
15 |
16 | jobs:
17 | benchmark_base_branch:
18 | # Run only on the upstream repository.
19 | if: github.repository == 'google/sedpack'
20 | name: Continuous Benchmarking with Bencher
21 | permissions:
22 | checks: write
23 | runs-on: ubuntu-latest
24 | steps:
25 | - uses: actions/checkout@v6
26 | - name: Set up Python
27 | uses: actions/setup-python@v6
28 | with:
29 | cache: 'pip'
30 | - name: Installing sedpack pip package
31 | run: |
32 | pip install sedpack
33 | - name: Prepare benchmarking data
34 | working-directory: ./rust
35 | run: bash benches/setup.sh
36 | - uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 # v1
37 | - uses: bencherdev/bencher@f89d454e74a32a81b2eab29fe0afdb2316617342 # v0.5
38 | - name: Track base branch benchmarks with Bencher
39 | working-directory: ./rust
40 | run: |
41 | bencher run \
42 | --project sedpack \
43 | --token '${{ secrets.BENCHER_API_TOKEN }}' \
44 | --branch main \
45 | --testbed ubuntu-latest \
46 | --threshold-measure latency \
47 | --threshold-test t_test \
48 | --threshold-max-sample-size 64 \
49 | --threshold-upper-boundary 0.99 \
50 | --thresholds-reset \
51 | --err \
52 | --adapter rust_criterion \
53 | --github-actions '${{ secrets.GITHUB_TOKEN }}' \
54 | "cargo bench"
55 |
--------------------------------------------------------------------------------
/tests/io/test_dataset_base.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from pathlib import Path
16 | from typing import Union
17 |
18 | import numpy as np
19 |
20 | import sedpack
21 | from sedpack.io import Dataset, Metadata
22 |
23 |
24 | def test_dataset_info_is_copy(tmpdir: Union[str, Path]) -> None:
25 | dtype = "float32"
26 | compression = "LZ4"
27 | tiny_experiment_path: Path = Path(tmpdir) / "dataset_info"
28 | array_of_values = np.random.random((1024, 138))
29 | array_of_values = array_of_values.astype(dtype)
30 |
31 | # Create a dataset
32 |
33 | dataset_metadata = Metadata(description="Test of the lib")
34 |
35 | example_attributes = [
36 | sedpack.io.metadata.Attribute(
37 | name="attribute_name",
38 | dtype=str(dtype),
39 | shape=array_of_values[0].shape,
40 | ),
41 | ]
42 |
43 | dataset_structure = sedpack.io.metadata.DatasetStructure(
44 | saved_data_description=example_attributes,
45 | compression=compression,
46 | examples_per_shard=256,
47 | shard_file_type="fb",
48 | )
49 |
50 | dataset = Dataset.create(
51 | path=tiny_experiment_path,
52 | metadata=dataset_metadata,
53 | dataset_structure=dataset_structure,
54 | )
55 |
56 | old_dataset_info = dataset.dataset_info
57 |
58 | # Will make a copy
59 | dataset.dataset_info.dataset_structure.compression = ""
60 |
61 | assert old_dataset_info == dataset.dataset_info
62 |
--------------------------------------------------------------------------------
/website/astro.config.mjs:
--------------------------------------------------------------------------------
1 | // @ts-check
2 | import { defineConfig } from 'astro/config';
3 | import starlight from '@astrojs/starlight';
4 | import remarkMath from 'remark-math';
5 | import rehypeMathJax from 'rehype-mathjax';
6 |
7 | // https://astro.build/config
8 | export default defineConfig({
9 | site: 'https://google.github.io/sedpack/',
10 | base: '/sedpack',
11 |
12 | // Configure `remark-math` and `rehype-mathjax` plugins:
13 | markdown: {
14 | remarkPlugins: [remarkMath],
15 | rehypePlugins: [rehypeMathJax],
16 | },
17 |
18 | integrations: [
19 | starlight({
20 | title: 'Sedpack Documentation',
21 | social: [
22 | {
23 | icon: 'github',
24 | label: 'GitHub',
25 | href: 'https://github.com/google/sedpack',
26 | }
27 | ],
28 | // Custom CSS to style MathJax equations
29 | customCss: ['./src/mathjax.css'],
30 | sidebar: [
31 | {
32 | label: 'Start Here',
33 | items: [
34 | // Each item here is one entry in the navigation menu.
35 | { label: 'Getting Started', slug: 'start_here/intro' },
36 | { label: 'Installation', slug: 'start_here/install' },
37 | ],
38 | },
39 | {
40 | label: 'Tutorials',
41 | items: [
42 | { label: 'MNIST', slug: 'tutorials/mnist' },
43 | {
44 | label: 'Side Channel Attacks',
45 | items: [
46 | { label: 'SCA Overview', slug: 'tutorials/sca/overview' },
47 | { label: 'Dataset Preparation', slug: 'tutorials/sca/dataset' },
48 | {
49 | label: 'Classical Attacks',
50 | items: [
51 | { label: 'Signal to Noise Ratio', slug: 'tutorials/sca/snr' },
52 | { label: 'GPU Acceleration of CPA and Template Attacks', slug: 'tutorials/sca/gpu_cpa_template' },
53 | ],
54 | },
55 | { label: 'Deep Learning (GPAM)', slug: 'tutorials/sca/gpam' },
56 | ],
57 | },
58 | ],
59 | },
60 | ],
61 | }),
62 | ],
63 | });
64 |
--------------------------------------------------------------------------------
/tests/io/test_file_info.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from pathlib import Path
16 | import pytest
17 |
18 | from sedpack.io.file_info import PathGenerator
19 |
20 |
21 | def test_multilevel() -> None:
22 | levels: int = 3
23 | max_branching: int = 2
24 | name_length: int = 3
25 |
26 | generator = PathGenerator(
27 | levels=levels,
28 | max_branching=max_branching,
29 | name_length=name_length,
30 | )
31 |
32 | seen_paths: set[Path] = set()
33 |
34 | # No exception in the top level
35 | for _ in range((max_branching**levels) + 10):
36 | p = generator.get_path()
37 | assert len(p.parts) == levels
38 |
39 | # Last part is long, previous are at most name_length
40 | assert len(p.parts[-1]) >= 32
41 | assert all(len(part) == name_length for part in p.parts[:-1])
42 |
43 | seen_paths.add(p)
44 |
45 | # Enforce bounded number of subdirectories except the top
46 | for l in range(levels - 1):
47 | for p in seen_paths:
48 | prefix = p.parents[l]
49 | count = sum(path.is_relative_to(prefix) for path in seen_paths)
50 | assert count <= max_branching**(l + 1)
51 |
52 |
53 | def test_single_level() -> None:
54 | max_branching: int = 3
55 | name_length: int = 3
56 | generator = PathGenerator(
57 | levels=1,
58 | max_branching=max_branching,
59 | name_length=name_length,
60 | )
61 |
62 | # No exception in the top level
63 | for _ in range(max_branching + 10):
64 | n = generator.get_path()
65 | assert len(str(n)) >= 32 # last level is unbounded length
66 | assert len(n.parts) == 1
67 |
--------------------------------------------------------------------------------
/.github/workflows/mypy.yml:
--------------------------------------------------------------------------------
1 | name: mypy
2 | permissions:
3 | contents: read
4 | pull-requests: write
5 | on:
6 | pull_request:
7 | types: [opened, synchronize, reopened]
8 | paths:
9 | - '**/*.py'
10 | - '**/*.rs'
11 | - 'mypy.ini'
12 | merge_group: # Needed for required workflows
13 | # Run after a review has been submitted (this is a required workflow which
14 | # might not be triggered when no code changes -- trigger before going to
15 | # merge queue).
16 | pull_request_review:
17 | types: [submitted]
18 |
19 | jobs:
20 | mypy:
21 | runs-on: ubuntu-22.04
22 | steps:
23 | - uses: actions/checkout@v6
24 | - name: Set up Python 3.10
25 | uses: actions/setup-python@v6
26 | with:
27 | python-version: '3.10'
28 | cache: 'pip'
29 | - name: Get pip cache directory
30 | id: pip-cache
31 | shell: bash
32 | run: |
33 | echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
34 | - name: Use cached venv or create it
35 | uses: actions/cache/restore@v5
36 | id: cache
37 | with:
38 | path: ${{ steps.pip-cache.outputs.dir }}
39 | # The cache key depends on requirements.txt
40 | key: ${{ runner.os }}-pip-${{ hashFiles('requirements.txt') }}-${{ hashFiles('pyproject.toml') }}
41 | restore-keys: |
42 | ${{ runner.os }}-pip-
43 | # Build a virtualenv, but only if it doesn't already exist
44 | - name: Populate pip cache
45 | run: |
46 | pip install --require-hashes --no-deps -r requirements.txt
47 | pip install --editable ".[dev]"
48 | - name: Save cache
49 | id: cache-save
50 | uses: actions/cache/save@v5
51 | with:
52 | path: ${{ steps.pip-cache.outputs.dir }}
53 | key: ${{ steps.cache.outputs.cache-primary-key }}
54 | if: steps.cache.outputs.cache-hit != 'true'
55 | - name: Register matcher
56 | run: echo ::add-matcher::./.github/python_matcher.json
57 | - name: Running mypy
58 | run: |
59 | echo "PYTHONPATH=./src:$PYTHONPATH" >> $GITHUB_ENV
60 | mkdir -p .mypy_cache
61 | mypy --version
62 | mypy --no-color-output --install-types --non-interactive src docs
63 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Sedpack - Scalable and efficient data packing
2 |
3 | [](https://coveralls.io/github/google/sedpack?branch=main)
4 |
5 | [Documentation](https://google.github.io/sedpack/)
6 |
7 | Mainly refactored from the [SCAAML](https://github.com/google/scaaml) project.
8 |
9 | ## Available components
10 |
11 | See the documentation website:
12 | [https://google.github.io/sedpack/](https://google.github.io/sedpack/).
13 |
14 | ## Install
15 |
16 | ### Dependencies
17 |
18 | To use this library you need to have a working version of [TensorFlow
19 | 2.x](https://www.tensorflow.org/install).
20 |
21 | Development dependencies:
22 |
23 | - python-dev and gcc for [xxhash](https://pypi.org/project/xxhash/)
24 |
25 | ### Dataset install
26 |
27 | #### Development install
28 |
29 | 1. Clone the repository: `git clone https://github.com/google/sedpack`
30 | 2. Install dependencies: `python3 -m pip install --require-hashes -r requirements.txt`
31 | 3. Install the package in development mode: `python3 -m pip install --editable
32 | .` (short `pip install -e .` or legacy `python setup.py develop`)
33 |
34 | #### Rust install
35 |
36 | - Activate your Python virtual environment
37 | - [Install Rust](https://www.rust-lang.org/tools/install)
38 | - Run `maturin develop --release`
39 | - Run `python -m pytest` from the project root directory -- no tests should
40 | be skipped
41 |
42 | ### Update dependencies
43 |
44 | Make sure to have: `sudo apt install python3 python3-pip python3-venv` and
45 | activated the virtual environment.
46 |
47 | Install requirements: `pip install --require-hashes -r base-tooling-requirements.txt`
48 |
49 | Update: `pip-compile pyproject.toml --generate-hashes --upgrade` and commit requirements.txt.
50 |
51 | #### Package install
52 |
53 | `pip install sedpack`
54 |
55 | ### Tutorial
56 |
57 | A tutorial and documentation is available at
58 | [https://google.github.io/sedpack/](https://google.github.io/sedpack/).
59 |
60 | Code for the tutorials is available in the `docs/tutorials` directory. For a
61 | "hello world" see
62 | [https://google.github.io/sedpack/tutorials/mnist/](https://google.github.io/sedpack/tutorials/mnist/).
63 |
64 | ## Disclaimer
65 |
66 | This is not an official Google product.
67 |
--------------------------------------------------------------------------------
/tests/io/iteration/test_rust_generator.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from pathlib import Path
16 | import pytest
17 | from typing import Union
18 |
19 | import numpy as np
20 |
21 | import sedpack
22 | from sedpack.io.iteration import RustGenerator
23 | from sedpack.io.metadata import DatasetStructure
24 |
25 |
26 | def test_wrong_file_paralelism() -> None:
27 | with pytest.raises(
28 | ValueError,
29 | match="The argument file_parallelism should be positive.*",
30 | ):
31 | g = RustGenerator(
32 | dataset_path=Path(),
33 | dataset_structure=DatasetStructure(),
34 | shard_iterator=[],
35 | process_record=None,
36 | file_parallelism=0,
37 | )
38 |
39 |
40 | def test_wrong_shard_type() -> None:
41 | with pytest.raises(
42 | ValueError,
43 | match="RustGenerator is implemented only for FlatBuffers.",
44 | ):
45 | g = RustGenerator(
46 | dataset_path=Path(),
47 | dataset_structure=DatasetStructure(shard_file_type="tfrec"),
48 | shard_iterator=[],
49 | process_record=None,
50 | file_parallelism=1,
51 | )
52 |
53 |
54 | def test_wrong_compression() -> None:
55 | with pytest.raises(
56 | ValueError,
57 | match=
58 | "The compression .* is not among the supported compressions: .*",
59 | ):
60 | g = RustGenerator(
61 | dataset_path=Path(),
62 | dataset_structure=DatasetStructure(
63 | shard_file_type="fb",
64 | compression="ZIP",
65 | ),
66 | shard_iterator=[],
67 | process_record=None,
68 | file_parallelism=1,
69 | )
70 |
--------------------------------------------------------------------------------
/.github/workflows/fork_pr_benchmarks_track.yml:
--------------------------------------------------------------------------------
1 | # Based on https://bencher.dev/docs/how-to/github-actions/
2 | name: Track Benchmarks with Bencher
3 | permissions:
4 | contents: read
5 | pull-requests: write
6 |
7 | on:
8 | workflow_run:
9 | workflows: [Run Benchmarks]
10 | types: [completed]
11 |
12 | jobs:
13 | track_fork_pr_branch:
14 | if: github.event.workflow_run.conclusion == 'success'
15 | runs-on: ubuntu-latest
16 | env:
17 | BENCHMARK_RESULTS: benchmark_results.txt
18 | PR_EVENT: event.json
19 | steps:
20 | - name: Download Benchmark Results
21 | uses: dawidd6/action-download-artifact@ac66b43f0e6a346234dd65d4d0c8fbb31cb316e5 # v11
22 | with:
23 | name: ${{ env.BENCHMARK_RESULTS }}
24 | run_id: ${{ github.event.workflow_run.id }}
25 | - name: Download PR Event
26 | uses: dawidd6/action-download-artifact@ac66b43f0e6a346234dd65d4d0c8fbb31cb316e5 # v11
27 | with:
28 | name: ${{ env.PR_EVENT }}
29 | run_id: ${{ github.event.workflow_run.id }}
30 | - name: Figure out what is where
31 | run: ls
32 | - name: Export PR Event Data
33 | uses: actions/github-script@v8
34 | with:
35 | script: |
36 | let fs = require('fs');
37 | let prEvent = JSON.parse(fs.readFileSync(process.env.PR_EVENT, {encoding: 'utf8'}));
38 | core.exportVariable("PR_HEAD", prEvent.pull_request.head.ref);
39 | core.exportVariable("PR_BASE", prEvent.pull_request.base.ref);
40 | core.exportVariable("PR_BASE_SHA", prEvent.pull_request.base.sha);
41 | core.exportVariable("PR_NUMBER", prEvent.number);
42 | - uses: bencherdev/bencher@f89d454e74a32a81b2eab29fe0afdb2316617342 # v0.5
43 | - name: Track Benchmarks with Bencher
44 | run: |
45 | bencher run \
46 | --project sedpack \
47 | --token '${{ secrets.BENCHER_API_TOKEN }}' \
48 | --branch "$PR_HEAD" \
49 | --start-point "$PR_BASE" \
50 | --start-point-hash "$PR_BASE_SHA" \
51 | --start-point-clone-thresholds \
52 | --start-point-reset \
53 | --testbed ubuntu-latest \
54 | --err \
55 | --adapter rust_criterion \
56 | --github-actions '${{ secrets.GITHUB_TOKEN }}' \
57 | --ci-number "$PR_NUMBER" \
58 | --file "$BENCHMARK_RESULTS"
59 |
--------------------------------------------------------------------------------
/tests/io/npz/test_npz_shards.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from pathlib import Path
16 |
17 | import numpy as np
18 |
19 | from sedpack.io.metadata import Attribute, DatasetStructure
20 | from sedpack.io.shard.shard_writer_np import ShardWriterNP
21 | from sedpack.io.npz import IterateShardNP
22 |
23 |
24 | def test_attribute_bytes(tmpdir: str | Path) -> None:
25 | shard_file = Path(tmpdir / "test_shard.npz")
26 | examples_per_shard = 4
27 | attr_1_shape = (128,)
28 | attr_2_shape = (8,)
29 | attr_1_values = np.random.uniform(size=(examples_per_shard, *attr_1_shape))
30 | attr_2_values = np.random.uniform(size=(examples_per_shard, *attr_2_shape))
31 |
32 | dataset_structure = DatasetStructure(
33 | saved_data_description=[
34 | Attribute(
35 | name="attr_1",
36 | dtype="float32",
37 | shape=attr_1_shape,
38 | ),
39 | Attribute(
40 | name="attr_2",
41 | dtype="float32",
42 | shape=attr_2_shape,
43 | ),
44 | ],
45 | compression="ZIP",
46 | examples_per_shard=examples_per_shard,
47 | shard_file_type="npz",
48 | )
49 |
50 | shard_writer_np = ShardWriterNP(
51 | dataset_structure=dataset_structure,
52 | shard_file=shard_file,
53 | )
54 |
55 | for i in range(examples_per_shard):
56 | shard_writer_np.write(values={
57 | "attr_1": attr_1_values[i],
58 | "attr_2": attr_2_values[i],
59 | })
60 | shard_writer_np.close()
61 |
62 | for i, e in enumerate(
63 | IterateShardNP(dataset_structure=dataset_structure,
64 | process_record=None).iterate_shard(shard_file)):
65 | assert np.allclose(e["attr_1"], attr_1_values[i])
66 | assert np.allclose(e["attr_2"], attr_2_values[i])
67 |
--------------------------------------------------------------------------------
/tests/io/test_hash_checksums.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024-2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from pathlib import Path
16 |
17 | import numpy as np
18 | import pytest
19 |
20 | from sedpack.io.utils import hash_checksums, hash_checksums_from_bytes
21 |
22 |
23 | def test_compress_gzip_write(tmpdir: str | Path) -> None:
24 | file_name = tmpdir / "file.txt"
25 | hashes = ("md5", "sha256", "sha512", "sha384")
26 | with open(file_name, "w", encoding="ascii") as f:
27 | f.write("Hello world")
28 |
29 | checksums = hash_checksums(file_name, hashes)
30 |
31 | assert checksums == (
32 | "3e25960a79dbc69b674cd4ec67a72c62",
33 | "64ec88ca00b268e5ba1a35678a1b5316d212f4f366b2477232534a8aeca37f3c",
34 | "b7f783baed8297f0db917462184ff4f08e69c2d5e5f79a942600f9725f58ce1f29c18139bf80b06c0fff2bdd34738452ecf40c488c22a7e3d80cdf6f9c1c0d47",
35 | "9203b0c4439fd1e6ae5878866337b7c532acd6d9260150c80318e8ab8c27ce330189f8df94fb890df1d298ff360627e1",
36 | )
37 |
38 |
39 | @pytest.mark.parametrize("size", [1, 3, 7, 8, 13, 17, 33, 67])
40 | def test_equivalent(size: int, tmp_path: Path) -> None:
41 | file_content = np.random.randint(
42 | 0,
43 | 256,
44 | size=size,
45 | dtype=np.uint8,
46 | ).tobytes()
47 | hashes = (
48 | "md5",
49 | "sha256",
50 | "sha512",
51 | "sha384",
52 | "xxh32",
53 | "xxh64",
54 | "xxh128",
55 | )
56 |
57 | tmp_path = tmp_path / "shard_file.extension"
58 |
59 | with open(tmp_path, "wb") as f:
60 | f.write(file_content)
61 |
62 | # Redundant, but read again.
63 | with open(tmp_path, "rb") as f:
64 | assert f.read() == file_content
65 |
66 | assert hash_checksums_from_bytes(
67 | file_content=file_content,
68 | hashes=hashes,
69 | ) == hash_checksums(
70 | file_path=tmp_path,
71 | hashes=hashes,
72 | )
73 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # package specific
2 | *.zip
3 | .vscode/
4 | # Byte-compiled / optimized / DLL files
5 | __pycache__/
6 | *.py[cod]
7 | *$py.class
8 | *.prof
9 | # C extensions
10 | *.so
11 | tmp/
12 | algo_arch_implem_v1_train/
13 |
14 | # Distribution / packaging
15 | .Python
16 | build/
17 | develop-eggs/
18 | dist/
19 | downloads/
20 | eggs/
21 | .eggs/
22 | lib/
23 | lib64/
24 | parts/
25 | sdist/
26 | var/
27 | wheels/
28 | pip-wheel-metadata/
29 | share/python-wheels/
30 | *.egg-info/
31 | .installed.cfg
32 | *.egg
33 | MANIFEST
34 |
35 | # PyInstaller
36 | # Usually these files are written by a python script from a template
37 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
38 | *.manifest
39 | *.spec
40 |
41 | # Installer logs
42 | pip-log.txt
43 | pip-delete-this-directory.txt
44 |
45 | # Unit test / coverage reports
46 | htmlcov/
47 | .tox/
48 | .nox/
49 | .coverage
50 | .coverage.*
51 | .cache
52 | nosetests.xml
53 | coverage.xml
54 | *.cover
55 | *.py,cover
56 | .hypothesis/
57 | .pytest_cache/
58 |
59 | # Translations
60 | *.mo
61 | *.pot
62 |
63 | # Django stuff:
64 | *.log
65 | local_settings.py
66 | db.sqlite3
67 | db.sqlite3-journal
68 |
69 | # Flask stuff:
70 | instance/
71 | .webassets-cache
72 |
73 | # Scrapy stuff:
74 | .scrapy
75 |
76 | # Sphinx documentation
77 | docs/_build/
78 |
79 | # PyBuilder
80 | target/
81 |
82 | # Jupyter Notebook
83 | .ipynb_checkpoints
84 |
85 | # IPython
86 | profile_default/
87 | ipython_config.py
88 |
89 | # pyenv
90 | .python-version
91 |
92 | # pipenv
93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
96 | # install all needed dependencies.
97 | #Pipfile.lock
98 |
99 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
100 | __pypackages__/
101 |
102 | # Celery stuff
103 | celerybeat-schedule
104 | celerybeat.pid
105 |
106 | # SageMath parsed files
107 | *.sage.py
108 |
109 | # Environments
110 | .env
111 | .venv
112 | env/
113 | venv/
114 | ENV/
115 | env.bak/
116 | venv.bak/
117 |
118 | # Spyder project settings
119 | .spyderproject
120 | .spyproject
121 |
122 | # Rope project settings
123 | .ropeproject
124 |
125 | # mkdocs documentation
126 | /site
127 |
128 | # mypy
129 | .mypy_cache/
130 | .dmypy.json
131 | dmypy.json
132 |
133 | # Pyre type checker
134 | .pyre/
135 |
--------------------------------------------------------------------------------
/src/sedpack/io/types.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023-2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Build and load tensorFlow dataset Record wrapper"""
15 |
16 | # pylint: disable=invalid-name
17 |
18 | from typing import Any, Literal, Union
19 |
20 | import numpy as np
21 | import numpy.typing as npt
22 | from typing_extensions import TypeAlias # from typing since 3.10
23 |
24 | # Valid split values (used also as directory names).
25 | SplitT: TypeAlias = Literal["train", "test", "holdout"]
26 | TRAIN_SPLIT: SplitT = "train"
27 | TEST_SPLIT: SplitT = "test"
28 | HOLDOUT_SPLIT: SplitT = "holdout"
29 |
30 | # Keras model.
31 | TFModelT: TypeAlias = Any
32 |
33 | # tf.data.Dataset and similar.
34 | TFDatasetT: TypeAlias = Any
35 |
36 | # Type of an attribute value.
37 | AttributeValueT: TypeAlias = Union[
38 | str, # UTF-8 string
39 | int,
40 | bytes,
41 | npt.NDArray[np.generic],
42 | ]
43 |
44 | # Type of a batch of attribute values.
45 | BatchedAttributeValueT: TypeAlias = Union[
46 | list[str], # UTF-8 string
47 | list[int],
48 | list[bytes],
49 | # NP has the first dimension as the batch dimension.
50 | npt.NDArray[np.generic],
51 | ]
52 |
53 | # Compression choices.
54 | CompressionT: TypeAlias = Literal[
55 | "",
56 | "BZ2",
57 | "GZIP",
58 | "LZMA",
59 | "LZ4",
60 | "ZIP",
61 | "ZLIB",
62 | "ZSTD",
63 | ]
64 |
65 | # Hash checksums types (algorithms supported by hashlib).
66 | HashChecksumT: TypeAlias = Literal[
67 | "md5",
68 | "sha1",
69 | "sha224",
70 | "sha256",
71 | "sha384",
72 | "sha512",
73 | "sha3_224",
74 | "sha3_256",
75 | "sha3_384",
76 | "sha3_512",
77 | "xxh32",
78 | "xxh64",
79 | "xxh128",
80 | ]
81 |
82 | # Shard file-type choices. Also serves as the file-extension of shard files.
83 | ShardFileTypeT: TypeAlias = Literal[
84 | "fb",
85 | "npz",
86 | "tfrec",
87 | ]
88 |
89 | # Type alias for example, this is what gets iterated or saved.
90 | ExampleT: TypeAlias = dict[str, AttributeValueT]
91 |
92 | BatchT: TypeAlias = dict[str, BatchedAttributeValueT]
93 |
--------------------------------------------------------------------------------
/website/src/content/docs/start_here/install.mdx:
--------------------------------------------------------------------------------
1 | ---
2 | title: Sedpack Installation
3 | description: Sedpack Installation
4 | ---
5 |
6 | import { Tabs, TabItem } from '@astrojs/starlight/components';
7 |
8 | Installation of the Sedpack Python package.
9 |
10 | ## Python Package Index
11 |
12 | All one should need to install the package from
13 | [PyPI](https://pypi.org/project/sedpack/) is:
14 |
15 | ```bash
16 | pip install sedpack
17 | ```
18 |
19 | Note that this is the latest stable version of the package.
20 | If you want the bleeding edge features install from source instead.
21 |
22 | ## Installing from Source
23 |
24 | One can always opt for installation from the source.
25 |
26 | Development dependencies:
27 |
28 | - python-dev and gcc for [xxhash](https://pypi.org/project/xxhash/)
29 | - [Install Rust](https://www.rust-lang.org/tools/install)
30 |
31 | Clone and create a Python virtual environment:
32 |
33 | ```bash
34 | git clone github.com/google/sedpack/ # Clone the repository
35 | python3 -m venv my_env # Create Python virtual environment
36 | source my_env/bin/activate # Activate your virtual environment
37 | cd sedpack/ # Change directory to the cloned git repository
38 | ```
39 |
40 | [optional] Install pinned Python dependencies:
41 |
42 |
43 |
44 |
45 | ```bash
46 | python3 -m pip install --require-hashes -r requirements.txt
47 | ```
48 |
49 |
50 |
51 |
52 | Currently the file `requirements.txt` is compiled for Linux and some
53 | versions of PyPI packages might be different or even not existing
54 | elsewhere. One can always recreate those:
55 |
56 | ```bash
57 | python3 -m pip install pip-tools
58 | python3 -m pip-compile --generate-hashes pyproject.toml > requirements_win.txt
59 | python3 -m pip install --require-hashes -r requirements_win.txt
60 | ```
61 |
62 |
63 |
64 |
65 | Currently the file `requirements.txt` is compiled for Linux and some
66 | versions of PyPI packages might be different or even not existing
67 | elsewhere. One can always recreate those:
68 |
69 | ```bash
70 | python3 -m pip install pip-tools
71 | python3 -m pip-compile --generate-hashes pyproject.toml > requirements_mac.txt
72 | python3 -m pip install --require-hashes -r requirements_mac.txt
73 | ```
74 |
75 |
76 |
77 |
78 | Install sedpack with development dependencies:
79 |
80 | ```bash
81 | python3 -m pip install --editable '.[dev]'
82 | ```
83 |
84 | [optional] Run unit-tests, none should be skipped:
85 |
86 | ```bash
87 | python -m pytest
88 | ```
89 |
90 | [optional] Rebuild the Rust library after a change:
91 |
92 | ```bash
93 | maturin develop --release
94 | ```
95 |
--------------------------------------------------------------------------------
/src/sedpack/io/shard/iterate_shard_base.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Base class for shard iteration."""
15 |
16 | from abc import ABC, abstractmethod
17 | from pathlib import Path
18 | from typing import AsyncIterator, Callable, Generic, Iterable, TypeVar
19 |
20 | from sedpack.io.metadata import DatasetStructure
21 | from sedpack.io.types import ExampleT
22 |
23 | T = TypeVar("T")
24 |
25 |
26 | class IterateShardBase(ABC, Generic[T]):
27 | """Remember everything to be able to iterate shards. This can be pickled
28 | and passed as a callable object into another process.
29 | """
30 |
31 | def __init__(self, dataset_structure: DatasetStructure,
32 | process_record: Callable[[ExampleT], T] | None) -> None:
33 | """Initialize the shard iterator.
34 |
35 | Args:
36 |
37 | dataset_structure (DatasetStructure): The structure of the dataset
38 | allowing shard parsing.
39 |
40 | process_record (Callable[[ExampleT], T] | None): How to process each
41 | record. Needs to be pickleable (for multiprocessing).
42 | """
43 | self.dataset_structure: DatasetStructure = dataset_structure
44 | self.process_record: Callable[[ExampleT], T] | None = process_record
45 |
46 | @abstractmethod
47 | def iterate_shard(self, file_path: Path) -> Iterable[ExampleT]:
48 | """Iterate a shard.
49 | """
50 |
51 | @abstractmethod
52 | def iterate_shard_async(self, file_path: Path) -> AsyncIterator[ExampleT]:
53 | """Asynchronously iterate a shard.
54 | """
55 | # TODO(issue #85) fix and test async iterator typing
56 |
57 | @abstractmethod
58 | def process_and_list(self, shard_file: Path) -> list[T]:
59 | """Return a list of processed examples. Used as a function call in a
60 | different process. Returning a list as opposed to an iterator allows to
61 | do all work in another process and all that needs to be done is a
62 | memory copy between processes.
63 |
64 | Args:
65 |
66 | shard_file (Path): Path to the shard file.
67 |
68 | Returns: A list of examples present in the shard. If `process_record`
69 | is defined it is applied to all those examples.
70 | """
71 |
--------------------------------------------------------------------------------
/tests/io/test_bytes.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from pathlib import Path
16 | import pytest
17 | from typing import Union
18 |
19 | import numpy as np
20 |
21 | import sedpack
22 | from sedpack.io import Dataset, Metadata
23 | from sedpack.io.types import TRAIN_SPLIT
24 | from sedpack.io.utils import is_module_present
25 |
26 |
27 | @pytest.mark.skipif(
28 | not is_module_present("tensorflow"),
29 | reason="TensorFlow is optional, skip test if not present.",
30 | )
31 | def test_attribute_bytes(tmpdir: Union[str, Path]) -> None:
32 | array_of_values = [
33 | bytes(
34 | np.random.randint(
35 | 0, # low
36 | 256, # high
37 | np.random.randint(10, 1_000, (), np.int32), # size
38 | np.uint8,
39 | )) for _ in range(138)
40 | ]
41 |
42 | tiny_experiment_path: Path = Path(tmpdir) / "e2e_experiment"
43 |
44 | # Create a dataset
45 |
46 | dataset_metadata = Metadata(description="Test of the lib")
47 |
48 | example_attributes = [
49 | sedpack.io.metadata.Attribute(
50 | name="attribute_name",
51 | dtype="bytes",
52 | shape=(), # ignored
53 | ),
54 | ]
55 |
56 | dataset_structure = sedpack.io.metadata.DatasetStructure(
57 | saved_data_description=example_attributes,
58 | compression="GZIP",
59 | examples_per_shard=256,
60 | shard_file_type="tfrec",
61 | )
62 |
63 | dataset = Dataset.create(
64 | path=tiny_experiment_path,
65 | metadata=dataset_metadata,
66 | dataset_structure=dataset_structure,
67 | )
68 |
69 | # Fill data in the dataset
70 |
71 | with dataset.filler(concurrency=np.random.randint(0, 4),) as filler:
72 | for attribute_value in array_of_values:
73 | filler.write_example(
74 | values={"attribute_name": attribute_value},
75 | split=TRAIN_SPLIT,
76 | )
77 |
78 | # Check the data is correct
79 |
80 | for i, example in enumerate(
81 | dataset.as_numpy_iterator(
82 | split=TRAIN_SPLIT,
83 | shuffle=0,
84 | repeat=False,
85 | )):
86 | if example["attribute_name"] != array_of_values[i]:
87 | print(example["attribute_name"])
88 | print(array_of_values[i])
89 | raise ValueError
90 |
91 | # We tested everything
92 | assert i + 1 == len(array_of_values)
93 |
--------------------------------------------------------------------------------
/.github/dependabot.yml:
--------------------------------------------------------------------------------
1 | # To get started with Dependabot version updates, you'll need to specify which
2 | # package ecosystems to update and where the package manifests are located.
3 | # Please see the documentation for all configuration options:
4 | # https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates
5 |
6 | version: 2
7 | updates:
8 | - package-ecosystem: "npm"
9 | directory: "/website"
10 | labels:
11 | - "dependencies"
12 | - "javascript"
13 | # Run every Monday
14 | schedule:
15 | interval: "monthly"
16 | timezone: "Europe/Zurich"
17 | # Group PRs to avoid having to rebase/merge too many
18 | groups:
19 | dependabot:
20 | patterns:
21 | - "*"
22 | # Only care about our direct dependencies
23 | allow:
24 | - dependency-type: "direct"
25 | ignore:
26 | # Filter out semver patches updates to reduce the frequency of updates
27 | - dependency-name: "*"
28 | update-types: ["version-update:semver-patch"]
29 |
30 | - package-ecosystem: "github-actions"
31 | directory: "/"
32 | labels:
33 | - "dependencies"
34 | - "CI"
35 | # Run every Monday
36 | schedule:
37 | interval: "weekly"
38 | timezone: "Europe/Zurich"
39 | groups:
40 | dependabot:
41 | patterns:
42 | - "*"
43 | ignore:
44 | - dependency-name: "*"
45 | # For github-actions, we only care about major version update
46 | update-types:
47 | - "version-update:semver-patch"
48 | - "version-update:semver-minor"
49 |
50 | - package-ecosystem: "pip"
51 | directory: "/"
52 | labels:
53 | - "dependencies"
54 | - "python"
55 | # Run every Monday
56 | schedule:
57 | interval: "weekly"
58 | timezone: "Europe/Zurich"
59 | # Group PRs to avoid having to rebase/merge too many
60 | groups:
61 | dependabot:
62 | patterns:
63 | - "*"
64 | # Only care about our direct dependencies
65 | allow:
66 | - dependency-type: "direct"
67 | ignore:
68 | # Filter out semver patches updates to reduce the frequency of updates
69 | - dependency-name: "*"
70 | update-types: ["version-update:semver-patch"]
71 |
72 | - package-ecosystem: "cargo"
73 | directory: "/rust/"
74 | labels:
75 | - "dependencies"
76 | - "rust"
77 | # Run every Monday
78 | schedule:
79 | interval: "weekly"
80 | timezone: "Europe/Zurich"
81 | # Group PRs to avoid having to rebase/merge too many
82 | groups:
83 | dependabot:
84 | patterns:
85 | - "*"
86 | # Only care about our direct dependencies
87 | allow:
88 | - dependency-type: "direct"
89 | ignore:
90 | # Filter out semver patches updates to reduce the frequency of updates
91 | - dependency-name: "*"
92 | update-types: ["version-update:semver-patch"]
93 | # Peer dependencies of Cargo, not to bump independently:
94 | - dependency-name: "semver"
95 | - dependency-name: "crates-io"
96 |
--------------------------------------------------------------------------------
/tests/io/itertools/test_lazy_pool.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import pytest
16 |
17 | from sedpack.io.itertools import LazyPool
18 |
19 |
20 | def test_doc() -> None:
21 |
22 | def f(x: int) -> int:
23 | return 2 * x
24 |
25 | with LazyPool(5) as pool:
26 | result = set(pool.imap_unordered(f, range(10)))
27 |
28 | assert result == set(f(i) for i in range(10))
29 |
30 |
31 | def test_break_from_loop() -> None:
32 | f = lambda x: list(range(x))
33 |
34 | with LazyPool(5) as pool:
35 | for l in pool.imap_unordered(f, range(1_000**5)):
36 | if len(l) > 50:
37 | break
38 | # Not hanging.
39 |
40 |
41 | def test_two_functions() -> None:
42 | with LazyPool(3) as pool:
43 |
44 | def f(x: int) -> tuple[int, int]:
45 | return (x, x)
46 |
47 | result = set(pool.imap_unordered(f, range(10)))
48 | print(result)
49 | assert result == set((i, i) for i in range(10))
50 |
51 | def g(x: int) -> str:
52 | return f"{x} is a number"
53 |
54 | result: set[str] = set( # type: ignore[no-redef]
55 | pool.imap_unordered(g, range(10)))
56 | assert result == set(
57 | f"{i} is a number"
58 | for i in range(10)) # type: ignore[comparison-overlap]
59 |
60 |
61 | def test_break_from_loop_break_and_iterate_again() -> None:
62 | f = lambda x: list(range(x))
63 |
64 | pool = LazyPool(5)
65 |
66 | with pool as pool:
67 | seen: int = 0
68 | for l in pool.imap_unordered(f, range(1_000**5)):
69 | seen += 1
70 | if len(l) >= 50:
71 | break
72 | assert seen == 51
73 |
74 | with pool as pool:
75 | seen = 0
76 | for l in pool.imap_unordered(f, range(1_000**5)):
77 | seen += 1
78 | if len(l) >= 50:
79 | break
80 | assert seen == 51
81 |
82 |
83 | def test_no_restart_while_imap() -> None:
84 | f = lambda x: list(range(x))
85 |
86 | pool = LazyPool(5)
87 |
88 | with pool as pool:
89 | seen: int = 0
90 | for l in pool.imap_unordered(f, range(1_000**5)):
91 | seen += 1
92 | if len(l) >= 50:
93 | break
94 | assert seen == 51
95 |
96 | # Did not finish the neither previous iteration nor the context.
97 | with pytest.raises(AssertionError):
98 | for l in pool.imap_unordered(f, range(10)):
99 | pass
100 |
--------------------------------------------------------------------------------
/src/sedpack/io/shard/shard.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023-2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Dataset shard manipulation.
15 | """
16 |
17 | from pathlib import Path
18 |
19 | from sedpack.io.metadata import DatasetStructure
20 | from sedpack.io.shard_file_metadata import ShardInfo
21 | from sedpack.io.types import ExampleT
22 | from sedpack.io.shard.shard_writer_base import ShardWriterBase
23 | from sedpack.io.shard.get_shard_writer import get_shard_writer
24 |
25 |
26 | class Shard():
27 | """A shard contains N measurement pertaining to the same key"""
28 |
29 | def __init__(self, shard_info: ShardInfo,
30 | dataset_structure: DatasetStructure,
31 | dataset_root_path: Path) -> None:
32 | """Collect information about a new shard.
33 |
34 | Args:
35 |
36 | shard_info (ShardInfo): Information about this shard.
37 |
38 | dataset_structure (DatasetStructure): The structure of data being
39 | saved.
40 |
41 | dataset_root_path (Path): Path to the dataset.
42 | """
43 | # Information needed to save the shard.
44 | self.shard_info: ShardInfo = shard_info
45 | self.dataset_structure: DatasetStructure = dataset_structure
46 | self._dataset_path: Path = dataset_root_path
47 |
48 | self._shard_writer: ShardWriterBase | None = get_shard_writer(
49 | dataset_structure=dataset_structure,
50 | shard_file=self._get_full_path(),
51 | )
52 |
53 | def write(self, values: ExampleT) -> None:
54 | """Write an example on disk as TFRecord.
55 |
56 | Args:
57 |
58 | values (ExampleT): Attribute values.
59 | """
60 | if not self._shard_writer:
61 | raise ValueError("Attempting to write to a closed shard.")
62 |
63 | self._shard_writer.write(values)
64 | self.shard_info.number_of_examples += 1
65 |
66 | def close(self) -> ShardInfo:
67 | """Close shard and return statistics.
68 | """
69 | if self._shard_writer is None:
70 | raise ValueError("Closing a shard which has not been open.")
71 |
72 | hash_checksums: tuple[str, ...] = self._shard_writer.close()
73 | self._shard_writer = None
74 |
75 | # Compute sha256 checksum.
76 | self.shard_info.file_infos[0].hash_checksums = hash_checksums
77 |
78 | # Return shard info.
79 | return self.shard_info
80 |
81 | def _get_full_path(self) -> Path:
82 | """Return full path to the shard file.
83 | """
84 | return self._dataset_path / self.shard_info.file_infos[0].file_path
85 |
--------------------------------------------------------------------------------
/rust/benches/my_benchmark.rs:
--------------------------------------------------------------------------------
1 | // Copyright 2025 Google LLC
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // https://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | use criterion::{Criterion, criterion_group, criterion_main};
16 | use glob::glob;
17 | use sedpack_rs::batch_iteration::BatchIterator;
18 | use sedpack_rs::example_iteration::{
19 | CompressionType, ExampleIterator, ShardInfo, get_shard_progress,
20 | };
21 | pub use sedpack_rs::parallel_map::parallel_map;
22 |
23 | pub fn get_shard_files() -> Vec {
24 | let shard_infos: Vec<_> = glob("mnist_fb_gzip/**/*.fb")
25 | .expect("Failed to load dataset")
26 | .filter_map(|p| p.ok())
27 | .map(|p| p.display().to_string())
28 | .map(|file_path| ShardInfo { file_path, compression_type: CompressionType::Gzip })
29 | .collect();
30 | println!(">> Decoding {} shards", shard_infos.len());
31 | assert_eq!(shard_infos.len(), 275);
32 | shard_infos
33 | }
34 |
35 | pub fn batch_iterator_benchmark_deterministic(c: &mut Criterion) {
36 | let shard_infos = get_shard_files();
37 | c.bench_function("BatchIterator", |b| {
38 | b.iter(|| {
39 | for batch in BatchIterator::new(shard_infos.clone(), 12, 32, vec![true, true], 0) {
40 | let _ = std::hint::black_box(batch);
41 | }
42 | })
43 | });
44 | }
45 |
46 | pub fn batch_iterator_benchmark_shuffled(c: &mut Criterion) {
47 | let shard_infos = get_shard_files();
48 | c.bench_function("BatchIterator", |b| {
49 | b.iter(|| {
50 | for batch in BatchIterator::new(shard_infos.clone(), 12, 32, vec![true, true], 256) {
51 | let _ = std::hint::black_box(batch);
52 | }
53 | })
54 | });
55 | }
56 |
57 | pub fn example_iterator_benchmark(c: &mut Criterion) {
58 | let shard_infos = get_shard_files();
59 | c.bench_function("ExampleIterator", |b| {
60 | b.iter(|| {
61 | for example in ExampleIterator::new(shard_infos.clone(), 12) {
62 | let _ = std::hint::black_box(example);
63 | }
64 | })
65 | });
66 | }
67 |
68 | pub fn parallel_map_benchmark(c: &mut Criterion) {
69 | let shard_infos = get_shard_files();
70 | c.bench_function("parallel_map", |b| {
71 | b.iter(|| {
72 | for shard in
73 | parallel_map(|x| get_shard_progress(&x), shard_infos.clone().into_iter(), 32)
74 | {
75 | let _ = std::hint::black_box(shard);
76 | }
77 | })
78 | });
79 | }
80 |
81 | criterion_group!(
82 | benches,
83 | batch_iterator_benchmark_deterministic,
84 | batch_iterator_benchmark_shuffled,
85 | example_iterator_benchmark,
86 | parallel_map_benchmark,
87 | );
88 | criterion_main!(benches);
89 |
--------------------------------------------------------------------------------
/.github/workflows/fork_pr_benchmarks_run.yml:
--------------------------------------------------------------------------------
1 | # Based on https://bencher.dev/docs/how-to/github-actions/
2 | name: Run Benchmarks
3 | permissions:
4 | contents: read
5 | pull-requests: write
6 |
7 | on:
8 | pull_request:
9 | types: [opened, edited]
10 | paths:
11 | - '**/*.rs'
12 | - 'rust/Cargo.toml'
13 | - 'rust/Cargo.lock'
14 |
15 | jobs:
16 | benchmark_fork_pr_branch:
17 | # Run only on the upstream repository.
18 | if: github.repository == 'google/sedpack'
19 | name: Run Fork PR Benchmarks
20 | runs-on: ubuntu-latest
21 | steps:
22 | - uses: actions/checkout@v6
23 |
24 | # begin: Use source sedpack package
25 | - uses: actions/checkout@v6
26 | - name: Set up Python 3.10
27 | uses: actions/setup-python@v6
28 | with:
29 | python-version: '3.10'
30 | cache: 'pip'
31 | - name: Get pip cache directory
32 | id: pip-cache
33 | shell: bash
34 | run: |
35 | echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
36 | - name: Use cached venv or create it
37 | uses: actions/cache/restore@v5
38 | id: cache
39 | with:
40 | path: ${{ steps.pip-cache.outputs.dir }}
41 | # The cache key depends on requirements.txt
42 | key: ${{ matrix.platform.runner }}-pip-${{ hashFiles('pyproject.toml') }}-${{ hashFiles('requirements.txt') }}
43 | # Build a virtualenv, but only if it doesn't already exist
44 | - name: Populate pip cache
45 | # requirements.txt is not reliable since across different platforms and
46 | # their versions the pip package versions might vary. We regenerate it
47 | # again from pyproject.toml every time when pyproject.toml or
48 | # requirements.txt changes. The pinned versions in requirements.txt are
49 | # tested by coverage since that is running on ubuntu which is also used
50 | # to produce the main requirements.txt file.
51 | run: |
52 | pip install pip-tools
53 | pip-compile --generate-hashes --extra dev pyproject.toml > requirements.txt
54 | pip install -r requirements.txt
55 | if: steps.cache.outputs.cache-hit != 'true'
56 | - name: Save cache
57 | id: cache-save
58 | uses: actions/cache/save@v5
59 | with:
60 | path: ${{ steps.pip-cache.outputs.dir }}
61 | key: ${{ steps.cache.outputs.cache-primary-key }}
62 | if: steps.cache.outputs.cache-hit != 'true'
63 | - name: Install sedpack locally
64 | run: pip install --editable .
65 | # end: Use source sedpack package
66 |
67 | - name: Prepare benchmarking data
68 | run: (cd rust/ ; bash benches/setup.sh )
69 |
70 | - uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 # v1
71 | - uses: bencherdev/bencher@f89d454e74a32a81b2eab29fe0afdb2316617342 # v0.5
72 | - name: Benchmarking
73 | run: (cd rust/ ; cargo bench > benchmark_results.txt)
74 | - name: Upload Benchmark Results
75 | uses: actions/upload-artifact@v6
76 | with:
77 | name: benchmark_results.txt
78 | path: ./rust/benchmark_results.txt
79 | - name: Upload GitHub Pull Request Event
80 | uses: actions/upload-artifact@v6
81 | with:
82 | name: event.json
83 | path: ${{ github.event_path }}
84 |
--------------------------------------------------------------------------------
/tests/io/test_compression.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import gzip
16 | from pathlib import Path
17 | from typing import Union
18 |
19 | import pytest
20 |
21 | from sedpack.io.compress import CompressedFile
22 | from sedpack.io.types import CompressionT
23 |
24 |
25 | def test_compress_gzip_write(tmpdir: str | Path) -> None:
26 | file_name = tmpdir / "compressed_file"
27 | payload = bytes(x % 13 for x in range(10 * (1024**2))) # 10MB
28 |
29 | # Can read what gzip writes
30 | with gzip.open(file_name, "wb") as f:
31 | f.write(payload)
32 |
33 | with open(file_name, "rb") as f:
34 | assert CompressedFile("GZIP").decompress(f.read()) == payload
35 |
36 |
37 | def test_compress_gzip_read(tmpdir: str | Path) -> None:
38 | file_name = tmpdir / "compressed_file"
39 | payload = bytes(x % 13 for x in range(10 * (1024**2))) # 10MB
40 |
41 | # Write is readable by gzip
42 | with open(file_name, "wb") as f:
43 | f.write(CompressedFile("GZIP").compress(payload))
44 |
45 | with gzip.open(file_name, "rb") as f:
46 | assert f.read() == payload
47 |
48 |
49 | def test_compress_decompress_file(tmpdir: str | Path) -> None:
50 | file_name = tmpdir / "compressed_file"
51 | payload = bytes(x % 13 for x in range(10 * (1024**2))) # 10MB
52 |
53 | for algorithm in CompressedFile.supported_compressions():
54 | with open(file_name, "wb") as f:
55 | f.write(CompressedFile(algorithm).compress(payload))
56 |
57 | with open(file_name, "rb") as f:
58 | assert CompressedFile(algorithm).decompress(f.read()) == payload
59 |
60 |
61 | def test_compress_decompress_in_memory() -> None:
62 | payload = bytes(x % 13 for x in range(10 * (1024**2))) # 10MB
63 |
64 | for algorithm in CompressedFile.supported_compressions():
65 | # Decompress of compress is the same.
66 | assert CompressedFile(algorithm).decompress(
67 | CompressedFile(algorithm).compress(payload)) == payload
68 |
69 |
70 | def test_compresses() -> None:
71 | # This should be compressible
72 | payload = bytes(x % 13 for x in range(10 * (1024**2))) # 10MB
73 |
74 | assert len(CompressedFile("GZIP").compress(payload)) < len(payload)
75 |
76 |
77 | @pytest.mark.parametrize("compression", CompressedFile.supported_compressions())
78 | def test_compression_works(compression: CompressionT,
79 | tmpdir: Union[str, Path]) -> None:
80 | file_name = tmpdir / "compressed_file"
81 | payload = bytes(x % 13 for x in range(10 * (1024**2))) # 10MB
82 |
83 | # Write is readable by gzip
84 | with open(file_name, "wb") as f:
85 | compressor = CompressedFile(compression)
86 | f.write(compressor.compress(payload))
87 |
88 | with open(file_name, "rb") as f:
89 | assert CompressedFile(compression).decompress(f.read()) == payload
90 |
--------------------------------------------------------------------------------
/tests/io/test_rust_iter.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023-2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from pathlib import Path
16 | from typing import get_args, Union
17 |
18 | import numpy as np
19 | import numpy.typing as npt
20 | import pytest
21 |
22 | import sedpack
23 | from sedpack.io import Dataset
24 | from sedpack.io import Metadata
25 | from sedpack.io.types import TRAIN_SPLIT, CompressionT, ShardFileTypeT
26 |
27 | from sedpack import _sedpack_rs # type: ignore[attr-defined]
28 |
29 |
30 | def end2end(tmpdir: Union[str, Path], dtype: npt.DTypeLike,
31 | shard_file_type: ShardFileTypeT, compression: CompressionT) -> None:
32 | array_of_values = np.random.random((1024, 138))
33 | array_of_values = array_of_values.astype(dtype)
34 |
35 | tiny_experiment_path: Path = Path(tmpdir) / "e2e_experiment"
36 |
37 | # Create a dataset
38 |
39 | dataset_metadata = Metadata(description="Test of the lib")
40 |
41 | example_attributes = [
42 | sedpack.io.metadata.Attribute(
43 | name="attribute_name",
44 | dtype=str(dtype),
45 | shape=array_of_values[0].shape,
46 | ),
47 | ]
48 |
49 | dataset_structure = sedpack.io.metadata.DatasetStructure(
50 | saved_data_description=example_attributes,
51 | compression=compression,
52 | examples_per_shard=256,
53 | shard_file_type=shard_file_type,
54 | )
55 |
56 | dataset = Dataset.create(
57 | path=tiny_experiment_path,
58 | metadata=dataset_metadata,
59 | dataset_structure=dataset_structure,
60 | )
61 |
62 | # Fill data in the dataset
63 |
64 | with dataset.filler(concurrency=np.random.randint(0, 4),) as filler:
65 | for attribute_value in array_of_values:
66 | filler.write_example(
67 | values={"attribute_name": attribute_value},
68 | split=TRAIN_SPLIT,
69 | )
70 |
71 | # Check the data is correct
72 | # Reopen the dataset
73 | dataset = Dataset(tiny_experiment_path)
74 | dataset.check()
75 |
76 | for i, example in enumerate(
77 | dataset.as_numpy_iterator_rust(split=TRAIN_SPLIT,
78 | repeat=False,
79 | shuffle=0)):
80 | assert np.allclose(example["attribute_name"], array_of_values[i])
81 |
82 | # We tested everything
83 | assert i + 1 == array_of_values.shape[
84 | 0], "Not all examples have been iterated"
85 |
86 |
87 | @pytest.mark.parametrize("compression",
88 | list(_sedpack_rs.RustIter.supported_compressions()))
89 | def test_end2end_as_numpy_iterator_fb(compression: str,
90 | tmpdir: Union[str, Path]) -> None:
91 | end2end(
92 | tmpdir=tmpdir,
93 | dtype="float32",
94 | shard_file_type="fb",
95 | compression=compression,
96 | )
97 |
98 |
99 | def test_correct_compressions():
100 | for compression in _sedpack_rs.RustIter.supported_compressions():
101 | assert compression in set(get_args(CompressionT))
102 |
--------------------------------------------------------------------------------
/src/sedpack/io/shard/shard_writer_tfrec.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024-2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Dataset shard manipulation.
15 |
16 | For information how to read and write TFRecord files see
17 | https://www.tensorflow.org/tutorials/load_data/tfrecord
18 | """
19 |
20 | from pathlib import Path
21 | from typing import Any
22 |
23 | from sedpack.io.metadata import DatasetStructure
24 | from sedpack.io.tfrec.tfdata import to_tfrecord
25 | from sedpack.io.types import ExampleT, CompressionT
26 | from sedpack.io.shard.shard_writer_base import ShardWriterBase
27 |
28 |
29 | class ShardWriterTFRec(ShardWriterBase):
30 | """Shard writing capabilities.
31 | """
32 |
33 | def __init__(self, dataset_structure: DatasetStructure,
34 | shard_file: Path) -> None:
35 | """Collect information about a new shard.
36 |
37 | Args:
38 |
39 | dataset_structure (DatasetStructure): The structure of data being
40 | saved.
41 |
42 | shard_file (Path): Full path to the shard file.
43 | """
44 | assert dataset_structure.shard_file_type == "tfrec"
45 |
46 | super().__init__(
47 | dataset_structure=dataset_structure,
48 | shard_file=shard_file,
49 | )
50 |
51 | # Open the tf.io.TFRecordWriter only with the first `write` call. Make
52 | # it None immediately during a call to `close`.
53 | self._tf_shard_writer: Any | None = None
54 |
55 | def _write(self, values: ExampleT) -> None:
56 | """Write an example on disk. Writing may be buffered.
57 |
58 | Args:
59 |
60 | values (ExampleT): Attribute values.
61 | """
62 | # TensorFlow is an optional dependency.
63 | import tensorflow as tf # pylint: disable=import-outside-toplevel
64 |
65 | if (self.dataset_structure.compression
66 | not in ShardWriterTFRec.supported_compressions()):
67 | raise ValueError(
68 | f"Unsupported compression {self.dataset_structure.compression}"
69 | " requested for TFRecordWriter, expected "
70 | f"{ShardWriterTFRec.supported_compressions()}")
71 | if not self._tf_shard_writer:
72 | self._tf_shard_writer = tf.io.TFRecordWriter(
73 | str(self._shard_file),
74 | self.dataset_structure.compression, # type: ignore[arg-type]
75 | )
76 |
77 | example = to_tfrecord(
78 | saved_data_description=self.dataset_structure.
79 | saved_data_description,
80 | values=values,
81 | )
82 | self._tf_shard_writer.write(example)
83 |
84 | def close(self) -> tuple[str, ...]:
85 | """Close the shard file(-s).
86 | """
87 | if not self._tf_shard_writer:
88 | raise ValueError("Trying to close a shard that was not open")
89 | self._tf_shard_writer.close()
90 | self._tf_shard_writer = None
91 | return self._compute_file_hash_checksums()
92 |
93 | @staticmethod
94 | def supported_compressions() -> list[CompressionT]:
95 | """Return a list of supported compression types.
96 | """
97 | return ["GZIP", "ZLIB", ""]
98 |
--------------------------------------------------------------------------------
/tests/io/test_end2end_async.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from pathlib import Path
16 | import pytest
17 | from typing import Union
18 |
19 | import numpy as np
20 | import numpy.typing as npt
21 |
22 | import sedpack
23 | from sedpack.io import Dataset
24 | from sedpack.io import Metadata
25 | from sedpack.io.types import TRAIN_SPLIT, CompressionT, ShardFileTypeT
26 |
27 | pytest_plugins = ("pytest_asyncio",)
28 |
29 |
30 | async def end2end(tmpdir: Union[str, Path], dtype: npt.DTypeLike, method: str,
31 | shard_file_type: ShardFileTypeT,
32 | compression: CompressionT) -> None:
33 | array_of_values = np.random.random((1024, 138))
34 | array_of_values = array_of_values.astype(dtype)
35 |
36 | tiny_experiment_path: Path = Path(tmpdir) / "e2e_experiment"
37 |
38 | # Create a dataset
39 |
40 | dataset_metadata = Metadata(description="Test of the lib")
41 |
42 | example_attributes = [
43 | sedpack.io.metadata.Attribute(
44 | name="attribute_name",
45 | dtype=str(dtype),
46 | shape=array_of_values[0].shape,
47 | ),
48 | ]
49 |
50 | dataset_structure = sedpack.io.metadata.DatasetStructure(
51 | saved_data_description=example_attributes,
52 | compression=compression,
53 | examples_per_shard=256,
54 | shard_file_type=shard_file_type,
55 | )
56 |
57 | dataset = Dataset.create(
58 | path=tiny_experiment_path,
59 | metadata=dataset_metadata,
60 | dataset_structure=dataset_structure,
61 | )
62 |
63 | # Fill data in the dataset
64 |
65 | with dataset.filler() as filler:
66 | for attribute_value in array_of_values:
67 | filler.write_example(
68 | values={"attribute_name": attribute_value},
69 | split=TRAIN_SPLIT,
70 | )
71 |
72 | # Check the data is correct
73 | # Reopen the dataset
74 | dataset = Dataset(tiny_experiment_path)
75 | dataset.check()
76 |
77 | i = 0
78 | async for example in dataset.as_numpy_iterator_async(split=TRAIN_SPLIT,
79 | shuffle=0,
80 | repeat=False):
81 | assert np.allclose(example["attribute_name"], array_of_values[i])
82 | i += 1
83 |
84 | # We tested everything
85 | assert i == array_of_values.shape[0], "Not all examples have been iterated"
86 |
87 |
88 | @pytest.mark.asyncio
89 | async def test_end2end_as_numpy_iterator_npz(tmpdir: Union[str, Path]) -> None:
90 | await end2end(tmpdir=tmpdir,
91 | dtype="float32",
92 | method="as_numpy_iterator",
93 | shard_file_type="npz",
94 | compression="ZIP")
95 |
96 |
97 | @pytest.mark.asyncio
98 | async def test_end2end_as_numpy_iterator_fb(tmpdir: Union[str, Path]) -> None:
99 | await end2end(tmpdir=tmpdir,
100 | dtype="float32",
101 | method="as_numpy_iterator",
102 | shard_file_type="fb",
103 | compression="LZ4")
104 |
--------------------------------------------------------------------------------
/tests/io/test_write_multiprocessing.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from pathlib import Path
16 | from typing import Union
17 |
18 | import numpy as np
19 | import numpy.typing as npt
20 |
21 | import sedpack
22 | from sedpack.io import Dataset, DatasetFiller
23 | from sedpack.io import Metadata
24 | from sedpack.io.types import SplitT, TRAIN_SPLIT
25 |
26 |
27 | def feed_writer(dataset_filler: DatasetFiller,
28 | array_of_values: npt.NDArray[np.generic], split: SplitT) -> int:
29 | # Fill data in the dataset
30 |
31 | with dataset_filler as filler:
32 | for attribute_value in array_of_values:
33 | filler.write_example(
34 | values={"attribute_name": attribute_value},
35 | split=split,
36 | )
37 |
38 | return len(array_of_values)
39 |
40 |
41 | def test_write_multiprocessing(tmpdir: Union[str, Path]) -> None:
42 | dtype = "float32"
43 | array_of_values = np.random.random((1024, 138))
44 | array_of_values = array_of_values.astype(dtype)
45 |
46 | tiny_experiment_path: Path = Path(tmpdir) / "write_multiprocessing"
47 |
48 | # Create a dataset
49 |
50 | dataset_metadata = Metadata(
51 | description="Test of the lib write_multiprocessing")
52 |
53 | example_attributes = [
54 | sedpack.io.metadata.Attribute(
55 | name="attribute_name",
56 | dtype=str(dtype),
57 | shape=array_of_values[0].shape,
58 | ),
59 | ]
60 |
61 | dataset_structure = sedpack.io.metadata.DatasetStructure(
62 | saved_data_description=example_attributes,
63 | compression="GZIP",
64 | examples_per_shard=256,
65 | shard_file_type="fb",
66 | )
67 |
68 | dataset = Dataset.create(
69 | path=tiny_experiment_path,
70 | metadata=dataset_metadata,
71 | dataset_structure=dataset_structure,
72 | )
73 |
74 | custom_arguments = [
75 | (array_of_values[:100],),
76 | (array_of_values[100:900],),
77 | (array_of_values[900:],),
78 | ]
79 | custom_kwarguments = [
80 | {
81 | "split": "train"
82 | },
83 | {
84 | "split": "train"
85 | },
86 | {
87 | "split": "train"
88 | },
89 | ]
90 |
91 | results = dataset.write_multiprocessing(
92 | feed_writer=feed_writer,
93 | custom_arguments=custom_arguments,
94 | custom_kwarguments=custom_kwarguments,
95 | single_process=True)
96 |
97 | assert results == [
98 | len(part_of_array_of_values[0])
99 | for part_of_array_of_values in custom_arguments
100 | ]
101 |
102 | # Check the data is correct
103 |
104 | for i, example in enumerate(
105 | dataset.as_numpy_iterator_rust_batched(
106 | split=TRAIN_SPLIT,
107 | shuffle=0,
108 | repeat=False,
109 | batch_size=1,
110 | )):
111 | assert np.allclose(example["attribute_name"], array_of_values[i:i + 1])
112 |
113 | # We tested everything
114 | assert i + 1 == array_of_values.shape[
115 | 0], "Not all examples have been iterated"
116 |
--------------------------------------------------------------------------------
/src/sedpack/io/flatbuffer/shardfile/Example.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # automatically generated by the FlatBuffers compiler, do not modify
16 |
17 | # namespace: shardfile
18 |
19 | # pylint: skip-file
20 |
21 | import flatbuffers
22 | from flatbuffers.builder import Builder
23 | from flatbuffers.compat import import_numpy
24 |
25 | from sedpack.io.flatbuffer.shardfile.Attribute import Attribute
26 |
27 | np = import_numpy()
28 |
29 |
30 | class Example(object):
31 | __slots__ = ['_tab']
32 |
33 | @classmethod
34 | def GetRootAs(cls, buf: bytes, offset: int = 0) -> "Example":
35 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
36 | x = Example()
37 | x.Init(buf, n + offset)
38 | return x
39 |
40 | # Example
41 | def Init(self, buf: bytes, pos: int) -> None:
42 | self._tab = flatbuffers.table.Table(buf, pos)
43 |
44 | # Example
45 | def Attributes(self, j: int) -> Attribute | None:
46 | o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
47 | if o != 0:
48 | x = self._tab.Vector(o)
49 | x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4
50 | x = self._tab.Indirect(x)
51 | obj = Attribute()
52 | obj.Init(self._tab.Bytes, x)
53 | return obj
54 | return None
55 |
56 | # Example
57 | def AttributesLength(self) -> int:
58 | o: int = flatbuffers.number_types.UOffsetTFlags.py_type(
59 | self._tab.Offset(4))
60 | if o != 0:
61 | return self._tab.VectorLen(o) # type: ignore[no-any-return]
62 | return 0
63 |
64 | # Example
65 | def AttributesIsNone(self) -> bool:
66 | o: int = flatbuffers.number_types.UOffsetTFlags.py_type(
67 | self._tab.Offset(4))
68 | return o == 0
69 |
70 |
71 | def ExampleStart(builder: Builder) -> None: # type: ignore[no-any-unimported]
72 | builder.StartObject(1)
73 |
74 |
75 | def Start(builder: Builder) -> None: # type: ignore[no-any-unimported]
76 | ExampleStart(builder)
77 |
78 |
79 | def ExampleAddAttributes( # type: ignore[no-any-unimported]
80 | builder: Builder, attributes: int) -> None:
81 | builder.PrependUOffsetTRelativeSlot(
82 | 0, flatbuffers.number_types.UOffsetTFlags.py_type(attributes), 0)
83 |
84 |
85 | def AddAttributes( # type: ignore[no-any-unimported]
86 | builder: Builder, attributes: int) -> None:
87 | ExampleAddAttributes(builder, attributes)
88 |
89 |
90 | def ExampleStartAttributesVector( # type: ignore[no-any-unimported]
91 | builder: Builder, numElems: int) -> int:
92 | return builder.StartVector(4, numElems, 4) # type: ignore[no-any-return]
93 |
94 |
95 | def StartAttributesVector( # type: ignore[no-any-unimported]
96 | builder: Builder, numElems: int) -> int:
97 | return ExampleStartAttributesVector(builder, numElems)
98 |
99 |
100 | def ExampleEnd(builder: Builder) -> int: # type: ignore[no-any-unimported]
101 | return builder.EndObject() # type: ignore[no-any-return]
102 |
103 |
104 | def End(builder: Builder) -> int: # type: ignore[no-any-unimported]
105 | return ExampleEnd(builder)
106 |
--------------------------------------------------------------------------------
/src/sedpack/io/shard/shard_writer_base.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024-2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Base class for shard writing depending on shard_file_type.
15 | """
16 |
17 | from abc import ABC, abstractmethod
18 | from pathlib import Path
19 |
20 | import numpy as np
21 |
22 | import sedpack
23 | from sedpack.io.metadata import DatasetStructure
24 | from sedpack.io.types import ExampleT, CompressionT
25 |
26 |
27 | class ShardWriterBase(ABC):
28 | """Shard writing capabilities.
29 | """
30 |
31 | def __init__(self, dataset_structure: DatasetStructure,
32 | shard_file: Path) -> None:
33 | """Collect information about a new shard.
34 |
35 | Args:
36 |
37 | dataset_structure (DatasetStructure): The structure of data being
38 | saved.
39 |
40 | shard_file (Path): Full path to the shard file.
41 | """
42 | # Information needed to save the shard.
43 | self.dataset_structure: DatasetStructure = dataset_structure
44 | self._shard_file: Path = shard_file
45 |
46 | # Make sure that the directory exists.
47 | self._shard_file.parent.mkdir(exist_ok=True, parents=True)
48 |
49 | # Make sure that the compression is supported for this shard file type.
50 | assert dataset_structure.compression in self.supported_compressions()
51 |
52 | def write(self, values: ExampleT) -> None:
53 | """Write an example on disk. Writing may be buffered.
54 |
55 | Args:
56 |
57 | values (ExampleT): Attribute values.
58 | """
59 | # Check the values are correct type and shape.
60 | for attribute in self.dataset_structure.saved_data_description:
61 | # If the attribute dtype is "bytes" and shape is empty tuple we
62 | # consider this a variable size attribute and do not check shape.
63 | if attribute.has_variable_size():
64 | continue
65 |
66 | # Else check the shape (the value should be a NumPy array but maybe
67 | # it is an int or bytearray).
68 | current_shape = np.array(values[attribute.name]).shape
69 | if current_shape != attribute.shape:
70 | raise ValueError(f"Attribute {attribute.name} has shape "
71 | f"{current_shape} expected {attribute.shape}")
72 |
73 | self._write(values=values)
74 |
75 | @abstractmethod
76 | def _write(self, values: ExampleT) -> None:
77 | """Write an example on disk. Writing may be buffered.
78 |
79 | Args:
80 |
81 | values (ExampleT): Attribute values.
82 | """
83 |
84 | @abstractmethod
85 | def close(self) -> tuple[str, ...]:
86 | """Close the shard file(-s).
87 | """
88 |
89 | @staticmethod
90 | @abstractmethod
91 | def supported_compressions() -> list[CompressionT]:
92 | """Return a list of supported compression types.
93 | """
94 |
95 | def _compute_file_hash_checksums(self) -> tuple[str, ...]:
96 | """Compute hash checksums of the shard file(-s). """
97 | return sedpack.io.utils.hash_checksums(
98 | file_path=self._shard_file,
99 | hashes=self.dataset_structure.hash_checksum_algorithms,
100 | )
101 |
--------------------------------------------------------------------------------
/tests/io/shard_info_iterator/test_balanced_iterator.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import itertools
16 |
17 | from sedpack.io.shard_file_metadata import ShardInfo
18 | from sedpack.io.shard_info_iterator.balanced_iterator import _split_balancing
19 |
20 |
21 | def test_split_balancing_all() -> None:
22 | shard_list: list[ShardInfo] = [
23 | ShardInfo(
24 | file_infos=(),
25 | number_of_examples=1,
26 | custom_metadata={
27 | "id": i,
28 | "id_mod_3_is_zero": i % 3 == 0,
29 | },
30 | ) for i in range(100)
31 | ]
32 |
33 | balanced = _split_balancing(
34 | shard_list=shard_list,
35 | balance_by=[
36 | lambda shard_info: shard_info.custom_metadata["id_mod_3_is_zero"]
37 | ],
38 | repeat=False,
39 | shuffle=10,
40 | )
41 |
42 | assert set(
43 | shard_info.custom_metadata["id"] for shard_info in balanced) == set(
44 | shard_info.custom_metadata["id"] for shard_info in shard_list)
45 |
46 |
47 | def test_split_balancing_balances() -> None:
48 | shard_list: list[ShardInfo] = [
49 | ShardInfo(
50 | file_infos=(),
51 | number_of_examples=1,
52 | custom_metadata={
53 | "id": i,
54 | "id_mod_3_is_zero": i % 3 == 0,
55 | },
56 | ) for i in range(100)
57 | ]
58 |
59 | balanced = _split_balancing(
60 | shard_list=shard_list,
61 | balance_by=[
62 | lambda shard_info: shard_info.custom_metadata["id_mod_3_is_zero"]
63 | ],
64 | repeat=True,
65 | shuffle=10,
66 | )
67 |
68 | take_n = 1_000
69 | assert sum(shard_info.custom_metadata["id_mod_3_is_zero"] for shard_info in
70 | list(itertools.islice(balanced, take_n))) == take_n // 2
71 |
72 |
73 | def test_custom_weight() -> None:
74 | shard_list: list[ShardInfo] = [
75 | ShardInfo(
76 | file_infos=(),
77 | number_of_examples=1,
78 | custom_metadata={
79 | "id": i,
80 | "id_mod_3_is_zero": i % 3 == 0,
81 | },
82 | ) for i in range(100)
83 | ]
84 |
85 | class BalanceBy:
86 |
87 | def __call__(self, shard_info: ShardInfo) -> bool:
88 | return shard_info.custom_metadata["id_mod_3_is_zero"]
89 |
90 | def weight(self, shard_info: ShardInfo) -> float:
91 | if shard_info.custom_metadata["id_mod_3_is_zero"]:
92 | # Do four times more of the zeros. Meaning for each non-zero
93 | # example there are four zero examples -> 80% of the zero
94 | # examples.
95 | return 4
96 | else:
97 | return 1
98 |
99 | balance_by_top = BalanceBy()
100 |
101 | balanced = _split_balancing(
102 | shard_list=shard_list,
103 | balance_by=[
104 | balance_by_top,
105 | ],
106 | repeat=True,
107 | shuffle=10,
108 | )
109 |
110 | take_n = 1_000
111 | assert sum(shard_info.custom_metadata["id_mod_3_is_zero"] for shard_info in
112 | list(itertools.islice(balanced, take_n))) == take_n * (4 / 5)
113 |
--------------------------------------------------------------------------------
/src/sedpack/io/flatbuffer/shardfile/Shard.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # automatically generated by the FlatBuffers compiler, do not modify
16 | # Added best-effort type hints since this file is not likely to change.
17 |
18 | # namespace: shardfile
19 |
20 | # pylint: skip-file
21 |
22 | import flatbuffers
23 | from flatbuffers import Builder
24 | from flatbuffers.compat import import_numpy
25 |
26 | from sedpack.io.flatbuffer.shardfile.Example import Example
27 |
28 | np = import_numpy()
29 |
30 |
31 | class Shard(object):
32 | __slots__ = ['_tab']
33 |
34 | @classmethod
35 | def GetRootAs(cls, buf: bytes, offset: int = 0) -> "Shard":
36 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
37 | x = Shard()
38 | x.Init(buf, n + offset)
39 | return x
40 |
41 | # Shard
42 | def Init(self, buf: bytes, pos: int) -> None:
43 | self._tab = flatbuffers.table.Table(buf, pos)
44 |
45 | # Shard
46 | def Examples(self, j: int) -> Example | None:
47 | o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
48 | if o != 0:
49 | x = self._tab.Vector(o)
50 | x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4
51 | x = self._tab.Indirect(x)
52 | obj = Example()
53 | obj.Init(self._tab.Bytes, x)
54 | return obj
55 | return None
56 |
57 | # Shard
58 | def ExamplesLength(self) -> int:
59 | o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
60 | if o != 0:
61 | return self._tab.VectorLen(o) # type: ignore[no-any-return]
62 | return 0
63 |
64 | # Shard
65 | def ExamplesIsNone(self) -> bool:
66 | o: int = flatbuffers.number_types.UOffsetTFlags.py_type(
67 | self._tab.Offset(4))
68 | return o == 0
69 |
70 |
71 | def ShardStart(builder: Builder) -> None: # type: ignore[no-any-unimported]
72 | builder.StartObject(1)
73 |
74 |
75 | def Start(builder: Builder) -> None: # type: ignore[no-any-unimported]
76 | ShardStart(builder)
77 |
78 |
79 | def ShardAddExamples( # type: ignore[no-any-unimported]
80 | builder: Builder,
81 | examples: int,
82 | ) -> None:
83 | builder.PrependUOffsetTRelativeSlot(
84 | 0,
85 | flatbuffers.number_types.UOffsetTFlags.py_type(examples),
86 | 0,
87 | )
88 |
89 |
90 | def AddExamples( # type: ignore[no-any-unimported]
91 | builder: Builder,
92 | examples: int,
93 | ) -> None:
94 | ShardAddExamples(builder, examples)
95 |
96 |
97 | def ShardStartExamplesVector( # type: ignore[no-any-unimported]
98 | builder: Builder,
99 | numElems: int,
100 | ) -> int:
101 | return builder.StartVector(4, numElems, 4) # type: ignore[no-any-return]
102 |
103 |
104 | def StartExamplesVector( # type: ignore[no-any-unimported]
105 | builder: Builder, numElems: int) -> int:
106 | return ShardStartExamplesVector(
107 | builder,
108 | numElems,
109 | )
110 |
111 |
112 | def ShardEnd(builder: Builder) -> int: # type: ignore[no-any-unimported]
113 | return builder.EndObject() # type: ignore[no-any-return]
114 |
115 |
116 | def End(builder: Builder) -> int: # type: ignore[no-any-unimported]
117 | return ShardEnd(builder)
118 |
--------------------------------------------------------------------------------
/src/sedpack/io/tfrec/read.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Read a dataset using pure asyncio.
15 | """
16 |
17 | import asyncio
18 | import os
19 | from pathlib import Path
20 | from typing import Any, AsyncIterator, Callable, Iterable
21 |
22 | import tensorflow as tf
23 |
24 | from sedpack.io.metadata import DatasetStructure
25 | from sedpack.io.shard import IterateShardBase
26 | from sedpack.io.shard.iterate_shard_base import T
27 | from sedpack.io.tfrec.tfdata import get_from_tfrecord
28 | from sedpack.io.types import ExampleT
29 | from sedpack.io.utils import func_or_identity
30 |
31 |
32 | class IterateShardTFRec(IterateShardBase[T]):
33 | """Iterate a TFRec shard.
34 | """
35 |
36 | def __init__(
37 | self,
38 | dataset_structure: DatasetStructure,
39 | process_record: Callable[[ExampleT], Any] | None,
40 | num_parallel_calls: int = os.cpu_count() or 4,
41 | ) -> None:
42 | super().__init__(dataset_structure=dataset_structure,
43 | process_record=process_record)
44 | # This is not pickleable, but can be created on the fly.
45 | self.from_tfrecord: Callable[[Any], Any] | None = None
46 | self.num_parallel_calls: int = num_parallel_calls
47 |
48 | def iterate_shard(self, file_path: Path) -> Iterable[ExampleT]:
49 | """Iterate a shard saved in the TFRec format
50 | """
51 | if not self.from_tfrecord:
52 | self.from_tfrecord = get_from_tfrecord(
53 | self.dataset_structure.saved_data_description)
54 |
55 | # Read the shard.
56 | tf_dataset_records = tf.data.TFRecordDataset(
57 | str(file_path),
58 | compression_type=self.dataset_structure.
59 | compression, # type: ignore[arg-type]
60 | )
61 |
62 | # Decode examples.
63 | tf_dataset_examples = tf_dataset_records.map(
64 | self.from_tfrecord,
65 | num_parallel_calls=self.num_parallel_calls,
66 | )
67 |
68 | yield from tf_dataset_examples.as_numpy_iterator() # type: ignore[misc]
69 |
70 | # TODO(issue #85) fix and test async iterator typing
71 | async def iterate_shard_async( # pylint: disable=invalid-overridden-method
72 | self,
73 | file_path: Path,
74 | ) -> AsyncIterator[ExampleT]:
75 | for example in self.iterate_shard(file_path=file_path):
76 | yield example
77 | # Give up event loop (a bit dirty).
78 | await asyncio.sleep(0)
79 |
80 | def process_and_list(self, shard_file: Path) -> list[T]:
81 | """Return a list of processed examples. Used as a function call in a
82 | different process. Returning a list as opposed to an iterator allows to
83 | do all work in another process and all that needs to be done is a
84 | memory copy between processes.
85 |
86 | TODO think of a way to avoid copying memory between processes.
87 |
88 | Args:
89 |
90 | shard_file (Path): Path to the shard file.
91 |
92 | Returns a list of examples present in the shard identified by the path
93 | where a `process_record` function has been applied (if not None).
94 | """
95 | process_record = func_or_identity(self.process_record)
96 |
97 | return [
98 | process_record(example)
99 | for example in self.iterate_shard(file_path=shard_file)
100 | ]
101 |
--------------------------------------------------------------------------------
/src/sedpack/io/compress.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Open a file with specified compression.
15 | """
16 |
17 | import bz2
18 | import gzip
19 | import lzma
20 |
21 | import lz4.frame
22 | import zstandard as zstd
23 |
24 | from sedpack.io.types import CompressionT
25 |
26 |
27 | class CompressedFile:
28 | """Provide an easy open function for dealing with compressed files.
29 | """
30 |
31 | def __init__(self, compression_type: CompressionT) -> None:
32 | """Initialize a compressed file opening.
33 |
34 | compression_type (CompressionT): The type of compression. Note that ZIP
35 | is not supported yet.
36 | """
37 | self.compression_type: CompressionT = compression_type
38 |
39 | if compression_type in ["ZIP"]:
40 | # Zip is a container meaning we open something.zip and inside that
41 | # we open file(-s). This requires more work on the context manager
42 | # side. Not implementing yet.
43 | raise NotImplementedError(f"Compression {compression_type} is not "
44 | f"supported yet by CompressedFile")
45 |
46 | @staticmethod
47 | def supported_compressions() -> list[CompressionT]:
48 | """Return a list of supported compression types.
49 | """
50 | return [
51 | "",
52 | "BZ2",
53 | "GZIP",
54 | "LZMA",
55 | "LZ4",
56 | "ZLIB",
57 | "ZSTD",
58 | ]
59 |
60 | def compress(self, data: bytes) -> bytes:
61 | """Compression.
62 |
63 | Args:
64 |
65 | data (bytes): Content to compress.
66 |
67 | Returns: the compressed data.
68 | """
69 | match self.compression_type:
70 | case "":
71 | return data
72 | case "GZIP" | "ZLIB":
73 | return gzip.compress(data, compresslevel=9)
74 | case "BZ2":
75 | return bz2.compress(data, compresslevel=9)
76 | case "LZMA":
77 | return lzma.compress(data)
78 | case "LZ4":
79 | return lz4.frame.compress(data) # type: ignore[no-any-return]
80 | case "ZSTD":
81 | return zstd.compress(data)
82 | case _:
83 | raise NotImplementedError(f"CompressedFile does not implement "
84 | f"{self.compression_type} yet.")
85 |
86 | def decompress(self, data: bytes) -> bytes:
87 | """Decompression.
88 |
89 | Args:
90 |
91 | data (bytes): Content of the file to be decompressed.
92 |
93 | Returns: the decompressed data.
94 | """
95 | match self.compression_type:
96 | case "":
97 | return data
98 | case "GZIP" | "ZLIB":
99 | return gzip.decompress(data)
100 | case "BZ2":
101 | return bz2.decompress(data)
102 | case "LZMA":
103 | return lzma.decompress(data)
104 | case "LZ4":
105 | return lz4.frame.decompress(data) # type: ignore[no-any-return]
106 | case "ZSTD":
107 | return zstd.decompress(data)
108 | case _:
109 | raise NotImplementedError(f"CompressedFile does not implement "
110 | f"{self.compression_type} yet.")
111 |
--------------------------------------------------------------------------------
/src/sedpack/io/dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023-2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Build and load tensorFlow dataset Record wrapper"""
15 |
16 | from pathlib import Path
17 | from typing import Union
18 |
19 | from sedpack.io.dataset_base import DatasetBase
20 | from sedpack.io.dataset_iteration import DatasetIteration
21 | from sedpack.io.dataset_iteration_tf import DatasetIterationTF
22 | from sedpack.io.dataset_writing import DatasetWriting
23 | from sedpack.io.errors import DatasetExistsError
24 | from sedpack.io.metadata import DatasetInfo, DatasetStructure, Metadata
25 |
26 |
27 | class Dataset(
28 | DatasetIteration,
29 | DatasetIterationTF,
30 | DatasetWriting,
31 | ):
32 | """Dataset class."""
33 |
34 | def __init__(self,
35 | path: Union[str, Path],
36 | create_dataset: bool = False) -> None:
37 | """Class for saving and loading a database.
38 |
39 | Args:
40 |
41 | path (Union[str, Path]): Path from where to load the dataset (a
42 | directory -- for instance
43 | "/home/user_name/datasets/my_awesome_dataset").
44 |
45 | create_dataset (bool): Are we creating a new dataset? Defaults to
46 | False which is used when (down-)loading a dataset.
47 |
48 | Raises:
49 |
50 | ValueError if the dataset was created using a newer version of the
51 | sedpack than the one trying to load it. See
52 | sedpack.__version__ docstring.
53 |
54 | FileNotFoundError if `create_dataset` is False and the
55 | `dataset_info.json` file does not exist.
56 | """
57 | dataset_info: DatasetInfo
58 | if create_dataset:
59 | # Default DatasetInfo.
60 | dataset_info = DatasetInfo()
61 | else:
62 | # Load the information.
63 | dataset_info = DatasetBase._load(Path(path))
64 |
65 | super().__init__(path=path, dataset_info=dataset_info)
66 |
67 | @staticmethod
68 | def create(
69 | path: Union[str, Path],
70 | metadata: Metadata,
71 | dataset_structure: DatasetStructure,
72 | ) -> "Dataset":
73 | """Create an empty dataset to be filled using the `filler` or
74 | `write_multiprocessing` API.
75 |
76 | Args:
77 |
78 | path (Union[str, Path]): Path where the dataset gets saved (a
79 | directory -- for instance
80 | "/home/user_name/datasets/my_awesome_dataset").
81 |
82 | metadata (Metadata): Information about this dataset.
83 |
84 | dataset_structure (DatasetStructure): Structure of saved records.
85 |
86 | Raises: DatasetExistsError if creating this object would overwrite the
87 | corresponding config file.
88 | """
89 | # Create a new object.
90 | dataset = Dataset(path=Path(path), create_dataset=True)
91 |
92 | # Do not overwrite an existing dataset.
93 | if Dataset._get_config_path(dataset.path).is_file():
94 | # Raise if the dataset already exists.
95 | raise DatasetExistsError(dataset.path)
96 |
97 | # Create a new dataset directory if needed.
98 | dataset.path.mkdir(parents=True, exist_ok=True)
99 |
100 | # Fill metadata and structure parameters.
101 | dataset.metadata = metadata
102 | dataset.dataset_structure = dataset_structure
103 |
104 | # Write empty config.
105 | dataset.write_config(updated_infos=[])
106 | return dataset
107 |
--------------------------------------------------------------------------------
/src/sedpack/io/shard_info_iterator/shard_info_iterator.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Base class for a shard info iterator."""
15 | import itertools
16 | from pathlib import Path
17 |
18 | from typing import Iterator
19 |
20 | from sedpack.io.metadata import DatasetInfo
21 | from sedpack.io.shard_file_metadata import ShardInfo, ShardsList, ShardListInfo
22 | from sedpack.io.types import SplitT
23 |
24 |
25 | class ShardInfoIterator:
26 | """Iterate shards of a dataset.
27 | """
28 |
29 | def __init__(
30 | self,
31 | *,
32 | dataset_path: Path,
33 | dataset_info: DatasetInfo,
34 | split: SplitT | None,
35 | repeat: bool = False,
36 | ) -> None:
37 | """Initialize shard information iteration.
38 |
39 | Args:
40 |
41 | dataset_path (Path): The path to the dataset directory.
42 |
43 | dataset_info (DatasetInfo): The information about the iterated
44 | dataset.
45 |
46 | split (SplitT | None): Which split to iterate or all if set to None.
47 |
48 | repeat (bool): Should we cycle indefinitely?
49 | """
50 | self.dataset_path: Path = Path(dataset_path)
51 | self.dataset_info: DatasetInfo = dataset_info
52 | self.split: SplitT | None = split
53 | self.repeat: bool = repeat
54 |
55 | self._iterator: Iterator[ShardInfo] = iter([])
56 |
57 | def __len__(self) -> int:
58 | """Either return the number of ShardInfo objects iterated or raise a
59 | ValueError if infinite cycle.
60 | """
61 | if self.number_of_shards() == 0 or not self.repeat:
62 | return self.number_of_shards()
63 | raise ValueError("Infinite iteration")
64 |
65 | def number_of_shards(self) -> int:
66 | """Return the number of distinct shards that are iterated. When
67 | repeated this method still returns a finite answer.
68 | """
69 | # Single split.
70 | if self.split is None:
71 | # Sum all splits.
72 | return sum(shard_list_info.number_of_shards
73 | for shard_list_info in self.dataset_info.splits.values())
74 |
75 | if self.split not in self.dataset_info.splits:
76 | return 0
77 |
78 | return self.dataset_info.splits[self.split].number_of_shards
79 |
80 | def _shard_info_iterator(
81 | self, shard_list_info: ShardListInfo) -> Iterator[ShardInfo]:
82 | """Recursively yield `ShardInfo` from the whole directory tree.
83 | """
84 | shard_list: ShardsList = ShardsList.model_validate_json(
85 | (self.dataset_path /
86 | shard_list_info.shard_list_info_file.file_path).read_text())
87 |
88 | yield from shard_list.shard_files
89 |
90 | for child in shard_list.children_shard_lists:
91 | yield from self._shard_info_iterator(child)
92 |
93 | def __iter__(self) -> Iterator[ShardInfo]:
94 | """Return the shard information iterator (reentrant).
95 | """
96 | if self.split is None:
97 | self._iterator = itertools.chain.from_iterable(
98 | self._shard_info_iterator(shard_list_info)
99 | for shard_list_info in self.dataset_info.splits.values())
100 | else:
101 | self._iterator = self._shard_info_iterator(
102 | self.dataset_info.splits[self.split])
103 |
104 | return self._iterator
105 |
106 | def __next__(self) -> ShardInfo:
107 | """Get the next item.
108 | """
109 | return next(self._iterator)
110 |
--------------------------------------------------------------------------------
/website/src/content/docs/tutorials/sca/overview.mdx:
--------------------------------------------------------------------------------
1 | ---
2 | title: Side Channel Attacks Overview
3 | description: Side Channel Attacks Overview
4 | ---
5 |
6 | Side channel attacks (SCA) and side channel analysis (conveniently also SCA)
7 | study how to correlate data dependent computation characteristics (e.g.,
8 | timing, power consumption or electromagnetic emissions) to secret values. There
9 | is a rich body of research in this area. For some evaluations large amounts of
10 | data are needed.
11 |
12 | At the [CHES 24 OPTIMIST workshop](https://optimist-ose.org/workshop-24) an
13 | initiative for [Open Tools, Interfaces and Metrics for Implementation Security
14 | Testing (OPTIMIST)](https://optimist-ose.org/) started. One of the outcomes
15 | being a call for format for trace storage. In this series of tutorials we argue
16 | that Sedpack is a viable solution for this purpose. If you have any questions,
17 | feature suggestions or patches see [sedpack GitHub
18 | repository](https://github.com/google/sedpack).
19 |
20 | The **sedpack** project started as a refactor and evolution of the data storage
21 | system used by **SCAAML (Side Channel Attacks Assisted with Machine
22 | Learning)**. See the [SCAAML website](https://google.github.io/scaaml/) or the
23 | [SCAAML GitHub repository](https://github.com/google/scaaml/).
24 |
25 | ## What's Next
26 |
27 | For the purposes of exposition we mainly focus on the very easy dataset of
28 | power consumption measurements of a textbook [AES
29 | (Wikipedia)](https://wikipedia.org/wiki/Advanced_Encryption_Standard)
30 | implementation. This dataset is both publicly available and easy to attack
31 | (analyze). If you know side channel attacks feel free to skip to the next
32 | sections otherwise you might choose to read through the following blog-posts:
33 |
34 | - [A Hacker Guide To Deep Learning Based Side Channel
35 | Attacks](https://elie.net/talk/a-hackerguide-to-deep-learning-based-side-channel-attacks)
36 | a video of DefCon 27 (2019) talk introducing SCAAML.
37 | - [Hacker's guide to deep-learning side-channel attacks: the
38 | theory](https://elie.net/blog/security/hacker-guide-to-deep-learning-side-channel-attacks-the-theory)
39 | a gentle introduction to the theory.
40 | - [Hacker's guide to deep-learning side-channel attacks: code
41 | walkthrough](https://elie.net/blog/security/hacker-guide-to-deep-learning-side-channel-attacks-code-walkthrough)
42 | the very first version of the SCAAML model. Later in this series of
43 | tutorials we use our most recent model on the same dataset.
44 |
45 | In this series of tutorials we showcase how the sedpack storage can be
46 | leveraged to help with side channel analysis. This series is not a
47 | one-solution-fits fully fledged side channel framework. The point is to
48 | showcase a tool designed to do one thing well (data storage) and enables others
49 | to work together using as universal interface as possible. We will see that
50 | different forms of iteration (randomly shuffled or just plain iteration)
51 | possibly with batching (yielding several examples at once) provides a ground
52 | for both machine learning as well as classical attack needs.
53 |
54 | ## Existing Tools
55 |
56 | There is a variety of excellent tools for side channel analysis. Most of which
57 | provide not only storage capabilities but also side channel analysis tools.
58 | This provides an advantage of being ready to use. On the other hand storing
59 | data in one of those $N$ tools and using algorithms from another creates an $N
60 | \times N$ compatibility matrix (with the diagonal hopefully being trivial). Our
61 | goal is to make one storage tool which can be then used by other tools. And
62 | thus saving a significant amount of duplicated work.
63 |
64 | ### An Incomplete List of Existing Tools
65 |
66 | We acknowledge that the following list is incomplete. In fact it is our hope
67 | that it is incomplete as we hope this list will be growing. In an alphabetical
68 | order:
69 |
70 | - [LASCAR](https://github.com/Ledger-Donjon/lascar)
71 | - [SCARR](https://github.com/decryptofy/scarr)
72 | - [riscure](https://www.keysight.com/ch/de/products/network-test/device-vulnerability-analysis.html)
73 |
74 | If we forgot your tool and you want it to be listed here please open an issue
75 | or send us a pull request.
76 |
--------------------------------------------------------------------------------
/tests/io/test_shard_custom_metadata.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from pathlib import Path
16 | from typing import Union
17 |
18 | import numpy as np
19 |
20 | import sedpack
21 | from sedpack.io import Dataset
22 | from sedpack.io import Metadata
23 | from sedpack.io.types import TRAIN_SPLIT
24 |
25 |
26 | def test_custom_shard_metadata(tmpdir: Union[str, Path]) -> None:
27 | attribute_dtype = "float32"
28 | shard_file_type = "fb"
29 |
30 | array_of_values = np.random.random((32, 138)).astype(dtype=attribute_dtype)
31 |
32 | experiment_path: Path = Path(tmpdir) / "custom_shard_metadata_experiment"
33 |
34 | # Create a dataset
35 | dataset_metadata = Metadata(description="Test of the lib")
36 |
37 | example_attributes = [
38 | sedpack.io.metadata.Attribute(
39 | name="attribute_name",
40 | dtype=attribute_dtype,
41 | shape=array_of_values[0].shape,
42 | ),
43 | ]
44 |
45 | dataset_structure = sedpack.io.metadata.DatasetStructure(
46 | saved_data_description=example_attributes,
47 | compression="GZIP",
48 | examples_per_shard=4,
49 | shard_file_type=shard_file_type,
50 | )
51 |
52 | dataset = Dataset.create(
53 | path=experiment_path,
54 | metadata=dataset_metadata,
55 | dataset_structure=dataset_structure,
56 | )
57 |
58 | # Fill data in the dataset
59 | custom_metadata_0 = {"key": {"key2": "valueA"}}
60 | custom_metadata_1 = {"key": {"key2": "valueB"}}
61 |
62 | with dataset.filler() as filler:
63 | # No custom metadata.
64 | filler.write_example(
65 | values={"attribute_name": array_of_values[0]},
66 | split=TRAIN_SPLIT,
67 | )
68 | # Still the same shard, retroactively setting metadata here.
69 | filler.write_example(
70 | values={"attribute_name": array_of_values[1]},
71 | split=TRAIN_SPLIT,
72 | custom_metadata=custom_metadata_0,
73 | )
74 | # Another shard has been open.
75 | filler.write_example(
76 | values={"attribute_name": array_of_values[2]},
77 | split=TRAIN_SPLIT,
78 | custom_metadata=custom_metadata_1,
79 | )
80 | # Still the same shard.
81 | filler.write_example(
82 | values={"attribute_name": array_of_values[3]},
83 | split=TRAIN_SPLIT,
84 | custom_metadata=custom_metadata_1,
85 | )
86 | # Still the same shard.
87 | filler.write_example(
88 | values={"attribute_name": array_of_values[4]},
89 | split=TRAIN_SPLIT,
90 | custom_metadata=custom_metadata_1,
91 | )
92 | # Still the same shard.
93 | filler.write_example(
94 | values={"attribute_name": array_of_values[5]},
95 | split=TRAIN_SPLIT,
96 | custom_metadata=custom_metadata_1,
97 | )
98 | # Shard full, another opened.
99 | filler.write_example(
100 | values={"attribute_name": array_of_values[6]},
101 | split=TRAIN_SPLIT,
102 | custom_metadata=custom_metadata_1,
103 | )
104 |
105 | # The object in memory and the saved metadata are the same.
106 | assert dataset._dataset_info == Dataset(experiment_path)._dataset_info
107 |
108 | # There are three shards with the custom metadata.
109 | shards: list[ShardInfo] = list(dataset.shard_info_iterator(TRAIN_SPLIT))
110 | assert len(shards) == 3
111 | assert shards[0].custom_metadata == custom_metadata_0
112 | assert shards[1].custom_metadata == custom_metadata_1
113 | assert shards[2].custom_metadata == custom_metadata_1
114 |
--------------------------------------------------------------------------------
/src/sedpack/io/flatbuffer/shardfile/Attribute.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # automatically generated by the FlatBuffers compiler, do not modify
16 |
17 | # namespace: shardfile
18 |
19 | # pylint: skip-file
20 |
21 | import flatbuffers
22 |
23 | import numpy as np
24 | import numpy.typing as npt
25 |
26 |
27 | class Attribute(object):
28 | __slots__ = ['_tab']
29 |
30 | @classmethod
31 | def GetRootAs(cls, buf: bytes, offset: int = 0) -> "Attribute":
32 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
33 | x = Attribute()
34 | x.Init(buf, n + offset)
35 | return x
36 |
37 | # Attribute
38 | def Init(self, buf: bytes, pos: int) -> None:
39 | self._tab = flatbuffers.table.Table(buf, pos)
40 |
41 | # Attribute
42 | def AttributeBytes(self, j: int) -> bytes:
43 | o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
44 | if o != 0:
45 | a = self._tab.Vector(o)
46 | return self._tab.Get( # type: ignore[no-any-return]
47 | flatbuffers.number_types.Uint8Flags,
48 | a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 1))
49 | return bytes([])
50 |
51 | # Attribute
52 | def AttributeBytesAsNumpy(self) -> npt.NDArray[np.uint8]:
53 | o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
54 | if o != 0:
55 | return self._tab.GetVectorAsNumpy( # type: ignore[no-any-return]
56 | flatbuffers.number_types.Uint8Flags, o)
57 | return np.array([], dtype=np.uint8)
58 |
59 | # Attribute
60 | def AttributeBytesLength(self) -> int:
61 | o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
62 | if o != 0:
63 | return self._tab.VectorLen(o) # type: ignore[no-any-return]
64 | return 0
65 |
66 | # Attribute
67 | def AttributeBytesIsNone(self) -> bool:
68 | o: int = flatbuffers.number_types.UOffsetTFlags.py_type(
69 | self._tab.Offset(4))
70 | return o == 0
71 |
72 |
73 | def AttributeStart( # type: ignore[no-any-unimported]
74 | builder: flatbuffers.builder.Builder) -> None:
75 | builder.StartObject(1)
76 |
77 |
78 | def Start( # type: ignore[no-any-unimported]
79 | builder: flatbuffers.builder.Builder) -> None:
80 | AttributeStart(builder)
81 |
82 |
83 | def AttributeAddAttributeBytes( # type: ignore[no-any-unimported]
84 | builder: flatbuffers.builder.Builder, attributeBytes: int) -> None:
85 | builder.PrependUOffsetTRelativeSlot(
86 | 0,
87 | flatbuffers.number_types.UOffsetTFlags.py_type(attributeBytes),
88 | 0,
89 | )
90 |
91 |
92 | def AddAttributeBytes( # type: ignore[no-any-unimported]
93 | builder: flatbuffers.builder.Builder, attributeBytes: int) -> None:
94 | AttributeAddAttributeBytes(builder, attributeBytes)
95 |
96 |
97 | def AttributeStartAttributeBytesVector( # type: ignore[no-any-unimported]
98 | builder: flatbuffers.builder.Builder, numElems: int) -> int:
99 | return builder.StartVector(1, numElems, 1) # type: ignore[no-any-return]
100 |
101 |
102 | def StartAttributeBytesVector( # type: ignore[no-any-unimported]
103 | builder: flatbuffers.builder.Builder, numElems: int) -> int:
104 | return AttributeStartAttributeBytesVector(builder, numElems)
105 |
106 |
107 | def AttributeEnd( # type: ignore[no-any-unimported]
108 | builder: flatbuffers.builder.Builder) -> int:
109 | return builder.EndObject() # type: ignore[no-any-return]
110 |
111 |
112 | def End( # type: ignore[no-any-unimported]
113 | builder: flatbuffers.builder.Builder) -> int:
114 | return AttributeEnd(builder)
115 |
--------------------------------------------------------------------------------
/base-tooling-requirements.txt:
--------------------------------------------------------------------------------
1 | build==1.2.1 --hash=sha256:526263f4870c26f26c433545579475377b2b7588b6f1eac76a001e873ae3e19d --hash=sha256:75e10f767a433d9a86e50d83f418e83efc18ede923ee5ff7df93b6cb0306c5d4
2 | click==8.2.0 --hash=sha256:6b303f0b2aa85f1cb4e5303078fadcbcd4e476f114fab9b5007005711839325c --hash=sha256:f5452aeddd9988eefa20f90f05ab66f17fce1ee2a36907fd30b05bbb5953814d
3 | importlib-metadata==8.7.0 --hash=sha256:d13b81ad223b890aa16c5471f2ac3056cf76c5f10f82d6f9292f0b415f389000 --hash=sha256:e5dd1551894c77868a30651cef00984d50e1002d06942a7101d34870c5f02afd
4 | packaging==25.0 --hash=sha256:29572ef2b1f17581046b3a2227d5c611fb25ec70ca1ba8554b24b0e69331a484 --hash=sha256:d443872c98d677bf60f6a1f2f8c1cb748e8fe762d2bf9d3148b5599295b0fc4f
5 | pip-tools==7.4.0 --hash=sha256:a92a6ddfa86ff389fe6ace381d463bc436e2c705bd71d52117c25af5ce867bb7 --hash=sha256:b67432fd0759ed834c5367f9e0ce8c95441acecfec9c8e24b41aca166757adf0
6 | pyproject-hooks==1.2.0 --hash=sha256:1e859bd5c40fae9448642dd871adf459e5e2084186e8d2c2a79a824c970da1f8 --hash=sha256:9e5c6bfa8dcc30091c74b0cf803c81fdd29d94f01992a7707bc97babb1141913
7 | tomli==2.2.1 --hash=sha256:023aa114dd824ade0100497eb2318602af309e5a55595f76b626d6d9f3b7b0a6 --hash=sha256:02abe224de6ae62c19f090f68da4e27b10af2b93213d36cf44e6e1c5abd19fdd --hash=sha256:286f0ca2ffeeb5b9bd4fcc8d6c330534323ec51b2f52da063b11c502da16f30c --hash=sha256:2d0f2fdd22b02c6d81637a3c95f8cd77f995846af7414c5c4b8d0545afa1bc4b --hash=sha256:33580bccab0338d00994d7f16f4c4ec25b776af3ffaac1ed74e0b3fc95e885a8 --hash=sha256:400e720fe168c0f8521520190686ef8ef033fb19fc493da09779e592861b78c6 --hash=sha256:40741994320b232529c802f8bc86da4e1aa9f413db394617b9a256ae0f9a7f77 --hash=sha256:465af0e0875402f1d226519c9904f37254b3045fc5084697cefb9bdde1ff99ff --hash=sha256:4a8f6e44de52d5e6c657c9fe83b562f5f4256d8ebbfe4ff922c495620a7f6cea --hash=sha256:4e340144ad7ae1533cb897d406382b4b6fede8890a03738ff1683af800d54192 --hash=sha256:678e4fa69e4575eb77d103de3df8a895e1591b48e740211bd1067378c69e8249 --hash=sha256:6972ca9c9cc9f0acaa56a8ca1ff51e7af152a9f87fb64623e31d5c83700080ee --hash=sha256:7fc04e92e1d624a4a63c76474610238576942d6b8950a2d7f908a340494e67e4 --hash=sha256:889f80ef92701b9dbb224e49ec87c645ce5df3fa2cc548664eb8a25e03127a98 --hash=sha256:8d57ca8095a641b8237d5b079147646153d22552f1c637fd3ba7f4b0b29167a8 --hash=sha256:8dd28b3e155b80f4d54beb40a441d366adcfe740969820caf156c019fb5c7ec4 --hash=sha256:9316dc65bed1684c9a98ee68759ceaed29d229e985297003e494aa825ebb0281 --hash=sha256:a198f10c4d1b1375d7687bc25294306e551bf1abfa4eace6650070a5c1ae2744 --hash=sha256:a38aa0308e754b0e3c67e344754dff64999ff9b513e691d0e786265c93583c69 --hash=sha256:a92ef1a44547e894e2a17d24e7557a5e85a9e1d0048b0b5e7541f76c5032cb13 --hash=sha256:ac065718db92ca818f8d6141b5f66369833d4a80a9d74435a268c52bdfa73140 --hash=sha256:b82ebccc8c8a36f2094e969560a1b836758481f3dc360ce9a3277c65f374285e --hash=sha256:c954d2250168d28797dd4e3ac5cf812a406cd5a92674ee4c8f123c889786aa8e --hash=sha256:cb55c73c5f4408779d0cf3eef9f762b9c9f147a77de7b258bef0a5628adc85cc --hash=sha256:cd45e1dc79c835ce60f7404ec8119f2eb06d38b1deba146f07ced3bbc44505ff --hash=sha256:d3f5614314d758649ab2ab3a62d4f2004c825922f9e370b29416484086b264ec --hash=sha256:d920f33822747519673ee656a4b6ac33e382eca9d331c87770faa3eef562aeb2 --hash=sha256:db2b95f9de79181805df90bedc5a5ab4c165e6ec3fe99f970d0e302f384ad222 --hash=sha256:e59e304978767a54663af13c07b3d1af22ddee3bb2fb0618ca1593e4f593a106 --hash=sha256:e85e99945e688e32d5a35c1ff38ed0b3f41f43fad8df0bdf79f72b2ba7bc5272 --hash=sha256:ece47d672db52ac607a3d9599a9d48dcb2f2f735c6c2d1f34130085bb12b112a --hash=sha256:f4039b9cbc3048b2416cc57ab3bda989a6fcf9b36cf8937f01a6e731b64f80d7
8 | typing-extensions==4.13.1 --hash=sha256:4b6cf02909eb5495cfbc3f6e8fd49217e6cc7944e145cdda8caa3734777f9e69 --hash=sha256:98795af00fb9640edec5b8e31fc647597b4691f099ad75f469a2616be1a76dff
9 | wheel==0.45.0 --hash=sha256:52f0baa5e6522155090a09c6bd95718cc46956d1b51d537ea5454249edb671c7 --hash=sha256:a57353941a3183b3d5365346b567a260a0602a0f8a635926a7dede41b94c674a
10 | zipp==3.23.0 --hash=sha256:071652d6115ed432f5ce1d34c336c0adfd6a884660d1e9712a256d3d3bd4b14e --hash=sha256:a07157588a12518c9d4034df3fbbee09c814741a33ff63c05fa29d26a2404166
11 | pip==25.3 --hash=sha256:8d0538dbbd7babbd207f261ed969c65de439f6bc9e5dbd3b3b9a77f25d95f343 --hash=sha256:9655943313a94722b7774661c21049070f6bbb0a1516bf02f7c8d5d9201514cd
12 | setuptools==80.9.0 --hash=sha256:062d34222ad13e0cc312a4c02d73f059e86a4acbfbdea8f8f76b28c99f306922 --hash=sha256:f36b47402ecde768dbfafc46e8e4207b4360c654f1f3bb84475f0a28628fb19c
13 |
--------------------------------------------------------------------------------
/docs/code-of-conduct.md:
--------------------------------------------------------------------------------
1 | # Code of Conduct
2 |
3 | ## Our Pledge
4 |
5 | In the interest of fostering an open and welcoming environment, we as
6 | contributors and maintainers pledge to making participation in our project and
7 | our community a harassment-free experience for everyone, regardless of age, body
8 | size, disability, ethnicity, gender identity and expression, level of
9 | experience, education, socio-economic status, nationality, personal appearance,
10 | race, religion, or sexual identity and orientation.
11 |
12 | ## Our Standards
13 |
14 | Examples of behavior that contributes to creating a positive environment
15 | include:
16 |
17 | * Using welcoming and inclusive language
18 | * Being respectful of differing viewpoints and experiences
19 | * Gracefully accepting constructive criticism
20 | * Focusing on what is best for the community
21 | * Showing empathy towards other community members
22 |
23 | Examples of unacceptable behavior by participants include:
24 |
25 | * The use of sexualized language or imagery and unwelcome sexual attention or
26 | advances
27 | * Trolling, insulting/derogatory comments, and personal or political attacks
28 | * Public or private harassment
29 | * Publishing others' private information, such as a physical or electronic
30 | address, without explicit permission
31 | * Other conduct which could reasonably be considered inappropriate in a
32 | professional setting
33 |
34 | ## Our Responsibilities
35 |
36 | Project maintainers are responsible for clarifying the standards of acceptable
37 | behavior and are expected to take appropriate and fair corrective action in
38 | response to any instances of unacceptable behavior.
39 |
40 | Project maintainers have the right and responsibility to remove, edit, or reject
41 | comments, commits, code, wiki edits, issues, and other contributions that are
42 | not aligned to this Code of Conduct, or to ban temporarily or permanently any
43 | contributor for other behaviors that they deem inappropriate, threatening,
44 | offensive, or harmful.
45 |
46 | ## Scope
47 |
48 | This Code of Conduct applies both within project spaces and in public spaces
49 | when an individual is representing the project or its community. Examples of
50 | representing a project or community include using an official project e-mail
51 | address, posting via an official social media account, or acting as an appointed
52 | representative at an online or offline event. Representation of a project may be
53 | further defined and clarified by project maintainers.
54 |
55 | This Code of Conduct also applies outside the project spaces when the Project
56 | Steward has a reasonable belief that an individual's behavior may have a
57 | negative impact on the project or its community.
58 |
59 | ## Conflict Resolution
60 |
61 | We do not believe that all conflict is bad; healthy debate and disagreement
62 | often yield positive results. However, it is never okay to be disrespectful or
63 | to engage in behavior that violates the project’s code of conduct.
64 |
65 | If you see someone violating the code of conduct, you are encouraged to address
66 | the behavior directly with those involved. Many issues can be resolved quickly
67 | and easily, and this gives people more control over the outcome of their
68 | dispute. If you are unable to resolve the matter for any reason, or if the
69 | behavior is threatening or harassing, report it. We are dedicated to providing
70 | an environment where participants feel welcome and safe.
71 |
72 | Reports should be directed to *[PROJECT STEWARD NAME(s) AND EMAIL(s)]*, the
73 | Project Steward(s) for *[PROJECT NAME]*. It is the Project Steward’s duty to
74 | receive and address reported violations of the code of conduct. They will then
75 | work with a committee consisting of representatives from the Open Source
76 | Programs Office and the Google Open Source Strategy team. If for any reason you
77 | are uncomfortable reaching out to the Project Steward, please email
78 | [opensource@google.com](opensource@google.com).
79 |
80 | We will investigate every complaint, but you may not receive a direct response.
81 | We will use our discretion in determining when and how to follow up on reported
82 | incidents, which may range from not taking action to permanent expulsion from
83 | the project and project-sponsored spaces. We will notify the accused of the
84 | report and provide them an opportunity to discuss it before any action is taken.
85 | The identity of the reporter will be omitted from the details of the report
86 | supplied to the accused. In potentially harmful situations, such as ongoing
87 | harassment or threats to anyone's safety, we may take action without notice.
88 |
89 | ## Attribution
90 |
91 | This Code of Conduct is adapted from the Contributor Covenant, version 1.4,
92 | available at
93 | [https://www.contributor-covenant.org/version/1/4/code-of-conduct/](https://www.contributor-covenant.org/version/1/4/code-of-conduct/)
94 |
--------------------------------------------------------------------------------
/tests/io/shard/test_shard_write_async.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from pathlib import Path
16 | import pytest
17 |
18 | import numpy as np
19 |
20 | from sedpack.io.metadata import Attribute, DatasetStructure
21 | from sedpack.io.types import ShardFileTypeT
22 |
23 | from sedpack.io.flatbuffer import IterateShardFlatBuffer
24 | from sedpack.io.npz import IterateShardNP
25 | from sedpack.io.shard.get_shard_writer import get_shard_writer
26 |
27 | pytest_plugins = ("pytest_asyncio",)
28 |
29 |
30 | async def shard_write_and_read(attributes: dict[str,
31 | np.ndarray], shard_file: Path,
32 | shard_file_type: ShardFileTypeT) -> None:
33 | dataset_structure = DatasetStructure(
34 | saved_data_description=[
35 | Attribute(
36 | name=name,
37 | shape=value.shape[1:],
38 | dtype=str(value.dtype),
39 | ) for name, value in attributes.items()
40 | ],
41 | shard_file_type=shard_file_type,
42 | compression="",
43 | )
44 |
45 | # Write data into the file.
46 | writer = get_shard_writer(dataset_structure=dataset_structure,
47 | shard_file=shard_file)
48 | one_value = next(iter(attributes.values())) # One of the values.
49 | for i in range(one_value.shape[0]):
50 | writer.write(values={
51 | name: value[i] for name, value in attributes.items()
52 | })
53 | writer.close()
54 |
55 | iterate_shard: IterateShardBase
56 | match shard_file_type:
57 | case "npz":
58 | iterate_shard = IterateShardNP(dataset_structure=dataset_structure,
59 | process_record=None)
60 | case "fb":
61 | iterate_shard = IterateShardFlatBuffer(
62 | dataset_structure=dataset_structure, process_record=None)
63 | case _:
64 | raise ValueError(f"Unknown {shard_file_type = }")
65 |
66 | # Read those back.
67 | seen: int = 0
68 | i: int = 0
69 | async for example in iterate_shard.iterate_shard_async(shard_file):
70 | for name, value in attributes.items():
71 | np.testing.assert_allclose(example[name], value[i])
72 | seen += 1
73 | i += 1 # manual enumerate
74 | assert seen == one_value.shape[0]
75 |
76 |
77 | @pytest.mark.asyncio
78 | async def test_async_npz_with_int(tmp_path):
79 | shard_file = tmp_path / "shard_file.npz"
80 | attributes = {
81 | "a": np.array([[13 + 512, 2, 3], [4, 5, 6]]),
82 | }
83 | await shard_write_and_read(attributes, shard_file, shard_file_type="npz")
84 |
85 |
86 | @pytest.mark.asyncio
87 | async def test_async_npz_with_float(tmp_path):
88 | shard_file = tmp_path / "shard_file.npz"
89 | attributes = {
90 | "a": np.random.uniform(size=(10, 15)),
91 | }
92 | await shard_write_and_read(attributes, shard_file, shard_file_type="npz")
93 |
94 |
95 | @pytest.mark.asyncio
96 | async def test_async_npz_mixed(tmp_path):
97 | shard_file = tmp_path / "shard_file.npz"
98 | attributes = {
99 | "a": np.random.uniform(size=(10, 15)),
100 | "b": np.random.uniform(size=(10, 25)),
101 | "c": np.random.randint(-5, 20, size=(10, 21)),
102 | }
103 | await shard_write_and_read(attributes, shard_file, shard_file_type="npz")
104 |
105 |
106 | @pytest.mark.asyncio
107 | async def test_async_fb_with_int(tmp_path):
108 | shard_file = tmp_path / "shard_file"
109 | attributes = {
110 | "a": np.array([[13 + 512, 2, 3], [4, 5, 6]]),
111 | }
112 | await shard_write_and_read(attributes, shard_file, shard_file_type="fb")
113 |
114 |
115 | @pytest.mark.asyncio
116 | async def test_async_fb_with_float(tmp_path):
117 | shard_file = tmp_path / "shard_file"
118 | attributes = {
119 | "a": np.random.uniform(size=(10, 15)),
120 | }
121 | await shard_write_and_read(attributes, shard_file, shard_file_type="fb")
122 |
123 |
124 | @pytest.mark.asyncio
125 | async def test_async_fb_mixed(tmp_path):
126 | shard_file = tmp_path / "shard_file"
127 | attributes = {
128 | "a": np.random.uniform(size=(10, 15)),
129 | "b": np.random.uniform(size=(10, 25)),
130 | "c": np.random.randint(-5, 20, size=(10, 21)),
131 | }
132 | await shard_write_and_read(attributes, shard_file, shard_file_type="fb")
133 |
--------------------------------------------------------------------------------
/website/src/content/docs/tutorials/sca/dataset.mdx:
--------------------------------------------------------------------------------
1 | ---
2 | title: Converting the TinyAES Dataset into Sedpack
3 | description: Converting the TinyAES Dataset into Sedpack
4 | ---
5 |
6 | The TinyAES dataset is a dataset of power-measurements of a textbook [AES
7 | (Wikipedia)](https://de.wikipedia.org/wiki/Advanced_Encryption_Standard)
8 | implementation. It was introduced for the DefCon 27 demo: [A Hacker Guide To
9 | Deep Learning Based Side Channel
10 | Attacks](https://elie.net/talk/a-hackerguide-to-deep-learning-based-side-channel-attacks).
11 | In this tutorial we show how to convert the original data into the sedpack
12 | format.
13 |
14 | ## Original Dataset
15 |
16 | The dataset was captured from STM32F4 chips using the ChipWhisperer [CW308 UFO
17 | board](https://www.newae.com/chipwhisperer). The capture was asynchronous
18 | meaning that the oscilloscope clock signal and the target chip clock signal
19 | were not synchronized (oscilloscope was oversampling to cope with this fact).
20 | The oscilloscope used was PicoScope® 6404D.
21 |
22 | One can download all zipped files either by clicking the following link
23 | [datasets.zip
24 | (8.2GB)](https://storage.googleapis.com/scaaml-public/scaaml_intro/datasets.zip)
25 | or:
26 |
27 | ```bash
28 | wget https://storage.googleapis.com/scaaml-public/scaaml_intro/datasets.zip
29 | sha256sum datasets.zip # 4bf2c6defb79b40b30f01f488e83762396b56daad14a694f64916be2b665b2f8
30 | unzip datasets.zip
31 | ```
32 |
33 | The original files were saved in the
34 | [NumPy](https://numpy.org/doc/1.26/reference/generated/numpy.savez.html)
35 | format. Each file having a constant key and balanced plaintexts (see [SCAAML
36 | Attack Point
37 | Iterators](https://google.github.io/scaaml/guides/capture/attack_point_iterators/)
38 | for explanation). This is completely fine other than:
39 |
40 | - There is no standard format -- if somebody wants to experiment with this
41 | dataset they have to discover the data format and code custom data loaders.
42 | Not complicated but tedious.
43 | - This approach does not scale well for huge datasets.
44 |
45 | ## Conversion to Sedpack
46 |
47 | The script
48 | [docs/tutorials/sca/tiny_aes.py](https://github.com/google/sedpack/blob/main/docs/tutorials/sca/tiny_aes.py)
49 | contains a function `convert_to_sedpack` which takes the directory with
50 | original NumPy files and converts it to a sedpack dataset.
51 |
52 | We save the following attributes:
53 |
54 | - `trace1`: float16, length 80_000
55 | - `key`: uint8, length 16
56 | - `plaintext`: uint8, length 16
57 | - `ciphertext`: uint8, length 16
58 | - `sub_bytes_out`: uint8, length 16
59 | - `sub_bytes_in`: uint8, length 16
60 |
61 | Some (most) of our scenarios will be dealing with profiled attacks. Where
62 | during the profiling phase we know the secret (`key`) and during attack we try
63 | to find it out. For convenience we also save the `ciphertext` (result of
64 | encryption of the given `plaintext` by the `key`) and first round of AES S-BOX
65 | inputs and outputs. These are for convenience and we could omit those. In a
66 | profiled setting (one device where we train and another device where we attack,
67 | these different phases are called the profiling phase and the attack phase)
68 | `ciphertext`, `sub_bytes_in`, `sub_bytes_out` could be considered to be
69 | redundant information. In the non-profiled setting (single device where we do
70 | not know the `key`) we would have access only to `plaintext` and `ciphertext`.
71 |
72 | Naturally if we wanted to save more attack points, e.g., last round of AES
73 | states, all inputs of the S-BOX, we could. Alternatively we can choose to
74 | compute those on the fly when / if we need them.
75 |
76 | The limitation of the TinyAES dataset is that all data was collected on a
77 | single device (both the original `train` and `test` splits). This is one of the
78 | reasons why this dataset is suitable for demonstration purposes only when one
79 | is considering profiled attacks. In a realistic scenario one would take `train`
80 | and `test` on the profiling chip and `holdout` on a different physical device
81 | (of the same type). We deal with the lack of the `holdout` split by splitting
82 | the original `test` into `test` and `holdout`.
83 |
84 | ### Iteration
85 |
86 | Let us check the data we just converted.
87 |
88 | ```python
89 | from sedpack.io import Dataset
90 |
91 | dataset = Dataset("tiny_aes_sedpack")
92 | for example in dataset.as_numpy_iterator(split="train",
93 | shuffle=0,
94 | repeat=False):
95 | print(example["key"])
96 | ```
97 |
98 | We should see groups of 256 examples with the exactly same `key` value. Same
99 | holds for `test` and `holdout` splits.
100 |
101 | And we plot a single trace just to be sure everything works as expected.
102 |
103 | ```python
104 | import matplotlib.pyplot as plt
105 |
106 | # Plot a single trace.
107 | plt.plot(next(iter(dataset.as_numpy_iterator(split="train")))["trace1"])
108 | plt.savefig("tiny_aes_trace.png")
109 | ```
110 |
111 | 
112 |
--------------------------------------------------------------------------------
/.github/workflows/pytest.yml:
--------------------------------------------------------------------------------
1 | name: pytest
2 | permissions:
3 | contents: read
4 | pull-requests: write
5 | on:
6 | pull_request:
7 | types: [opened, synchronize, reopened]
8 | paths:
9 | - '**/*.py'
10 | - '**/*.rs'
11 | - 'pytest.ini'
12 | schedule:
13 | - cron: 0 5 * * 1 # Every Monday at 5:00 UTC
14 | merge_group: # Needed for required workflows
15 | # Run after a review has been submitted (this is a required workflow which
16 | # might not be triggered when no code changes -- trigger before going to
17 | # merge queue).
18 | pull_request_review:
19 | types: [submitted]
20 |
21 | jobs:
22 | unittesting:
23 | runs-on: ${{ matrix.platform.runner }}
24 | strategy:
25 | matrix:
26 | # ubuntu-20.04-arm was not stable enough when testing
27 | platform:
28 | - runner: ubuntu-latest # x64
29 | - runner: windows-latest # x64
30 | - runner: macos-14 # arm64
31 | - runner: macos-15-intel # Intel
32 | - runner: macos-latest # arm64
33 | if: github.event_name != 'schedule'
34 | steps:
35 | - uses: actions/checkout@v6
36 | - name: Set up Python 3.10
37 | uses: actions/setup-python@v6
38 | with:
39 | python-version: '3.10'
40 | cache: 'pip'
41 | - name: Get pip cache directory
42 | id: pip-cache
43 | shell: bash
44 | run: |
45 | echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
46 | - name: Use cached venv or create it
47 | uses: actions/cache/restore@v5
48 | id: cache
49 | with:
50 | path: ${{ steps.pip-cache.outputs.dir }}
51 | key: ${{ matrix.platform.runner }}-pip
52 | # Build a virtualenv, but only if it doesn't already exist
53 | - name: Populate pip cache
54 | # requirements.txt is not reliable since across different platforms and
55 | # their versions the pip package versions might vary. We regenerate it
56 | # again from pyproject.toml every time when pyproject.toml or
57 | # requirements.txt changes. The pinned versions in requirements.txt are
58 | # tested by coverage since that is running on ubuntu which is also used
59 | # to produce the main requirements.txt file.
60 | run: |
61 | pip install pip==25.2 # TODO(remove the pinning) pip-tools issue 2252
62 | pip install pip-tools
63 | pip-compile --generate-hashes --extra dev pyproject.toml > dev_requirements.txt
64 | pip install -r dev_requirements.txt
65 | if: steps.cache.outputs.cache-hit != 'true'
66 | - name: Save cache
67 | id: cache-save
68 | uses: actions/cache/save@v5
69 | with:
70 | path: ${{ steps.pip-cache.outputs.dir }}
71 | key: ${{ steps.cache.outputs.cache-primary-key }}
72 | if: steps.cache.outputs.cache-hit != 'true'
73 | - name: Install sedpack locally
74 | run: pip install --editable ".[dev]"
75 | - name: Running unit tests
76 | run: python -m pytest
77 |
78 | coverage:
79 | runs-on: ubuntu-22.04
80 | steps:
81 | - uses: actions/checkout@v6
82 | - name: Set up Python 3.10
83 | uses: actions/setup-python@v6
84 | with:
85 | python-version: '3.10'
86 | cache: 'pip'
87 | - name: Get pip cache directory
88 | id: pip-cache
89 | shell: bash
90 | run: |
91 | echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
92 | - name: Use cached venv or create it
93 | uses: actions/cache/restore@v5
94 | id: cache
95 | with:
96 | path: ${{ steps.pip-cache.outputs.dir }}
97 | # The cache key depends on requirements.txt
98 | key: ${{ runner.os }}-pip-${{ hashFiles('requirements.txt') }}
99 | # Build a virtualenv, but only if it doesn't already exist
100 | - name: Populate pip cache
101 | run: |
102 | python -m pip install --require-hashes --no-deps -r requirements.txt
103 | if: steps.cache.outputs.cache-hit != 'true'
104 | - name: Save cache
105 | id: cache-save
106 | uses: actions/cache/save@v5
107 | with:
108 | path: ${{ steps.pip-cache.outputs.dir }}
109 | key: ${{ steps.cache.outputs.cache-primary-key }}
110 | if: steps.cache.outputs.cache-hit != 'true'
111 | - name: Installing test requirements and sedpack
112 | # Start by "installing" sedpack to be sure all dependencies are listed
113 | run: |
114 | pip install --editable ".[dev]"
115 | echo "PYTHONPATH=./src:$PYTHONPATH" >> $GITHUB_ENV
116 | - name: Install workflow dependencies
117 | run: pip install --upgrade pytest coverage
118 | - name: Running unit tests with coverage
119 | env:
120 | DISABLE_AUTOGRAPH: 1
121 | # TODO remove the -i (ignore errors)
122 | run: |
123 | coverage run -m pytest
124 | coverage xml -i
125 | - name: Upload results
126 | uses: coverallsapp/github-action@648a8eb78e6d50909eff900e4ec85cab4524a45b # v2
127 | with:
128 | file: coverage.xml
129 |
130 |
--------------------------------------------------------------------------------
/docs/tutorials/quick_start/mnist_save.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023-2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Download the MNIST dataset and save it in dataset-lib format. For a tutorial
15 | with explanations see: https://google.github.io/sedpack/tutorials/mnist
16 |
17 | Example use:
18 | python mnist_save.py -d "~/Datasets/mnist_dataset/"
19 | python mnist_read_keras.py -d "~/Datasets/mnist_dataset/"
20 | """
21 |
22 | import argparse
23 | import random
24 | from typing import get_args
25 |
26 | from tensorflow import keras
27 | from tqdm import tqdm
28 |
29 | from sedpack.io import Dataset, Metadata, DatasetStructure, Attribute
30 | from sedpack.io.types import SplitT
31 | from sedpack.io.types import CompressionT, ShardFileTypeT
32 |
33 |
34 | def main() -> None:
35 | """Convert the MNIST dataset into sedpack format.
36 | """
37 | parser = argparse.ArgumentParser(
38 | description="Convert MNIST dataset into dataset-lib format")
39 | parser.add_argument("--dataset_directory",
40 | "-d",
41 | help="Where to save the dataset",
42 | required=True)
43 | parser.add_argument("--compression",
44 | "-c",
45 | help="Which compression algorithm to use",
46 | default="GZIP",
47 | choices=get_args(CompressionT))
48 | parser.add_argument("--shard_file_type",
49 | "-f",
50 | help="Which shard file type to use",
51 | default="tfrec",
52 | choices=get_args(ShardFileTypeT))
53 | args = parser.parse_args()
54 |
55 | # General info about the dataset
56 | metadata = Metadata(
57 | description="MNIST dataset in the sedpack format",
58 | dataset_license="""
59 | Yann LeCun and Corinna Cortes hold the copyright of MNIST dataset, which is
60 | a derivative work from original NIST datasets. MNIST dataset is made
61 | available under the terms of the Creative Commons Attribution-Share Alike
62 | 3.0 license.
63 | """,
64 | custom_metadata={
65 | "list of authors": ["Yann LeCun", "Corinna Cortes"],
66 | },
67 | )
68 |
69 | # Types of attributes stored
70 | dataset_structure = DatasetStructure(
71 | saved_data_description=[
72 | Attribute(
73 | name="input",
74 | shape=(28, 28),
75 | dtype="float32",
76 | ),
77 | Attribute(
78 | name="digit",
79 | shape=(),
80 | dtype="uint8",
81 | ),
82 | ],
83 | compression=args.compression,
84 | shard_file_type=args.shard_file_type,
85 | )
86 |
87 | # Create a new dataset
88 | dataset = Dataset.create(
89 | path=args.dataset_directory, # All files are stored here
90 | metadata=metadata,
91 | dataset_structure=dataset_structure,
92 | )
93 |
94 | # Fill in examples
95 | (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
96 | x_train = x_train.astype("float32") / 256
97 | x_test = x_test.astype("float32") / 256
98 |
99 | # DatasetFiller makes sure that all shard files are written properly
100 | # when exiting the context.
101 | with dataset.filler() as dataset_filler:
102 | # Determine which data are in the holdout (test)
103 | for i in tqdm(range(len(x_test)), desc="holdout"):
104 | dataset_filler.write_example(
105 | values={
106 | "input": x_test[i],
107 | "digit": y_test[i],
108 | },
109 | split="holdout",
110 | )
111 |
112 | # Randomly assign 10% of validation and the rest is training.
113 | assert len(x_train) == len(y_train)
114 | train_indices: list[int] = list(range(len(x_train)))
115 | random.shuffle(train_indices)
116 | validation_split_position: int = int(len(x_train) * 0.1)
117 | for index_position, index in enumerate(
118 | tqdm(train_indices, desc="train and val")):
119 |
120 | # Assign to either train or test (aka validation).
121 | split: SplitT = "test"
122 | if index_position < validation_split_position:
123 | split = "train"
124 |
125 | # Write the example.
126 | dataset_filler.write_example(
127 | values={
128 | "input": x_train[index],
129 | "digit": y_train[index],
130 | },
131 | split=split,
132 | )
133 |
134 |
135 | if __name__ == "__main__":
136 | main()
137 |
--------------------------------------------------------------------------------
/tests/io/test_continue_writing.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import os
16 | from pathlib import Path
17 | import random
18 | from typing import Union
19 |
20 | import numpy as np
21 |
22 | import sedpack
23 | from sedpack.io import Dataset
24 | from sedpack.io import Metadata
25 |
26 |
27 | def get_dataset(tmpdir: Union[str, Path]) -> Dataset:
28 | tiny_experiment_path: Path = Path(tmpdir) / "e2e_experiment"
29 |
30 | # Create a dataset
31 |
32 | dataset_metadata = Metadata(description="Test of the lib")
33 |
34 | example_attributes = [
35 | sedpack.io.metadata.Attribute(
36 | name="attribute_name",
37 | dtype="float32",
38 | shape=(138,),
39 | ),
40 | ]
41 |
42 | dataset_structure = sedpack.io.metadata.DatasetStructure(
43 | saved_data_description=example_attributes,)
44 |
45 | dataset = Dataset.create(
46 | path=tiny_experiment_path,
47 | metadata=dataset_metadata,
48 | dataset_structure=dataset_structure,
49 | )
50 |
51 | return dataset
52 |
53 |
54 | def fill(dataset, split, data):
55 | # Fill data in the dataset
56 | with dataset.filler() as filler:
57 | for attribute_value in data:
58 | filler.write_example(
59 | values={"attribute_name": attribute_value},
60 | split=split,
61 | )
62 |
63 | # Check the data is correct
64 | # Reopen the dataset
65 | dataset = Dataset(dataset.path)
66 | dataset.check()
67 |
68 |
69 | def check_presence(dataset, split, data):
70 | for i, example in enumerate(
71 | dataset.as_numpy_iterator(
72 | split=split,
73 | shuffle=0,
74 | repeat=False,
75 | )):
76 | assert np.allclose(example["attribute_name"], data[i:i + 1])
77 |
78 | # We tested everything
79 | assert i + 1 == data.shape[0], "Not all examples have been iterated"
80 |
81 |
82 | def test_continue_writing_another_split(tmpdir: Union[str, Path]) -> None:
83 | """Check that we can write more examples into empty / single split. This
84 | would uncover the bug addressed by merging updates info.
85 | """
86 | data_train = np.random.random((1024, 138)).astype(dtype=np.float32)
87 | filled_train: int = 0
88 | data_test = np.random.random((1024, 138)).astype(dtype=np.float32)
89 | filled_test: int = 0
90 |
91 | dataset = get_dataset(tmpdir)
92 |
93 | fill_now: int = 20
94 |
95 | for _ in range(4):
96 | fill_now = random.randint(10, 50)
97 | fill(
98 | dataset=dataset,
99 | split="train",
100 | data=data_train[filled_train:filled_train + fill_now],
101 | )
102 | filled_train += fill_now
103 |
104 | fill_now = random.randint(10, 50)
105 | fill(
106 | dataset=dataset,
107 | split="test",
108 | data=data_test[filled_test:filled_test + fill_now],
109 | )
110 | filled_test += fill_now
111 |
112 | # Both splits are present after writing (not just directly after
113 | # writing into one).
114 | check_presence(dataset, "train", data_train[:filled_train])
115 | check_presence(Dataset(dataset.path), "train",
116 | data_train[:filled_train])
117 | assert dataset._dataset_info.splits[
118 | "train"].number_of_examples == filled_train
119 | check_presence(dataset, "test", data_test[:filled_test])
120 | check_presence(Dataset(dataset.path), "test", data_test[:filled_test])
121 | assert dataset._dataset_info.splits[
122 | "test"].number_of_examples == filled_test
123 |
124 |
125 | def test_local_root_path(tmpdir: Union[str, Path]) -> None:
126 | """Check that relative path checks work even when dataset root path is
127 | local. For this we need to write multiple times in the dataset.
128 | """
129 | # Change the working directory to be in the /tmp/pytest-of-user/
130 | os.chdir(tmpdir)
131 |
132 | data_train = np.random.random((1024, 138)).astype(dtype=np.float32)
133 | filled_train: int = 0
134 | data_test = np.random.random((1024, 138)).astype(dtype=np.float32)
135 | filled_test: int = 0
136 |
137 | dataset = get_dataset("my_dataset")
138 |
139 | fill_now: int = 20
140 |
141 | for _ in range(4):
142 | fill_now = random.randint(10, 50)
143 | fill(
144 | dataset=dataset,
145 | split="train",
146 | data=data_train[filled_train:filled_train + fill_now],
147 | )
148 | filled_train += fill_now
149 |
150 | fill_now = random.randint(10, 50)
151 | fill(
152 | dataset=dataset,
153 | split="test",
154 | data=data_test[filled_test:filled_test + fill_now],
155 | )
156 | filled_test += fill_now
157 |
--------------------------------------------------------------------------------
/src/sedpack/io/merge_shard_infos.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Merge new shard lists into the dataset.
15 | """
16 |
17 | from collections import defaultdict
18 | from pathlib import Path
19 |
20 | from sedpack.io.shard_file_metadata import ShardsList, ShardListInfo
21 | from sedpack.io.types import HashChecksumT
22 |
23 |
24 | def merge_shard_infos(updates: list[ShardListInfo], dataset_root: Path,
25 | common: int, hashes: tuple[HashChecksumT,
26 | ...]) -> ShardListInfo:
27 | """Merge a list of new `ShardListInfo`s into the dataset.
28 |
29 | Args:
30 |
31 | updates (list[ShardListInfo]): New shards lists information to merge. All
32 | of these belonging to the same split.
33 |
34 | dataset_root (Path): Where the dataset is saved.
35 |
36 | common (int): A positive integer indicating how many directories deep are
37 | shared between all `ShardListInfo`s (relative from the `dataset_root`).
38 | When calling this function all `updates` have to be in the same split,
39 | thus one should set `common=1`. It is not guaranteed to update
40 | `shards_list.json` files all the way to the split when `common>1`.
41 |
42 | hashes (tuple[HashChecksumT, ...]): A tuple of hash checksum algorithms.
43 | """
44 | assert updates, "Nothing to update."
45 |
46 | # Check that all common prefixes are the same.
47 | for update in updates:
48 | if update.shard_list_info_file.file_path.parts[:common] != updates[
49 | 0].shard_list_info_file.file_path.parts[:common]:
50 | raise ValueError(
51 | f"Not all relative paths are the same "
52 | f"{update.shard_list_info_file.file_path.parts[:common]} vs "
53 | f"{updates[0].shard_list_info_file.file_path.parts[:common]}")
54 |
55 | # The current level ShardsList (if it is in the updates).
56 | root_shard_list: ShardsList = ShardsList.load_or_create(
57 | dataset_root_path=dataset_root,
58 | relative_path_self=Path().joinpath(
59 | *updates[0].shard_list_info_file.file_path.parts[:common]) /
60 | "shards_list.json",
61 | )
62 |
63 | # Divide the updates to current level shard lists and the deeper ones.
64 | current_level: list[ShardListInfo] = [
65 | update for update in updates
66 | if len(update.shard_list_info_file.file_path.parts) == common + 1
67 | ]
68 | deeper_updates: list[ShardListInfo] = [
69 | update for update in updates
70 | if len(update.shard_list_info_file.file_path.parts) > common + 1
71 | ]
72 | # Check correctness of this implementation (O(1) check just to make sure we
73 | # do not forget anything).
74 | assert len(current_level) + len(deeper_updates) == len(updates)
75 | # Since the ShardsList is saved in a file named shards_list.json there can
76 | # be at most one update in this depth. We can ignore it since it has been
77 | # loaded into root_shard_list.
78 | assert len(current_level) <= 1
79 |
80 | # Move children of root_shard_list into deeper_updates to let recursion
81 | # merge everything. The updates are listed after the already present.
82 | for child in root_shard_list.children_shard_lists:
83 | root_shard_list.number_of_examples -= child.number_of_examples
84 | deeper_updates = root_shard_list.children_shard_lists + deeper_updates
85 | root_shard_list.children_shard_lists = []
86 |
87 | # Recursively update children with one longer common prefix.
88 | # Sort the infos by common directory.
89 | recursively_update: defaultdict[str,
90 | list[ShardListInfo]] = defaultdict(list)
91 | for update in deeper_updates:
92 | # The path is at least `split / $DIRECTORY / shards_list.json` or
93 | # longer.
94 | current_path: Path = update.shard_list_info_file.file_path
95 | directory = str(current_path.parts[common])
96 | recursively_update[directory].append(update)
97 |
98 | # Recursively update.
99 | merged: dict[str, ShardListInfo] = { # Merge recursively.
100 | directory:
101 | merge_shard_infos(updates=recursive_updates,
102 | dataset_root=dataset_root,
103 | common=common + 1,
104 | hashes=hashes)
105 | for directory, recursive_updates in recursively_update.items()
106 | }
107 |
108 | # Merge the recursive into root_shard_list.
109 | for child in merged.values():
110 | root_shard_list.number_of_examples += child.number_of_examples
111 | root_shard_list.children_shard_lists.append(child)
112 |
113 | # Write the root_shard_list and return its ShardListInfo.
114 | return root_shard_list.write_config(
115 | dataset_root_path=dataset_root,
116 | hashes=hashes,
117 | )
118 |
--------------------------------------------------------------------------------
/tests/io/test_end2end_wrong_type.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023-2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from pathlib import Path
16 | import pytest
17 | from typing import Union
18 |
19 | import numpy as np
20 | import numpy.typing as npt
21 |
22 | import sedpack
23 | from sedpack.io import Dataset
24 | from sedpack.io import Metadata
25 | from sedpack.io.types import TRAIN_SPLIT, CompressionT, ShardFileTypeT
26 |
27 |
28 | def end2end(tmpdir: Union[str, Path], input_dtype: npt.DTypeLike,
29 | saved_dtype: npt.DTypeLike, method: str,
30 | shard_file_type: ShardFileTypeT, compression: CompressionT) -> None:
31 | array_of_values = np.random.random((10, 138)) * 256
32 | array_of_values = np.array(array_of_values, dtype=input_dtype)
33 | #print(f">>> {array_of_values.dtype = }")
34 | #print(f">>> {array_of_values.shape = }")
35 | #print(f">>> {array_of_values[0, :10] = }")
36 | #print(f">>> {array_of_values.tobytes() = }")
37 |
38 | tiny_experiment_path: Path = Path(tmpdir) / "e2e_experiment"
39 |
40 | # Create a dataset
41 |
42 | dataset_metadata = Metadata(description="Test of the lib")
43 |
44 | example_attributes = [
45 | sedpack.io.metadata.Attribute(
46 | name="filling_before",
47 | dtype=str(array_of_values.dtype),
48 | shape=array_of_values.shape,
49 | ),
50 | sedpack.io.metadata.Attribute(
51 | name="attribute_name",
52 | dtype=str(saved_dtype),
53 | shape=array_of_values[0].shape,
54 | ),
55 | sedpack.io.metadata.Attribute(
56 | name="filling_after",
57 | dtype=str(array_of_values.dtype),
58 | shape=array_of_values.shape,
59 | ),
60 | ]
61 |
62 | dataset_structure = sedpack.io.metadata.DatasetStructure(
63 | saved_data_description=example_attributes,
64 | compression=compression,
65 | examples_per_shard=256,
66 | shard_file_type=shard_file_type,
67 | )
68 |
69 | dataset = Dataset.create(
70 | path=tiny_experiment_path,
71 | metadata=dataset_metadata,
72 | dataset_structure=dataset_structure,
73 | )
74 |
75 | # Fill data in the dataset
76 |
77 | with dataset.filler() as filler:
78 | for attribute_value in array_of_values:
79 | filler.write_example(
80 | values={
81 | "filling_before": array_of_values,
82 | "attribute_name": attribute_value,
83 | "filling_after": array_of_values,
84 | },
85 | split=TRAIN_SPLIT,
86 | )
87 |
88 | # Check the data is correct
89 | # Reopen the dataset
90 | dataset = Dataset(tiny_experiment_path)
91 | dataset.check()
92 |
93 | match method:
94 | case "as_tfdataset":
95 | for i, example in enumerate(
96 | dataset.as_tfdataset(
97 | split=TRAIN_SPLIT,
98 | shuffle=0,
99 | repeat=False,
100 | batch_size=1,
101 | )):
102 | assert np.allclose(example["attribute_name"],
103 | array_of_values[i:i + 1])
104 | case "as_numpy_iterator":
105 | for i, example in enumerate(
106 | dataset.as_numpy_iterator(
107 | split=TRAIN_SPLIT,
108 | shuffle=0,
109 | repeat=False,
110 | )):
111 | assert np.allclose(example["attribute_name"],
112 | array_of_values[i])
113 | case "as_numpy_iterator_concurrent":
114 | for i, example in enumerate(
115 | dataset.as_numpy_iterator_concurrent(
116 | split=TRAIN_SPLIT,
117 | shuffle=0,
118 | repeat=False,
119 | )):
120 | assert np.allclose(example["attribute_name"],
121 | array_of_values[i])
122 |
123 | # We tested everything
124 | assert i + 1 == array_of_values.shape[
125 | 0], "Not all examples have been iterated"
126 |
127 |
128 | def test_end2end_wrong_value_type(tmpdir: Union[str, Path]) -> None:
129 | end2end(tmpdir=tmpdir,
130 | input_dtype="uint8",
131 | saved_dtype="int32",
132 | method="as_numpy_iterator",
133 | shard_file_type="fb",
134 | compression="LZ4")
135 |
136 |
137 | def test_end2end_wrong_value_type_no_cast(tmpdir: Union[str, Path]) -> None:
138 | with pytest.raises(ValueError) as e:
139 | end2end(tmpdir=tmpdir,
140 | input_dtype="float32",
141 | saved_dtype="uint8",
142 | method="as_numpy_iterator",
143 | shard_file_type="fb",
144 | compression="LZ4")
145 |
--------------------------------------------------------------------------------
/tests/io/test_as_tfdataset.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023-2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Test all shard file types with as_tfdataset.
15 | """
16 |
17 | import itertools
18 | from pathlib import Path
19 | from typing import Callable, Union
20 |
21 | import pytest
22 | import numpy as np
23 | import numpy.typing as npt
24 |
25 | import sedpack
26 | from sedpack.io import Dataset, Metadata
27 | from sedpack.io.shard.iterate_shard_base import T
28 | from sedpack.io.shard.shard_writer_flatbuffer import ShardWriterFlatBuffer
29 | from sedpack.io.shard.shard_writer_np import ShardWriterNP
30 | from sedpack.io.shard.shard_writer_tfrec import ShardWriterTFRec
31 | from sedpack.io.types import (
32 | ExampleT,
33 | TRAIN_SPLIT,
34 | CompressionT,
35 | ShardFileTypeT,
36 | )
37 | from sedpack.io.utils import is_module_present
38 |
39 |
40 | def end2end(
41 | tmpdir: Union[str, Path],
42 | dtype: npt.DTypeLike,
43 | shard_file_type: ShardFileTypeT,
44 | compression: CompressionT,
45 | process_record: Callable[[ExampleT], T] | None,
46 | ) -> None:
47 | array_of_values = np.random.random((1024, 138))
48 | array_of_values = array_of_values.astype(dtype)
49 |
50 | tiny_experiment_path: Path = Path(tmpdir) / "e2e_experiment"
51 |
52 | # Create a dataset
53 |
54 | dataset_metadata = Metadata(description="Test of the lib")
55 |
56 | example_attributes = [
57 | sedpack.io.metadata.Attribute(
58 | name="attribute_name",
59 | dtype=str(dtype),
60 | shape=array_of_values[0].shape,
61 | ),
62 | ]
63 |
64 | dataset_structure = sedpack.io.metadata.DatasetStructure(
65 | saved_data_description=example_attributes,
66 | compression=compression,
67 | examples_per_shard=256,
68 | shard_file_type=shard_file_type,
69 | )
70 |
71 | dataset = Dataset.create(
72 | path=tiny_experiment_path,
73 | metadata=dataset_metadata,
74 | dataset_structure=dataset_structure,
75 | )
76 |
77 | # Fill data in the dataset
78 |
79 | with dataset.filler() as filler:
80 | for attribute_value in array_of_values:
81 | filler.write_example(
82 | values={"attribute_name": attribute_value},
83 | split=TRAIN_SPLIT,
84 | )
85 |
86 | # Check the data is correct
87 | # Reopen the dataset
88 | dataset = Dataset(tiny_experiment_path)
89 | dataset.check()
90 |
91 | for i, example in enumerate(
92 | dataset.as_tfdataset(
93 | split=TRAIN_SPLIT,
94 | shuffle=0,
95 | repeat=False,
96 | batch_size=1,
97 | process_record=process_record,
98 | )):
99 | if process_record:
100 | assert np.allclose(
101 | example["attribute_name"],
102 | process_record({"attribute_name": array_of_values[i:i + 1]
103 | })["attribute_name"],
104 | )
105 | else:
106 | assert np.allclose(example["attribute_name"],
107 | array_of_values[i:i + 1])
108 |
109 | # We tested everything
110 | assert i + 1 == array_of_values.shape[
111 | 0], "Not all examples have been iterated"
112 |
113 |
114 | @pytest.mark.skipif(
115 | not is_module_present("tensorflow"),
116 | reason="TensorFlow is optional, skip test if not present.",
117 | )
118 | @pytest.mark.parametrize(
119 | "shard_file_type,compression,dtype,process_record",
120 | itertools.chain(
121 | itertools.product(
122 | ["tfrec"],
123 | ShardWriterTFRec.supported_compressions(),
124 | ["float16", "float32"],
125 | [
126 | None,
127 | lambda d: {
128 | k: v + 1 for k, v in d.items()
129 | },
130 | ],
131 | ),
132 | itertools.product(
133 | ["npz"],
134 | ShardWriterNP.supported_compressions(),
135 | ["float16", "float32"],
136 | [
137 | None,
138 | lambda d: {
139 | k: v + 1 for k, v in d.items()
140 | },
141 | ],
142 | ),
143 | itertools.product(
144 | ["fb"],
145 | ShardWriterFlatBuffer.supported_compressions(),
146 | ["float16", "float32"],
147 | [
148 | None,
149 | lambda d: {
150 | k: v + 1 for k, v in d.items()
151 | },
152 | ],
153 | ),
154 | ),
155 | )
156 | def test_end2end_as_tfdataset(
157 | shard_file_type: str,
158 | compression: str,
159 | dtype: str,
160 | process_record: Callable[[ExampleT], T] | None,
161 | tmp_path: Union[str, Path],
162 | ) -> None:
163 | end2end(
164 | tmpdir=tmp_path,
165 | dtype=dtype,
166 | shard_file_type=shard_file_type,
167 | compression=compression,
168 | process_record=process_record,
169 | )
170 |
--------------------------------------------------------------------------------
/docs/tutorials/quick_start/mnist_read_keras.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023-2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Read MNIST data and feed it to a neural network. For a tutorial with
15 | explanations see: https://google.github.io/sedpack/tutorials/mnist
16 |
17 | Example use:
18 | python mnist_save.py -d "~/Datasets/my_new_dataset/"
19 | python mnist_read_keras.py -d "~/Datasets/my_new_dataset/"
20 | """
21 | import argparse
22 | from typing import Any, Tuple
23 |
24 | import numpy as np
25 | import tensorflow as tf
26 | from tensorflow import keras
27 | from tensorflow.keras import layers
28 |
29 | from sedpack.io import Dataset
30 | from sedpack.io.types import ExampleT, TFModelT
31 |
32 |
33 | def get_model() -> TFModelT:
34 | """Return a CNN model.
35 | """
36 | input_shape = (28, 28)
37 | num_classes = 10
38 |
39 | input_data = keras.Input(shape=input_shape, name="input")
40 |
41 | x = input_data
42 | x = layers.Reshape((*input_shape, 1))(x)
43 | x = layers.Conv2D(32, kernel_size=(3, 3), activation="relu")(x)
44 | x = layers.MaxPooling2D(pool_size=(2, 2))(x)
45 | x = layers.Conv2D(64, kernel_size=(3, 3), activation="relu")(x)
46 | x = layers.MaxPooling2D(pool_size=(2, 2))(x)
47 | x = layers.Flatten()(x)
48 | x = layers.Dropout(0.5)(x)
49 | x = layers.Dense(num_classes, activation="softmax", name="digit")(x)
50 |
51 | model: TFModelT = keras.Model(inputs=input_data, outputs=x)
52 |
53 | model.summary()
54 | model.compile(loss="categorical_crossentropy",
55 | optimizer="adam",
56 | metrics=["accuracy"])
57 | return model
58 |
59 |
60 | def main() -> None:
61 | """Train a neural network on the MNIST dataset saved in the sedpack
62 | format.
63 | """
64 | parser = argparse.ArgumentParser(
65 | description=
66 | "Read MNIST in dataset-lib format and train a small neural network.")
67 | parser.add_argument("--dataset_directory",
68 | "-d",
69 | help="Where to load the dataset",
70 | required=True)
71 | parser.add_argument("--ascii_evaluations",
72 | "-e",
73 | help="How many images to print and evaluate",
74 | type=int,
75 | default=10)
76 | args = parser.parse_args()
77 |
78 | # Load train and test and train
79 | model = get_model()
80 |
81 | dataset = Dataset(args.dataset_directory) # Load the dataset
82 |
83 | # ExampleT: TypeAlias of dict[str, sedpack.io.types.AttributeValueT]
84 | def process_record(rec: ExampleT) -> Tuple[Any, Any]:
85 | output = rec["digit"]
86 | output = tf.one_hot(output, 10)
87 | return rec["input"], output
88 |
89 | # Load train and validation splits of the dataset
90 | batch_size = 128
91 | train_data = dataset.as_tfdataset(
92 | "train",
93 | batch_size=batch_size,
94 | process_record=process_record,
95 | )
96 | validation_data = dataset.as_tfdataset(
97 | "test", # validation split
98 | batch_size=batch_size,
99 | process_record=process_record,
100 | )
101 |
102 | steps_per_epoch = 100
103 | epochs = 10
104 | _ = model.fit(
105 | train_data,
106 | steps_per_epoch=steps_per_epoch,
107 | epochs=epochs,
108 | validation_data=validation_data,
109 | validation_steps=steps_per_epoch // 10,
110 | )
111 |
112 | # Evaluate the model on holdout.
113 | holdout_data = dataset.as_tfdataset(
114 | "holdout",
115 | batch_size=batch_size,
116 | process_record=process_record,
117 | repeat=False, # Single iteration over the dataset.
118 | )
119 | score = model.evaluate(
120 | holdout_data,
121 | verbose=0,
122 | )
123 | print(f"Test loss: {score[0]}")
124 | print(f"Test accuracy: {100 * score[1]:.2f}%")
125 |
126 | evaluated: int = 0
127 | ascii_shades = " .-/X0#"
128 | for example in dataset.as_numpy_iterator(split="holdout",
129 | process_record=process_record):
130 | # Stop after a few evaluations.
131 | evaluated += 1
132 | if evaluated >= args.ascii_evaluations:
133 | break
134 |
135 | # Pass just the input (the handwritten digit image) to the model and get
136 | # the predicted class as the class with highest probability.
137 | image: list[list[float]] = example[0] # type: ignore[assignment,index]
138 | # Note that the model still expects a batch, here we pass a batch of one
139 | # image.
140 | predicted_class: int = np.argmax(model(np.expand_dims(
141 | image, axis=0))) # type: ignore[assignment]
142 | correct_class: int = np.argmax(
143 | example[1]) # type: ignore[assignment,index]
144 | print("")
145 | print(f"Predicted: {predicted_class} (should be {correct_class}) for")
146 | # Turn into ASCII art
147 | for row in image:
148 | print("".join(
149 | ascii_shades[int(pixel * len(ascii_shades))] for pixel in row))
150 |
151 |
152 | if __name__ == "__main__":
153 | main()
154 |
--------------------------------------------------------------------------------