├── .circleci
└── config.yml
├── .clang-format
├── .gitattributes
├── .gitignore
├── .pre-commit-config.yaml
├── ACKNOWLEDGMENTS.md
├── CMakeLists.txt
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── benchmarks
├── comparative
│ ├── README.md
│ ├── caltech101
│ │ ├── .gitignore
│ │ ├── README.md
│ │ ├── mlx_data.py
│ │ ├── pytorch.py
│ │ ├── run_caltech.sh
│ │ ├── tfds.py
│ │ └── utils.py
│ ├── librispeech
│ │ ├── .gitignore
│ │ ├── README.md
│ │ ├── mlx_data.py
│ │ ├── pytorch.py
│ │ ├── run_librispeech.sh
│ │ ├── tfds.py
│ │ └── utils.py
│ └── wikitext
│ │ ├── .gitignore
│ │ ├── README.md
│ │ ├── mlx_data.py
│ │ ├── pytorch.py
│ │ ├── run_wikitext.sh
│ │ ├── tfds.py
│ │ └── utils.py
└── utils.py
├── cmake
├── FindFFMPEG.cmake
├── FindFLAC.cmake
├── FindJPEGTURBO.cmake
├── FindOgg.cmake
├── FindSampleRate.cmake
├── FindSndFile.cmake
├── FindVorbis.cmake
├── bxzstr-v1.2.3.patch
└── pybind11-v2.11.1.patch
├── docs
├── .gitignore
├── Makefile
├── README.md
├── requirements.txt
└── src
│ ├── _static
│ ├── mlx_logo.png
│ └── mlx_logo_dark.png
│ ├── _templates
│ └── data_core_modules.rst
│ ├── buffers_streams_samples.rst
│ ├── conf.py
│ ├── hf_datasets_streams.rst
│ ├── index.rst
│ ├── install.rst
│ ├── python
│ ├── buffer.rst
│ ├── common_datasets.rst
│ ├── dataset.rst
│ ├── features.rst
│ ├── miscellaneous.rst
│ ├── stream.rst
│ └── tokenizing.rst
│ └── quick_start.rst
├── mlx-data.pc.in
├── mlx
└── data
│ ├── Array.cpp
│ ├── Array.h
│ ├── Buffer.cpp
│ ├── Buffer.h
│ ├── Dataset.cpp
│ ├── Dataset.h
│ ├── Sample.cpp
│ ├── Sample.h
│ ├── Stream.cpp
│ ├── Stream.h
│ ├── buffer
│ ├── Batch.cpp
│ ├── Batch.h
│ ├── Buffer.cpp
│ ├── Buffer.h
│ ├── DynamicBatch.cpp
│ ├── DynamicBatch.h
│ ├── FilesFromTAR.cpp
│ ├── FilesFromTAR.h
│ ├── FromStream.cpp
│ ├── FromStream.h
│ ├── FromVector.cpp
│ ├── FromVector.h
│ ├── Partition.cpp
│ ├── Partition.h
│ ├── Perm.cpp
│ ├── Perm.h
│ ├── Shuffle.cpp
│ ├── Shuffle.h
│ ├── Transform.cpp
│ └── Transform.h
│ ├── core
│ ├── AWSFileFetcher.cpp
│ ├── AWSFileFetcher.h
│ ├── BPETokenizer.cpp
│ ├── BPETokenizer.h
│ ├── BatchShape.cpp
│ ├── BatchShape.h
│ ├── CSVReader.cpp
│ ├── CSVReader.h
│ ├── FileFetcher.cpp
│ ├── FileFetcher.h
│ ├── Graph.cpp
│ ├── Graph.h
│ ├── Levenshtein.cpp
│ ├── Levenshtein.h
│ ├── Numpy.cpp
│ ├── Numpy.h
│ ├── State.cpp
│ ├── State.h
│ ├── TARReader.cpp
│ ├── TARReader.h
│ ├── ThreadController.cpp
│ ├── ThreadController.h
│ ├── ThreadPool.cpp
│ ├── ThreadPool.h
│ ├── Tokenizer.cpp
│ ├── Tokenizer.h
│ ├── Trie.h
│ ├── Utils.cpp
│ ├── Utils.h
│ ├── audio
│ │ ├── Audio.cpp
│ │ ├── Audio.h
│ │ ├── AudioPrivate.h
│ │ ├── AudioSampleRate.cpp
│ │ └── AudioSndfile.cpp
│ ├── image
│ │ ├── Image.h
│ │ ├── ImageIO.cpp
│ │ ├── ImageJPEG.cpp
│ │ ├── ImagePrivate.h
│ │ ├── ImageSTBI.cpp
│ │ └── ImageTransform.cpp
│ ├── imemstream.h
│ └── video
│ │ ├── Video.cpp
│ │ ├── Video.h
│ │ ├── VideoFFMPEG.cpp
│ │ └── VideoPrivate.h
│ ├── op
│ ├── FilterByShape.cpp
│ ├── FilterByShape.h
│ ├── FilterKey.cpp
│ ├── FilterKey.h
│ ├── ImageTransform.cpp
│ ├── ImageTransform.h
│ ├── KeyTransform.cpp
│ ├── KeyTransform.h
│ ├── LoadAudio.cpp
│ ├── LoadAudio.h
│ ├── LoadFile.cpp
│ ├── LoadFile.h
│ ├── LoadImage.cpp
│ ├── LoadImage.h
│ ├── LoadNumpy.cpp
│ ├── LoadNumpy.h
│ ├── LoadVideo.cpp
│ ├── LoadVideo.h
│ ├── Op.cpp
│ ├── Op.h
│ ├── Pad.cpp
│ ├── Pad.h
│ ├── ReadFromTAR.cpp
│ ├── ReadFromTAR.h
│ ├── RemoveValue.cpp
│ ├── RemoveValue.h
│ ├── RenameKey.cpp
│ ├── RenameKey.h
│ ├── Replace.cpp
│ ├── Replace.h
│ ├── SampleTransform.cpp
│ ├── SampleTransform.h
│ ├── SaveImage.cpp
│ ├── SaveImage.h
│ ├── Shape.cpp
│ ├── Shape.h
│ ├── Shard.cpp
│ ├── Shard.h
│ ├── Slice.cpp
│ ├── Slice.h
│ ├── Squeeze.cpp
│ ├── Squeeze.h
│ ├── Tokenize.cpp
│ └── Tokenize.h
│ └── stream
│ ├── Batch.cpp
│ ├── Batch.h
│ ├── Buffered.cpp
│ ├── Buffered.h
│ ├── CSVReader.cpp
│ ├── CSVReader.h
│ ├── Compose.cpp
│ ├── Compose.h
│ ├── DynamicBatch.cpp
│ ├── DynamicBatch.h
│ ├── FromBuffer.cpp
│ ├── FromBuffer.h
│ ├── LineReader.cpp
│ ├── LineReader.h
│ ├── OrderedPrefetch.cpp
│ ├── OrderedPrefetch.h
│ ├── Partition.cpp
│ ├── Partition.h
│ ├── Prefetch.cpp
│ ├── Prefetch.h
│ ├── Repeat.cpp
│ ├── Repeat.h
│ ├── Shuffle.cpp
│ ├── Shuffle.h
│ ├── SlidingWindow.cpp
│ ├── SlidingWindow.h
│ ├── Stream.cpp
│ ├── Stream.h
│ ├── Transform.cpp
│ └── Transform.h
├── pyproject.toml
├── python
├── mlx
│ └── data
│ │ ├── __init__.py
│ │ ├── core.py
│ │ ├── datasets
│ │ ├── __init__.py
│ │ ├── cifar.py
│ │ ├── common.py
│ │ ├── image_folder.py
│ │ ├── imagenet.py
│ │ ├── librispeech.py
│ │ ├── libritts_r.py
│ │ ├── mnist.py
│ │ ├── speechcommands.py
│ │ └── wikitext.py
│ │ ├── features
│ │ ├── __init__.py
│ │ └── audio.py
│ │ └── tokenizer_helpers.py
├── src
│ ├── CMakeLists.txt
│ ├── wrap.cpp
│ ├── wrap.h
│ ├── wrap_buffer.cpp
│ ├── wrap_core.cpp
│ ├── wrap_dataset.h
│ └── wrap_stream.cpp
└── tests
│ ├── test_bpe.py
│ ├── test_buffer.py
│ ├── test_general_ops.py
│ └── test_replace.py
├── setup.py
├── super
├── CMakeLists.txt
└── cmake
│ ├── aws-1.11.557.patch
│ ├── bzip2-1.0.8.patch
│ ├── flac-1.5.0.patch
│ └── xvidcore-1.3.7.patch
└── tests
└── CMakeLists.txt
/.gitattributes:
--------------------------------------------------------------------------------
1 | *.jpeg filter=lfs diff=lfs merge=lfs -text
2 | *.jpg filter=lfs diff=lfs merge=lfs -text
3 | *.png filter=lfs diff=lfs merge=lfs -text
4 | *.mov filter=lfs diff=lfs merge=lfs -text
5 |
6 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # build artifacts
2 | build/
3 | xcode/
4 | dist/
5 | __pycache__/
6 | *.egg-info/
7 | *.so
8 | *.c
9 | *.swp
10 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/pre-commit/mirrors-clang-format
3 | rev: v19.1.7
4 | hooks:
5 | - id: clang-format
6 | - repo: https://github.com/psf/black-pre-commit-mirror
7 | rev: 25.1.0
8 | hooks:
9 | - id: black
10 | - repo: https://github.com/pycqa/isort
11 | rev: 6.0.0
12 | hooks:
13 | - id: isort
14 | args:
15 | - --profile=black
16 | - repo: https://github.com/cheshirekow/cmake-format-precommit
17 | rev: v0.6.13
18 | hooks:
19 | - id: cmake-format
20 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing to MLX data
2 |
3 | We want to make contributing to this project as easy and transparent as
4 | possible.
5 |
6 | ## Pull Requests
7 |
8 | 1. Fork and submit pull requests to the repo.
9 | 2. If you've added code that should be tested, add tests.
10 | 3. If a change is likely to impact efficiency, run some of the benchmarks before
11 | and after the change. Examples of benchmarks can be found in `benchmarks/`.
12 | 4. If you've changed APIs, update the documentation.
13 | 5. Every PR should have passing tests and at least one review.
14 | 6. For code formatting install `pre-commit` using something like `pip install pre-commit` and run `pre-commit install`.
15 | This should install hooks for running `black` and `clang-format` to ensure
16 | consistent style for C++ and python code.
17 |
18 | You can also run the formatters manually as follows:
19 |
20 | ```
21 | clang-format -i file.cpp
22 | ```
23 |
24 | ```
25 | black file.py
26 | ```
27 |
28 | or run `pre-commit run --all-files` to check all files in the repo.
29 |
30 | ## Issues
31 |
32 | We use GitHub issues to track public bugs. Please ensure your description is
33 | clear and has sufficient instructions to be able to reproduce the issue.
34 |
35 | ## License
36 |
37 | By contributing to MLX data, you agree that your contributions will be licensed
38 | under the LICENSE file in the root directory of this source tree.
39 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright © 2023 Apple Inc. All rights reserved.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/benchmarks/comparative/README.md:
--------------------------------------------------------------------------------
1 | Comparative Benchmarks
2 | ======================
3 |
4 | In this folder we compare `mlx.data` with PyTorch DataLoaders and `tf.data`. We
5 | try to keep the comparison as fair as possible but like any benchmark it can
6 | never replace real-world performance measurements on your particular use-case.
7 |
8 | The goal is to show that `mlx.data` is concise and fast.
9 |
10 | Running the benchmarks
11 | ----------------------
12 |
13 | Each folder has separate instructions on how to download the data and run the benchmark.
14 |
--------------------------------------------------------------------------------
/benchmarks/comparative/caltech101/.gitignore:
--------------------------------------------------------------------------------
1 | data
2 |
--------------------------------------------------------------------------------
/benchmarks/comparative/caltech101/README.md:
--------------------------------------------------------------------------------
1 | Caltech 101 Benchmark
2 | =====================
3 |
4 | This benchmark creates a simple image classification data loading pipeline from
5 | a set of directories containing images. Practically, what is implemented in
6 | `torchvision`'s `ImageFolder` dataset.
7 |
8 | For PyTorch we use the aforementioned `ImageFolder` dataset while for `tf.data`
9 | we build the pipeline according to [their image classification
10 | tutorial](https://www.tensorflow.org/tutorials/load_data/images).
11 |
12 | We apply a minimal set of transforms, namely we resize the original image to
13 | 256 pixels across the smallest dimension, center-crop to 224 by 224 pixels and
14 | transform the pixel values to floating point in [0, 1].
15 |
16 | Getting the data
17 | ----------------
18 |
19 | You can download the data manually from
20 | [https://data.caltech.edu/records/mzrjq-6wc02](https://data.caltech.edu/records/mzrjq-6wc02),
21 | or use the provided bash script to download the data, extract them and run the
22 | benchmarks one after the other as follows:
23 |
24 | bash run_caltech.sh
25 |
26 |
27 | Running the benchmarks
28 | ----------------------
29 |
30 | If you already have the data, you can run each benchmark by simply pointing it
31 | to directory that holds the Caltech101 dataset. Just make sure that the archive
32 | containing the images is untarred.
33 |
34 | python mlx_data.py /path/to/Caltech101
35 |
36 |
37 | Dependencies
38 | ------------
39 |
40 | To automatically download the data you need `wget` and `unzip`. To run the
41 | benchmarks you need `tensorflow` and `PyTorch` with `torchvision` installed.
42 |
--------------------------------------------------------------------------------
/benchmarks/comparative/caltech101/mlx_data.py:
--------------------------------------------------------------------------------
1 | # Copyright © 2023 Apple Inc.
2 |
3 | import argparse
4 | from pathlib import Path
5 |
6 | from utils import Benchmark
7 |
8 | import mlx.data as dx
9 |
10 |
11 | def files_and_classes(root: Path):
12 | files = [str(f) for f in root.glob("**/*.jpg")]
13 | files = [f for f in files if "BACKGROUND" not in f]
14 | classes = dict(
15 | map(reversed, enumerate(sorted(set(f.split("/")[-2] for f in files))))
16 | )
17 |
18 | return [
19 | dict(image=f.encode("ascii"), label=classes[f.split("/")[-2]]) for f in files
20 | ]
21 |
22 |
23 | def iterate(args, workers):
24 | root = Path(args.data_dir)
25 |
26 | if args.ordered_prefetch:
27 | dset = (
28 | dx.buffer_from_vector(files_and_classes(root))
29 | .shuffle()
30 | .load_image("image")
31 | .image_resize_smallest_side("image", 256)
32 | .image_center_crop("image", 224, 224)
33 | .batch(args.batch_size)
34 | .key_transform("image", lambda x: x.astype("float32") / 255)
35 | .ordered_prefetch(workers, workers)
36 | )
37 | else:
38 | dset = (
39 | dx.buffer_from_vector(files_and_classes(root))
40 | .shuffle()
41 | .to_stream()
42 | .load_image("image")
43 | .image_resize_smallest_side("image", 256)
44 | .image_center_crop("image", 224, 224)
45 | .batch(args.batch_size)
46 | .key_transform("image", lambda x: x.astype("float32") / 255)
47 | .prefetch(workers, workers)
48 | )
49 |
50 | cnt = 0
51 | for sample in dset:
52 | cnt += 1
53 | return cnt
54 |
55 |
56 | if __name__ == "__main__":
57 | parser = argparse.ArgumentParser()
58 | parser.add_argument("data_dir")
59 | parser.add_argument("--batch_size", type=int, default=32)
60 | parser.add_argument("--ordered_prefetch", action="store_true")
61 | args = parser.parse_args()
62 |
63 | benchmark = Benchmark("MLX Caltech 101")
64 | for i in range(3):
65 | benchmark.log_run("iterate_no_workers", iterate, args, 1)
66 |
67 | for i in range(3):
68 | benchmark.log_run("iterate_workers_8", iterate, args, 8)
69 |
70 | for i in range(3):
71 | benchmark.log_run("iterate_workers_16", iterate, args, 16)
72 | benchmark.report()
73 |
--------------------------------------------------------------------------------
/benchmarks/comparative/caltech101/pytorch.py:
--------------------------------------------------------------------------------
1 | # Copyright © 2023 Apple Inc.
2 |
3 | import argparse
4 |
5 | import torch
6 | from torchvision.datasets import ImageFolder
7 | from torchvision.transforms import v2 as transforms
8 | from utils import Benchmark
9 |
10 |
11 | class Caltech101(ImageFolder):
12 | def __init__(self, root: str):
13 | super().__init__(
14 | str(root),
15 | transform=transforms.Compose(
16 | [
17 | transforms.ToImage(),
18 | transforms.Resize(256),
19 | transforms.CenterCrop(224),
20 | transforms.ToDtype(torch.float32, scale=True),
21 | ]
22 | ),
23 | is_valid_file=self.is_valid_file,
24 | )
25 |
26 | def is_valid_file(self, filepath: str):
27 | return "BACKGROUND" not in filepath and filepath.endswith("jpg")
28 |
29 |
30 | def iterate(args, workers):
31 | dset = Caltech101(args.data_dir)
32 | data_loader = torch.utils.data.DataLoader(
33 | dset, shuffle=True, batch_size=args.batch_size, num_workers=workers
34 | )
35 |
36 | cnt = 0
37 | for batch in data_loader:
38 | cnt += 1
39 | return cnt
40 |
41 |
42 | if __name__ == "__main__":
43 | parser = argparse.ArgumentParser()
44 | parser.add_argument("data_dir")
45 | parser.add_argument("--batch_size", type=int, default=32)
46 | args = parser.parse_args()
47 |
48 | benchmark = Benchmark("PyTorch Caltech 101")
49 | for i in range(3):
50 | benchmark.log_run("iterate_no_workers", iterate, args, 0)
51 |
52 | for i in range(3):
53 | benchmark.log_run("iterate_workers_8", iterate, args, 8)
54 |
55 | for i in range(3):
56 | benchmark.log_run("iterate_workers_16", iterate, args, 16)
57 | benchmark.report()
58 |
--------------------------------------------------------------------------------
/benchmarks/comparative/caltech101/run_caltech.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Get the data
4 | if [ ! -d data ]; then
5 | mkdir data
6 | pushd data
7 | wget "https://data.caltech.edu/records/mzrjq-6wc02/files/caltech-101.zip?download=1" -O caltech-101.zip
8 | unzip caltech-101.zip && rm caltech-101.zip
9 | pushd caltech-101
10 | tar -zxf 101_ObjectCategories.tar.gz && rm 101_ObjectCategories.tar.gz
11 | popd
12 | popd
13 | fi
14 |
15 | # Run the benchmarks
16 | python pytorch.py data/caltech-101 2>/dev/null
17 | echo "============="
18 | python tfds.py data/caltech-101 2>/dev/null
19 | echo "============="
20 | python mlx_data.py data/caltech-101
21 |
--------------------------------------------------------------------------------
/benchmarks/comparative/caltech101/tfds.py:
--------------------------------------------------------------------------------
1 | # Copyright © 2023 Apple Inc.
2 |
3 | import argparse
4 | from functools import partial
5 | from pathlib import Path
6 |
7 | import tensorflow as tf
8 | from utils import Benchmark
9 |
10 |
11 | def files_and_classes(root: Path):
12 | files = [str(f) for f in root.glob("**/*.jpg")]
13 | files = [f for f in files if "BACKGROUND" not in f]
14 | classes = dict(
15 | map(reversed, enumerate(sorted(set(f.split("/")[-2] for f in files))))
16 | )
17 | class_per_file = [classes[f.split("/")[-2]] for f in files]
18 |
19 | return files, class_per_file
20 |
21 |
22 | def to(dtype, x):
23 | if isinstance(x, list):
24 | return [to(dtype, xi) for xi in x]
25 | return tf.cast(x, dtype=dtype)
26 |
27 |
28 | to_float32 = partial(to, tf.float32)
29 | to_int32 = partial(to, tf.int32)
30 |
31 |
32 | def process_sample(sample):
33 | file_path = sample["file"]
34 | data = tf.io.read_file(file_path)
35 | img = tf.io.decode_jpeg(data, channels=3)
36 |
37 | height, width = to_float32([tf.shape(img)[0], tf.shape(img)[1]])
38 | min_side = to_float32(tf.minimum(height, width))
39 | scale_factor = min_side / tf.constant(256, dtype=tf.float32)
40 | img = tf.image.resize(img, to_int32([scale_factor * height, scale_factor * width]))
41 | img = tf.image.resize_with_crop_or_pad(img, 224, 224)
42 | img = img / tf.constant(255, dtype=tf.float32)
43 |
44 | return dict(image=img, label=sample["label"])
45 |
46 |
47 | def iterate(args, workers):
48 | root = Path(args.data_dir)
49 | files, classes = files_and_classes(root)
50 | ds = tf.data.Dataset.zip(
51 | dict(
52 | file=tf.data.Dataset.from_tensor_slices(files),
53 | label=tf.data.Dataset.from_tensor_slices(classes),
54 | )
55 | )
56 | ds = (
57 | ds.shuffle(buffer_size=1000)
58 | .map(process_sample)
59 | .batch(args.batch_size)
60 | .prefetch(workers)
61 | )
62 | options = tf.data.Options()
63 | options.threading.private_threadpool_size = workers
64 |
65 | cnt = 0
66 | for batch in ds.with_options(options):
67 | cnt += 1
68 | return cnt
69 |
70 |
71 | if __name__ == "__main__":
72 | parser = argparse.ArgumentParser()
73 | parser.add_argument("data_dir")
74 | parser.add_argument("--batch_size", type=int, default=32)
75 | args = parser.parse_args()
76 |
77 | benchmark = Benchmark("TFDS Caltech 101")
78 | for i in range(3):
79 | benchmark.log_run("iterate_no_workers", iterate, args, 1)
80 |
81 | for i in range(3):
82 | benchmark.log_run("iterate_workers_8", iterate, args, 8)
83 |
84 | for i in range(3):
85 | benchmark.log_run("iterate_workers_16", iterate, args, 16)
86 | benchmark.report()
87 |
--------------------------------------------------------------------------------
/benchmarks/comparative/caltech101/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright © 2023 Apple Inc.
2 |
3 | import time
4 | from collections import defaultdict
5 |
6 |
7 | class Benchmark:
8 | def __init__(self, name):
9 | self.name = name
10 | self.runtimes = defaultdict(list)
11 | self.n_samples = defaultdict(int)
12 |
13 | def log_run(self, run_name, fn, *args, **kwargs):
14 | start = time.time()
15 | n_samples = fn(*args, **kwargs)
16 | end = time.time()
17 | self.runtimes[run_name].append(end - start)
18 | self.n_samples[run_name] = n_samples
19 |
20 | def report(self):
21 | print(f"Benchmark {self.name}")
22 | print()
23 | klengths = max(map(len, self.runtimes.keys()))
24 | table = []
25 | for k in self.runtimes.keys():
26 | n_runs = len(self.runtimes[k])
27 | avg_time = sum(self.runtimes[k]) / n_runs
28 | avg_throughput = self.n_samples[k] / avg_time
29 | table.append(
30 | [
31 | k,
32 | f"{avg_time:.3f} s",
33 | f"{avg_throughput:.3f} samples/s",
34 | f"{n_runs} runs",
35 | ]
36 | )
37 |
38 | column_widths = [
39 | max(len(table[j][i]) for j in range(len(table)))
40 | for i in range(len(table[0]))
41 | ]
42 | for row in table:
43 | for i, (v, w) in enumerate(zip(row, column_widths)):
44 | if i == 0:
45 | print(v, " " * (w - len(v)), " " * 4, end="")
46 | else:
47 | print(" " * (w - len(v)), v, " " * 4, end="")
48 | print()
49 |
--------------------------------------------------------------------------------
/benchmarks/comparative/librispeech/.gitignore:
--------------------------------------------------------------------------------
1 | data
2 |
--------------------------------------------------------------------------------
/benchmarks/comparative/librispeech/README.md:
--------------------------------------------------------------------------------
1 | LibriSpeech Benchmark
2 | =====================
3 |
4 | This benchmark creates a simple ASR pipeline from the structure of the
5 | LibriSpeech dataset. Simply put there is a set of text files containing lists
6 | of audio files and their transcripts. We transform the audio to Mel scaled
7 | spectrograms and the transcripts to tokens from a SentencePiece tokenizer.
8 |
9 | For PyTorch we use the torchaudio LIBRISPEECH dataset and the `MelSpectrogram`
10 | transform with soundfile to load the audio. For tensorflow we use
11 | `tensorflow_text` for the tokenizer and `tensorflow_io` to read the audio.
12 |
13 | A more realistic ASR pipeline would load the audio info first in order to
14 | subsequently group the files by length. This can be achieved efficiently in `mlx.data`
15 | using buffered streams.
16 |
17 | Getting the data
18 | ----------------
19 |
20 | You can download the data manually from
21 | [https://www.openslr.org/12](https://www.openslr.org/12) or use the provided
22 | bash script to download the data, the tokenizer and run the benchmarks. The
23 | tokenizer model we use is the Llama SPM model provided by Huggingface.
24 |
25 | bash run_librispeech.sh
26 |
27 |
28 | Running the benchmarks
29 | ----------------------
30 |
31 | If you already have the data, you can run each benchmark by simply pointing it
32 | to directory that holds the LibriSpeech dataset. Just make sure the archive is
33 | extracted.
34 |
35 | OMP_NUM_THREADS=1 python mlx_data.py \
36 | --tokenizer_file /path/to/tokenizer.model \
37 | /path/to/librispeech/LibriSpeech/dev-clean
38 |
39 | You should run the PyTorch and `mlx.data` benchmarks with `OMP_NUM_THREADS=1`
40 | to avoid thread contention from computing the FFT and other operations on CPU.
41 |
42 | Dependencies
43 | ------------
44 |
45 | To automatically download the data you need `wget` and `unzip`. To run the
46 | benchmarks you need the following dependencies:
47 |
48 | - `tensorflow`
49 | - `tensorflow_io`
50 | - `tensorflow_text`
51 | - `sentencepiece`
52 | - `torch`
53 | - `torchaudio`
54 | - `intel-numpy` for fast FFT
55 |
56 | Since the featurization in `mlx.data` happens in `numpy` we propose you install
57 | `intel-numpy` for a fast FFT implementation.
58 |
--------------------------------------------------------------------------------
/benchmarks/comparative/librispeech/mlx_data.py:
--------------------------------------------------------------------------------
1 | # Copyright © 2023 Apple Inc.
2 |
3 | import argparse
4 | from pathlib import Path
5 |
6 | from mlx.data.features import mfsc
7 | from mlx.data.tokenizer_helpers import read_trie_from_spm
8 | from utils import Benchmark
9 |
10 | import mlx.data as dx
11 |
12 |
13 | def to_audio_and_transcript(sample):
14 | # Split the line
15 | file_part, transcript = bytes(sample["line"]).split(b" ", 1)
16 |
17 | # Extract the audio path
18 | parts = file_part.split(b"-")
19 | parts[-1] = file_part + b".flac"
20 | audio_path = b"/".join(parts)
21 |
22 | # Prepare the transcript
23 | transcript = transcript.lower()
24 |
25 | return {"audio": audio_path, "transcript": transcript}
26 |
27 |
28 | def iterate(args, workers):
29 | root = Path(args.data_dir)
30 |
31 | # Load the list of lists of files
32 | filelist = [{"file": str(f).encode("ascii")} for f in root.glob("**/*.txt")]
33 |
34 | # Load the tokenizer
35 | trie, _ = read_trie_from_spm(args.tokenizer_file)
36 |
37 | dset = (
38 | dx.buffer_from_vector(filelist)
39 | .shuffle()
40 | .to_stream()
41 | .line_reader_from_key("file", "line")
42 | # Transform the lines into an audio path and the transcript
43 | .sample_transform(to_audio_and_transcript)
44 | # Load the audio and extract features
45 | .load_audio("audio", prefix=args.data_dir)
46 | .squeeze("audio")
47 | .key_transform("audio", mfsc(128, 16000))
48 | .shape("audio", "audio_length", 0)
49 | # Tokenize the transcript
50 | .tokenize("transcript", trie)
51 | .pad("transcript", 0, 1, 0, trie.search("").id)
52 | .pad("transcript", 0, 0, 1, trie.search("").id)
53 | .shape("transcript", "transcript_length", 0)
54 | # Batch and prefetch
55 | .batch(args.batch_size)
56 | .prefetch(workers, workers)
57 | )
58 |
59 | cnt = 0
60 | for sample in dset:
61 | cnt += 1
62 | return cnt
63 |
64 |
65 | if __name__ == "__main__":
66 | parser = argparse.ArgumentParser()
67 | parser.add_argument("data_dir")
68 | parser.add_argument("--batch_size", type=int, default=32)
69 | parser.add_argument("--tokenizer_file", default="tokenizer.model")
70 | args = parser.parse_args()
71 |
72 | benchmark = Benchmark("MLX Librispeech")
73 | for i in range(3):
74 | benchmark.log_run("iterate_no_workers", iterate, args, 1)
75 |
76 | for i in range(3):
77 | benchmark.log_run("iterate_workers_8", iterate, args, 8)
78 |
79 | for i in range(3):
80 | benchmark.log_run("iterate_workers_16", iterate, args, 16)
81 | benchmark.report()
82 |
--------------------------------------------------------------------------------
/benchmarks/comparative/librispeech/run_librispeech.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Get the data
4 | if [ ! -d data ]; then
5 | mkdir data
6 | pushd data
7 | wget "https://www.openslr.org/resources/12/dev-clean.tar.gz" -O dev-clean.tar.gz
8 | tar -zxf dev-clean.tar.gz && rm dev-clean.tar.gz
9 | wget "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer.model" -O tokenizer.model
10 | popd
11 | fi
12 |
13 | # Run the benchmarks
14 | OMP_NUM_THREADS=1 python pytorch.py --tokenizer_file data/tokenizer.model data/LibriSpeech/dev-clean 2>/dev/null
15 | echo "============="
16 | python tfds.py --tokenizer_file data/tokenizer.model data/LibriSpeech/dev-clean 2>/dev/null
17 | echo "============="
18 | python mlx_data.py --tokenizer_file data/tokenizer.model data/LibriSpeech/dev-clean
19 |
--------------------------------------------------------------------------------
/benchmarks/comparative/librispeech/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright © 2023 Apple Inc.
2 |
3 | import time
4 | from collections import defaultdict
5 |
6 |
7 | class Benchmark:
8 | def __init__(self, name):
9 | self.name = name
10 | self.runtimes = defaultdict(list)
11 | self.n_samples = defaultdict(int)
12 |
13 | def log_run(self, run_name, fn, *args, **kwargs):
14 | start = time.time()
15 | n_samples = fn(*args, **kwargs)
16 | end = time.time()
17 | self.runtimes[run_name].append(end - start)
18 | self.n_samples[run_name] = n_samples
19 |
20 | def report(self):
21 | print(f"Benchmark {self.name}")
22 | print()
23 | klengths = max(map(len, self.runtimes.keys()))
24 | table = []
25 | for k in self.runtimes.keys():
26 | n_runs = len(self.runtimes[k])
27 | avg_time = sum(self.runtimes[k]) / n_runs
28 | avg_throughput = self.n_samples[k] / avg_time
29 | table.append(
30 | [
31 | k,
32 | f"{avg_time:.3f} s",
33 | f"{avg_throughput:.3f} samples/s",
34 | f"{n_runs} runs",
35 | ]
36 | )
37 |
38 | column_widths = [
39 | max(len(table[j][i]) for j in range(len(table)))
40 | for i in range(len(table[0]))
41 | ]
42 | for row in table:
43 | for i, (v, w) in enumerate(zip(row, column_widths)):
44 | if i == 0:
45 | print(v, " " * (w - len(v)), " " * 4, end="")
46 | else:
47 | print(" " * (w - len(v)), v, " " * 4, end="")
48 | print()
49 |
--------------------------------------------------------------------------------
/benchmarks/comparative/wikitext/.gitignore:
--------------------------------------------------------------------------------
1 | data
2 |
--------------------------------------------------------------------------------
/benchmarks/comparative/wikitext/README.md:
--------------------------------------------------------------------------------
1 | WikiText Benchmark
2 | =====================
3 |
4 | This benchmark reads the Wikitext103 dataset using a python generator,
5 | tokenizes the text using an SPM model, computes a sliding window of 1,025
6 | tokens and uses a shuffle buffer of 1,000 samples.
7 |
8 | Getting the data
9 | ----------------
10 |
11 | You can download the data manually from
12 | [blog.salesforceairesearch.com/.../](https://blog.salesforceairesearch.com/the-wikitext-long-term-dependency-language-modeling-dataset/)
13 | or use the provided bash script to download the data, the tokenizer and run the
14 | benchmarks. The tokenizer model we use is the Llama SPM model provided by
15 | Huggingface.
16 |
17 | bash run_wikitext.sh
18 |
19 |
20 | Running the benchmarks
21 | ----------------------
22 |
23 | If you already have the data, you can run each benchmark by simply pointing it
24 | to directory that holds the wikitext dataset. Just make sure the archive is
25 | extracted.
26 |
27 | OMP_NUM_THREADS=1 python mlx_data.py \
28 | --tokenizer_file /path/to/tokenizer.model \
29 | /path/to/wikitext/wikitext-103-raw
30 |
31 | Dependencies
32 | ------------
33 |
34 | To automatically download the data you need `wget` and `unzip`. To run the
35 | TF benchmark you need the following dependencies:
36 |
37 | - `tensorflow`
38 | - `tensorflow_io`
39 | - `tensorflow_text`
40 | - `sentencepiece`
41 |
--------------------------------------------------------------------------------
/benchmarks/comparative/wikitext/run_wikitext.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Get the data
4 | if [ ! -d data ]; then
5 | mkdir data
6 | pushd data
7 | wget "https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-raw-v1.zip" -O wikitext-103-raw-v1.zip
8 | unzip wikitext-103-raw-v1.zip && rm wikitext-103-raw-v1.zip
9 | wget "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer.model" -O tokenizer.model
10 | popd
11 | fi
12 |
13 | # Run the benchmarks
14 | python pytorch.py --tokenizer_file data/tokenizer.model data/wikitext-103-raw
15 | echo "============="
16 | python tfds.py --tokenizer_file data/tokenizer.model data/wikitext-103-raw 2>/dev/null
17 | echo "============="
18 | python mlx_data.py --tokenizer_file data/tokenizer.model data/wikitext-103-raw
19 |
--------------------------------------------------------------------------------
/benchmarks/comparative/wikitext/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright © 2023 Apple Inc.
2 |
3 | import time
4 | from collections import defaultdict
5 |
6 |
7 | class Benchmark:
8 | def __init__(self, name):
9 | self.name = name
10 | self.runtimes = defaultdict(list)
11 | self.n_samples = defaultdict(int)
12 |
13 | def log_run(self, run_name, fn, *args, **kwargs):
14 | start = time.time()
15 | n_samples = fn(*args, **kwargs)
16 | end = time.time()
17 | self.runtimes[run_name].append(end - start)
18 | self.n_samples[run_name] = n_samples
19 |
20 | def report(self):
21 | print(f"Benchmark {self.name}")
22 | print()
23 | klengths = max(map(len, self.runtimes.keys()))
24 | table = []
25 | for k in self.runtimes.keys():
26 | n_runs = len(self.runtimes[k])
27 | avg_time = sum(self.runtimes[k]) / n_runs
28 | avg_throughput = self.n_samples[k] / avg_time
29 | table.append(
30 | [
31 | k,
32 | f"{avg_time:.3f} s",
33 | f"{avg_throughput:.3f} samples/s",
34 | f"{n_runs} runs",
35 | ]
36 | )
37 |
38 | column_widths = [
39 | max(len(table[j][i]) for j in range(len(table)))
40 | for i in range(len(table[0]))
41 | ]
42 | for row in table:
43 | for i, (v, w) in enumerate(zip(row, column_widths)):
44 | if i == 0:
45 | print(v, " " * (w - len(v)), " " * 4, end="")
46 | else:
47 | print(" " * (w - len(v)), v, " " * 4, end="")
48 | print()
49 |
--------------------------------------------------------------------------------
/benchmarks/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright © 2023 Apple Inc.
2 |
3 | import time
4 | from collections import defaultdict
5 |
6 |
7 | class Benchmark:
8 | def __init__(self, name):
9 | self.name = name
10 | self.runtimes = defaultdict(list)
11 | self.n_samples = defaultdict(int)
12 | self.run_count = 0
13 |
14 | def log_run(self, run_name, fn, *args, **kwargs):
15 | print(f"Starting run {self.run_count}.", flush=True)
16 | start = time.time()
17 | n_samples = fn(*args, **kwargs)
18 | end = time.time()
19 | self.runtimes[run_name].append(end - start)
20 | self.n_samples[run_name] = n_samples
21 | print("-------------")
22 | self.run_count += 1
23 |
24 | def report(self):
25 | print(f"Benchmark {self.name}")
26 | print()
27 | klengths = max(map(len, self.runtimes.keys()))
28 | table = []
29 | for k in self.runtimes.keys():
30 | n_runs = len(self.runtimes[k])
31 | avg_time = sum(self.runtimes[k]) / n_runs
32 | avg_throughput = self.n_samples[k] / avg_time
33 | table.append(
34 | [
35 | k,
36 | f"{avg_time:.3f} s",
37 | f"{avg_throughput:.3f} samples/s",
38 | f"{n_runs} runs",
39 | ]
40 | )
41 |
42 | column_widths = [
43 | max(len(table[j][i]) for j in range(len(table)))
44 | for i in range(len(table[0]))
45 | ]
46 | for row in table:
47 | for i, (v, w) in enumerate(zip(row, column_widths)):
48 | if i == 0:
49 | print(v, " " * (w - len(v)), " " * 4, end="")
50 | else:
51 | print(" " * (w - len(v)), v, " " * 4, end="")
52 | print()
53 |
--------------------------------------------------------------------------------
/cmake/FindFLAC.cmake:
--------------------------------------------------------------------------------
1 | # * Find FLAC Find the native FLAC includes and libraries
2 | #
3 | # Sets the following imported targets if FLAC is found: FLAC::FLAC
4 | #
5 | # Sets the following legacy CMake variables: FLAC_INCLUDE_DIRS - where to find
6 | # FLAC headers. FLAC_LIBRARIES - List of libraries when using libFLAC.
7 | # FLAC_FOUND - True if libFLAC found. FLAC_DEFINITIONS - FLAC compile
8 | # definitons
9 |
10 | if(FLAC_INCLUDE_DIR)
11 | # Already in cache, be silent
12 | set(FLAC_FIND_QUIETLY TRUE)
13 | endif()
14 |
15 | find_package(Ogg QUIET)
16 |
17 | find_package(PkgConfig QUIET)
18 | pkg_check_modules(PC_FLAC QUIET flac)
19 |
20 | set(FLAC_VERSION ${PC_FLAC_VERSION})
21 |
22 | find_path(FLAC_INCLUDE_DIR FLAC/stream_decoder.h
23 | HINTS ${PC_FLAC_INCLUDEDIR} ${PC_FLAC_INCLUDE_DIRS} ${FLAC_ROOT})
24 |
25 | # MSVC built libraries can name them *_static, which is good as it distinguishes
26 | # import libraries from static libraries with the same extension.
27 | find_library(
28 | FLAC_LIBRARY
29 | NAMES FLAC libFLAC libFLAC_dynamic libFLAC_static
30 | HINTS ${PC_FLAC_LIBDIR} ${PC_FLAC_LIBRARY_DIRS} ${FLAC_ROOT})
31 |
32 | # Handle the QUIETLY and REQUIRED arguments and set FLAC_FOUND to TRUE if all
33 | # listed variables are TRUE.
34 | include(FindPackageHandleStandardArgs)
35 | find_package_handle_standard_args(
36 | FLAC
37 | REQUIRED_VARS FLAC_LIBRARY FLAC_INCLUDE_DIR OGG_FOUND
38 | VERSION_VAR FLAC_VERSION)
39 |
40 | if(FLAC_FOUND)
41 | set(FLAC_INCLUDE_DIRS ${FLAC_INCLUDE_DIR})
42 | set(FLAC_LIBRARIES ${FLAC_LIBRARY} ${OGG_LIBRARIES})
43 | if(WIN32)
44 | set(FLAC_LIBRARIES ${FLAC_LIBRARIES} wsock32)
45 | get_filename_component(FLAC_LIBRARY_FILENAME ${FLAC_LIBRARY} NAME_WE)
46 | if(FLAC_LIBRARY_FILENAME MATCHES "libFLAC_static")
47 | set(FLAC_DEFINITIONS -DFLAC__NO_DLL)
48 | endif()
49 | endif()
50 | if(NOT TARGET FLAC::FLAC)
51 | add_library(FLAC::FLAC UNKNOWN IMPORTED)
52 | set_target_properties(
53 | FLAC::FLAC
54 | PROPERTIES INTERFACE_INCLUDE_DIRECTORIES "${FLAC_INCLUDE_DIR}"
55 | IMPORTED_LOCATION "${FLAC_LIBRARY}"
56 | INTERFACE_LINK_LIBRARIES Ogg::ogg
57 | $<$:wsock32> INTERFACE_COMPILE_DEFINITIONS
58 | ${FLAC_DEFINITIONS})
59 | endif()
60 | endif()
61 |
62 | mark_as_advanced(FLAC_INCLUDE_DIR FLAC_LIBRARY)
63 |
--------------------------------------------------------------------------------
/cmake/FindOgg.cmake:
--------------------------------------------------------------------------------
1 | # * Find ogg Find the native ogg includes and libraries
2 | #
3 | # OGG_INCLUDE_DIRS - where to find ogg.h, etc. OGG_LIBRARIES - List of
4 | # libraries when using ogg. OGG_FOUND - True if ogg found.
5 |
6 | find_package(Ogg CONFIG QUIET)
7 |
8 | if(NOT TARGET Ogg::ogg)
9 | if(OGG_INCLUDE_DIR)
10 | # Already in cache, be silent
11 | set(OGG_FIND_QUIETLY TRUE)
12 | endif()
13 |
14 | find_package(PkgConfig QUIET)
15 | pkg_check_modules(PC_OGG QUIET ogg)
16 |
17 | set(OGG_VERSION ${PC_OGG_VERSION})
18 |
19 | find_path(OGG_INCLUDE_DIR ogg/ogg.h HINTS ${PC_OGG_INCLUDEDIR}
20 | ${PC_OGG_INCLUDE_DIRS} ${OGG_ROOT})
21 | # MSVC built ogg may be named ogg_static. The provided project files name the
22 | # library with the lib prefix.
23 | find_library(
24 | OGG_LIBRARY
25 | NAMES ogg ogg_static libogg libogg_static
26 | HINTS ${PC_OGG_LIBDIR} ${PC_OGG_LIBRARY_DIRS} ${OGG_ROOT})
27 | # Handle the QUIETLY and REQUIRED arguments and set OGG_FOUND to TRUE if all
28 | # listed variables are TRUE.
29 | include(FindPackageHandleStandardArgs)
30 | find_package_handle_standard_args(
31 | Ogg
32 | REQUIRED_VARS OGG_LIBRARY OGG_INCLUDE_DIR
33 | VERSION_VAR OGG_VERSION)
34 |
35 | if(OGG_FOUND)
36 | set(OGG_LIBRARIES ${OGG_LIBRARY})
37 | set(OGG_INCLUDE_DIRS ${OGG_INCLUDE_DIR})
38 |
39 | if(NOT TARGET Ogg::ogg)
40 | add_library(Ogg::ogg UNKNOWN IMPORTED)
41 | set_target_properties(
42 | Ogg::ogg PROPERTIES INTERFACE_INCLUDE_DIRECTORIES "${OGG_INCLUDE_DIRS}"
43 | IMPORTED_LOCATION "${OGG_LIBRARIES}")
44 | endif()
45 | endif()
46 |
47 | mark_as_advanced(OGG_INCLUDE_DIR OGG_LIBRARY)
48 | endif()
49 |
--------------------------------------------------------------------------------
/cmake/FindSampleRate.cmake:
--------------------------------------------------------------------------------
1 | # Try to find libsamplerate
2 | #
3 | # Provides the cmake config target - SampleRate::samplerate
4 | #
5 | # Inputs: SampleRate_INC_DIR: include directory for samplerate headers
6 | # SampleRate_LIB_DIR: directory containing samplerate libraries
7 | # SampleRate_ROOT_DIR: directory containing samplerate installation
8 | #
9 | # Defines: SampleRate_FOUND - system has libsamplerate SampleRate_INCLUDE_DIRS -
10 | # the libsamplerate include directory SampleRate_LIBRARIES - Link these to use
11 | # libsamplerate
12 | #
13 |
14 | find_package(SampleRate CONFIG)
15 |
16 | if(NOT TARGET SampleRate::samplerate)
17 | find_path(
18 | SampleRate_INCLUDE_DIR samplerate.h
19 | PATHS ${SampleRate_INC_DIR} ${SampleRate_ROOT_DIR}/include
20 | PATH_SUFFIXES include)
21 |
22 | find_library(
23 | SampleRate_LIBRARY samplerate
24 | PATHS ${SampleRate_LIB_DIR} ${SampleRate_ROOT_DIR}
25 | PATH_SUFFIXES lib
26 | HINTS SAMPLERATE)
27 |
28 | set(SampleRate_INCLUDE_DIRS ${SampleRate_INCLUDE_DIR})
29 | set(SampleRate_LIBRARIES ${SampleRate_LIBRARY})
30 |
31 | mark_as_advanced(SampleRate_INCLUDE_DIRS SampleRate_LIBRARIES)
32 | include(FindPackageHandleStandardArgs)
33 | find_package_handle_standard_args(
34 | SampleRate DEFAULT_MSG SampleRate_INCLUDE_DIRS SampleRate_LIBRARIES)
35 |
36 | if(SampleRate_FOUND)
37 | add_library(SampleRate::samplerate UNKNOWN IMPORTED)
38 | set_target_properties(
39 | SampleRate::samplerate
40 | PROPERTIES INTERFACE_INCLUDE_DIRECTORIES "${SampleRate_INCLUDE_DIRS}"
41 | IMPORTED_LOCATION "${SampleRate_LIBRARIES}"
42 | INTERFACE_LINK_LIBRARIES "${SAMPLERATE_DEP_LIBRARIES}")
43 | message(
44 | STATUS
45 | "Found libsamplerate: (lib: ${SampleRate_LIBRARIES} include: ${SampleRate_INCLUDE_DIRS})"
46 | )
47 | else()
48 | message(STATUS "libsamplerate not found.")
49 | endif()
50 | endif() # NOT TARGET SampleRate::samplerate
51 |
--------------------------------------------------------------------------------
/cmake/pybind11-v2.11.1.patch:
--------------------------------------------------------------------------------
1 | diff --git a/CMakeLists.txt b/CMakeLists.txt
2 | index 87ec1034..eaef1a4c 100644
3 | --- a/CMakeLists.txt
4 | +++ b/CMakeLists.txt
5 | @@ -16,6 +16,10 @@ else()
6 | cmake_policy(VERSION 3.26)
7 | endif()
8 |
9 | +if(POLICY CMP0148)
10 | + cmake_policy(SET CMP0148 NEW)
11 | +endif()
12 | +
13 | # Avoid infinite recursion if tests include this as a subdirectory
14 | if(DEFINED PYBIND11_MASTER_PROJECT)
15 | return()
16 |
--------------------------------------------------------------------------------
/docs/.gitignore:
--------------------------------------------------------------------------------
1 | build/
2 | _autosummary
3 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = src
9 | BUILDDIR = build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/docs/README.md:
--------------------------------------------------------------------------------
1 | ## Build the Docs
2 |
3 | ### Setup (do once)
4 |
5 | Install [sphinx](https://www.sphinx-doc.org/en/master/usage/installation.html)
6 | for example with `conda`:
7 |
8 | ```
9 | conda install sphinx
10 | pip install sphinx-book-theme
11 | ```
12 |
13 | ### Build
14 |
15 | Build the docs from `mlx/docs/`
16 |
17 | ```
18 | make html
19 | ```
20 |
21 | View the docs by running a server in `mlx/docs/build/html/`:
22 |
23 | ```
24 | python -m http.server
25 | ```
26 |
27 | and point your browser to `http://localhost:`.
28 |
--------------------------------------------------------------------------------
/docs/requirements.txt:
--------------------------------------------------------------------------------
1 | sphinx
2 | breathe
3 | sphinx-book-theme
4 | mlx-data
5 |
--------------------------------------------------------------------------------
/docs/src/_static/mlx_logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ml-explore/mlx-data/8c6d8b3efa6458812c8c3a70a87806b2bc1c056c/docs/src/_static/mlx_logo.png
--------------------------------------------------------------------------------
/docs/src/_static/mlx_logo_dark.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ml-explore/mlx-data/8c6d8b3efa6458812c8c3a70a87806b2bc1c056c/docs/src/_static/mlx_logo_dark.png
--------------------------------------------------------------------------------
/docs/src/_templates/data_core_modules.rst:
--------------------------------------------------------------------------------
1 | {{ fullname | escape | underline}}
2 |
3 | .. currentmodule:: {{ module }}
4 |
5 | .. autoclass:: {{ objname }}
6 |
7 | {% block methods %}
8 |
9 | {% if methods %}
10 | .. rubric:: {{ _('Methods') }}
11 |
12 | .. autosummary::
13 | :toctree: _autosummary
14 | {% for item in methods %}
15 | {%- if item not in inherited_members %}
16 | ~{{ name }}.{{ item }}
17 | {%- endif %}
18 | {%- endfor %}
19 | {% endif %}
20 | {% endblock %}
21 |
22 |
--------------------------------------------------------------------------------
/docs/src/conf.py:
--------------------------------------------------------------------------------
1 | # Copyright © 2023 Apple Inc.
2 |
3 | # -*- coding: utf-8 -*-
4 |
5 | import os
6 | import subprocess
7 |
8 | # -- Project information -----------------------------------------------------
9 |
10 | project = "MLX Data"
11 | copyright = "2023, MLX Contributors"
12 | author = "MLX Contributors"
13 | version = "0.0.2"
14 | release = "0.0.2"
15 |
16 | # -- General configuration ---------------------------------------------------
17 |
18 | extensions = [
19 | "sphinx.ext.autodoc",
20 | "sphinx.ext.autosummary",
21 | "sphinx.ext.intersphinx",
22 | "sphinx.ext.napoleon",
23 | ]
24 |
25 | autosummary_generate = True
26 |
27 | intersphinx_mapping = {
28 | "python": ("https://docs.python.org/3", None),
29 | "numpy": ("https://numpy.org/doc/stable/", None),
30 | }
31 |
32 | templates_path = ["_templates"]
33 | html_static_path = ["_static"]
34 | source_suffix = ".rst"
35 | master_doc = "index"
36 | highlight_language = "python"
37 | pygments_style = "sphinx"
38 |
39 | # -- Options for HTML output -------------------------------------------------
40 |
41 | html_theme = "sphinx_book_theme"
42 |
43 | html_theme_options = {
44 | "show_toc_level": 2,
45 | "repository_url": "https://github.com/ml-explore/mlx-data",
46 | "use_repository_button": True,
47 | "navigation_with_keys": False,
48 | "logo": {
49 | "image_light": "_static/mlx_logo.png",
50 | "image_dark": "_static/mlx_logo_dark.png",
51 | },
52 | }
53 |
54 | # -- Options for HTMLHelp output ---------------------------------------------
55 |
56 | htmlhelp_basename = "mlx_data_doc"
57 |
--------------------------------------------------------------------------------
/docs/src/index.rst:
--------------------------------------------------------------------------------
1 | MLX Data
2 | ========
3 |
4 | MLX Data is a framework agnostic data loading library brought to you by Apple
5 | machine learning research.
6 |
7 | MLX Data can be used to load data for machine learning training or on its own
8 | for data pre-processing. You can use it with PyTorch, Jax or `MLX
9 | `_.
10 |
11 | The goal of this library is to allow users to leverage multiple threads for
12 | data processing pipelines without the inflexibility of dealing with multiple
13 | processes or having to write in a symbolic language.
14 |
15 | .. note::
16 | In MLX Data pipelines you can use Python to process data, implement logic or cause side effects!
17 |
18 |
19 | .. toctree::
20 | :caption: Install
21 | :maxdepth: 1
22 |
23 | install
24 |
25 | .. toctree::
26 | :caption: Usage
27 | :maxdepth: 1
28 |
29 | quick_start
30 | buffers_streams_samples
31 | hf_datasets_streams
32 |
33 | .. toctree::
34 | :caption: Python API Reference
35 | :maxdepth: 1
36 |
37 | python/dataset
38 | python/buffer
39 | python/stream
40 | python/common_datasets
41 | python/tokenizing
42 | python/features
43 | python/miscellaneous
44 |
--------------------------------------------------------------------------------
/docs/src/python/buffer.rst:
--------------------------------------------------------------------------------
1 | .. _buffer:
2 |
3 | Buffer
4 | ======
5 |
6 | .. currentmodule:: mlx.data
7 |
8 | As also mentioned in :ref:`Buffers, Streams and Samples `
9 | a :class:`Buffer` is an indexable container of
10 | samples. Using a buffer in python should feel very similar to accessing a list
11 | of samples.
12 |
13 | .. code-block:: python
14 |
15 | import mlx.data as dx
16 |
17 | numbers = dx.buffer_from_vector([{"x": i} for i in range(10)])
18 | evens = numbers.key_transform("x", lambda x: 2*x)
19 |
20 | print(evens)
21 | # prints Buffer(size=10, keys={'x'})
22 |
23 | print(evens[3])
24 | # prints {'x': array(6)}
25 |
26 | print(len(evens))
27 | # prints 10
28 |
29 | Factory methods
30 | ---------------
31 |
32 | We provide the following factory methods to create a buffer.
33 |
34 | .. autosummary::
35 | :toctree: _autosummary
36 |
37 | buffer_from_vector
38 | files_from_tar
39 |
40 | Buffer specific API
41 | -------------------
42 |
43 | The random access characteristics of a ``Buffer`` allow us to define some
44 | transformations that cannot be implemented or do not make sense for a
45 | :class:`Stream`.
46 |
47 | .. autosummary::
48 | :toctree: _autosummary
49 |
50 | Buffer.ordered_prefetch
51 | Buffer.partition
52 | Buffer.perm
53 | Buffer.shuffle
54 | Buffer.to_stream
55 |
--------------------------------------------------------------------------------
/docs/src/python/features.rst:
--------------------------------------------------------------------------------
1 | Feature extraction
2 | ==================
3 |
4 | This submodule provides some feature extraction utilities that can be used as
5 | ``key_transform`` functions in MLX data pipelines. Even though a C++
6 | implementation would allow for completely circumventing the GIL and better
7 | utilization of multiple threads, we find that an efficient numpy implementation
8 | can often be fast enough while providing signficiantly more flexibility.
9 |
10 | .. currentmodule:: mlx.data.features
11 |
12 | Audio Features
13 | --------------
14 |
15 | .. autosummary::
16 | :toctree: _autosummary
17 |
18 | WindowType
19 | FrequencyScale
20 | mfsc
21 |
--------------------------------------------------------------------------------
/docs/src/python/miscellaneous.rst:
--------------------------------------------------------------------------------
1 | Miscellaneous
2 | ==============
3 |
4 | .. currentmodule:: mlx.data
5 |
6 | FileFetcher
7 | -----------
8 |
9 | Several functions in MLX data can make use of a :class:`FileFetcher` object to
10 | fetch files from a remote location. See the :ref:`installation instructions ` to build MLX data with AWS support which adds the
11 | :class:`core.AWSFileFetcher` described below.
12 |
13 | .. autosummary::
14 | :toctree: _autosummary
15 | :recursive:
16 |
17 | core.AWSFileFetcher.__init__
18 | core.AWSFileFetcher.fetch
19 | core.AWSFileFetcher.prefetch
20 |
21 | A :class:`FileFetcher` can also be used standalone in your scripts to
22 | efficiently fetch remote content in background threads.
23 |
24 | .. code-block:: python
25 |
26 | from pathlib import Path
27 | from mlx.data.core import AWSFileFetcher
28 |
29 | LOCAL_CACHE = Path("/path/to/local/cache")
30 |
31 | ff = AWSFileFetcher(
32 | "my-cool-bucket",
33 | endpoint="https://my.endpoint.com/"
34 | local_prefix=LOCAL_CACHE,
35 | num_kept_files=100,
36 | )
37 |
38 | # When fetch returns my/remote/path/foo.npy will be in LOCAL_CACHE
39 | ff.fetch("my/remote/path/foo.npy")
40 | assert (LOCAL_CACHE / "my/remote/path/foo.npy").is_file()
41 |
42 | # We can prefetch in the background
43 | ff.prefetch(["foo_1.npy", "foo_2.npy"])
44 | ff.fetch("foo_1.npy")
45 | # process foo_1 while foo_2 downloads in the background
46 |
--------------------------------------------------------------------------------
/docs/src/python/stream.rst:
--------------------------------------------------------------------------------
1 | .. _stream:
2 |
3 | Stream
4 | ======
5 |
6 | .. currentmodule:: mlx.data
7 |
8 | Using a :class:`Stream` in python should feel like accessing an iterator of
9 | samples. Only the next sample can be fetched and the iteration may be restarted
10 | depending on the underlying source of the data (some trully online sources are
11 | not resettable).
12 |
13 | .. code-block:: python
14 |
15 | import mlx.data as dx
16 |
17 | # The samples are never all instantiated
18 | numbers = dx.stream_python_iterable(lambda: ({"x": i} for i in range(10**10)))
19 |
20 | # Filtering is done with transforms returning an empty sample
21 | evens = numbers.sample_transform(lambda s: s if s["x"] % 2 == 0 else dict())
22 |
23 | print(next(numbers))
24 | # prints {'x': array(0)}
25 | print(next(numbers))
26 | # prints {'x': array(1)}
27 |
28 | # Streams are pointers to the streams so evens is using numbers under the
29 | # hood. Since numbers was advanced now evens is advanced as well.
30 | print(next(evens))
31 | # prints {'x': array(2)}
32 | print(next(evens))
33 | # prints {'x': array(4)}
34 | print(next(numbers))
35 | # prints {'x': array(5)}
36 |
37 | # Streams can be reset.
38 | evens.reset()
39 | print(next(evens))
40 | print(next(evens))
41 | print(next(numbers))
42 | # prints {'x': array(0)}
43 | # {'x': array(2)}
44 | # {'x': array(3)}
45 |
46 |
47 | Factory methods
48 | ---------------
49 |
50 | We provide the following factory methods to create a stream. When used from
51 | python the most interesting one is probably :func:`stream_python_iterable`. Of
52 | course another good strategy is to start from a :class:`Buffer` that you then
53 | cast to a stream using :meth:`Buffer.to_stream`.
54 |
55 | .. autosummary::
56 | :toctree: _autosummary
57 |
58 | stream_csv_reader
59 | stream_csv_reader_from_string
60 | stream_line_reader
61 | stream_python_iterable
62 |
63 | Stream specific API
64 | -------------------
65 |
66 | :class:`Stream` has a more powerful API than :class:`Buffer`. It does not allow
67 | for random access, however, it allows for stream composing and prefetching.
68 | Stream composing is when a sample becomes the beginning of a new stream that
69 | can have arbitrary length.
70 |
71 | Streams also allow for filtering using the provided functions or a
72 | :meth:`Stream.sample_transform` that returns an empty dictionary (an empty
73 | Sample).
74 |
75 | .. autosummary::
76 | :toctree: _autosummary
77 |
78 | Stream.csv_reader_from_key
79 | Stream.line_reader_from_key
80 | Stream.dynamic_batch
81 | Stream.partition
82 | Stream.buffered
83 | Stream.repeat
84 | Stream.shuffle
85 | Stream.sliding_window
86 | Stream.prefetch
87 |
--------------------------------------------------------------------------------
/docs/src/python/tokenizing.rst:
--------------------------------------------------------------------------------
1 | Tokenizing with MLX data
2 | ========================
3 |
4 | .. currentmodule:: mlx.data
5 |
6 | MLX data allows sample transformations with the full flexibility of python which
7 | means that you could use any python tokenizer in a
8 | :meth:`Buffer.key_transform`. However, this is likely to be subject to the GIL
9 | which means that effectively only one sample can be tokenized at a time.
10 |
11 | A better choice is to use an :class:`mlx.data.core.CharTrie` to tokenize your
12 | data, taking full advatage of a multicore system. You can build the trie
13 | yourself or use one of the provided helpers to build a trie from an SentencePiece model
14 | or a plain text vocabulary file.
15 |
16 | .. code-block:: python
17 |
18 | from mlx.data.core import CharTrie, Tokenizer
19 |
20 | # We can build a trie ourselves
21 | trie = CharTrie()
22 | for t in b"a quick brown fox jumped over the lazy dog".split():
23 | trie.insert(t)
24 | trie.insert(b" ")
25 |
26 | tokenizer = Tokenizer(trie)
27 | print(tokenizer.tokenize_shortest(b"a quick brown fox jumped over the lazy dog"))
28 | # [0, 9, 1, 9, 2, 9, 3, 9, 4, 9, 5, 9, 6, 9, 7, 9, 8]
29 |
30 | # We can also add all the letters in the trie and then tokenize anything we want
31 | import string
32 | for l in string.ascii_letters:
33 | trie.insert(bytes(l, "utf-8"))
34 |
35 | print(tokenizer.tokenize_shortest(b"This is a quick example"))
36 | # [54, 16, 17, 27, 9, 17, 27, 9, 0, 9, 1, 9, 13, 32, 0, 21, 24, 20, 13]
37 |
38 | # The more useful option is to read the trie from a file, for instance an spm model
39 | from mlx.data.tokenizer_helpers import read_trie_from_spm
40 |
41 | trie, weights = read_trie_from_spm("path/to/spm/model")
42 | tokenizer = Tokenizer(trie, trie_key_scores=weights)
43 | tokenizer.tokenize_shortest(b"This is some more text to tokenize")
44 |
45 |
46 | .. autosummary::
47 | :toctree: _autosummary
48 | :template: data_core_modules.rst
49 | :recursive:
50 |
51 | core.Tokenizer
52 | core.CharTrie
53 | tokenizer_helpers.read_trie_from_vocab
54 | tokenizer_helpers.read_trie_from_spm
55 |
--------------------------------------------------------------------------------
/docs/src/quick_start.rst:
--------------------------------------------------------------------------------
1 | Quick Start Guide
2 | =================
3 |
4 | Load some data
5 | --------------
6 |
7 | In MLX data all samples are dictionaries of arrays. The library provides
8 | functions to download and iterate over some common datasets but the goal is to
9 | **provide functions that allow the user to load and process on the fly their
10 | own datasets**.
11 |
12 | Let's start with the simplest example on MNIST.
13 |
14 | .. code-block:: python
15 |
16 | # This is the standard way to import and access mlx.data
17 | import mlx.data as dx
18 |
19 | # Let's import MNIST loading
20 | from mlx.data.datasets import load_mnist
21 |
22 | # Loads a buffer with the MNIST images
23 | mnist_train = load_mnist(train=True)
24 |
25 | # Let's shuffle flatten and batch to prepare for MLP training
26 | mnist_mlp = (
27 | mnist_train
28 | .shuffle()
29 | .to_stream()
30 | .key_transform("image", lambda x: x.astype("float32").reshape(-1))
31 | .batch(32)
32 | .prefetch(4, 2)
33 | )
34 |
35 | # Now we can iterate over the batches in normal python
36 | for batch in mnist_mlp:
37 | x, y = batch["image"], batch["label"]
38 |
39 |
40 | MLX Data provides many :ref:`operations ` that transform
41 | samples so you can create arbitrarily complex pipelines.
42 |
43 |
44 | About the GIL
45 | -------------
46 |
47 | Python functions called by MLX data still run under the Global Interpreter
48 | Lock. To avoid serializing your data pipeline, either drop into Numpy or some
49 | other optimized library as quickly as possible or limit the processing time of
50 | the python part of the data pipeline.
51 |
52 | We would advise, however, to avoid premature optimization and only try to
53 | reduce GIL overhead if you are certain that it is limiting your data processing
54 | pipeline.
55 |
56 | The following are examples where we would use a python function for flexibility
57 | rather than have a specific C++ transformation.
58 |
59 | .. code-block:: python
60 |
61 | # Normalizing images in [0, 1]
62 | dset = dset.key_transform("image", lambda x: x.astype("float32") / 255)
63 |
64 | # Extracting mel spectrogram features
65 | # A big chunk of the time is spent computing the FFT which is done with the GIL off so...
66 | from mlx.data.features import mfsc
67 | dset = dset.key_transform("audio", mfsc(n_filterbank=80, sampling_freq=16000))
68 |
69 | # Filter stream samples based on values (empty dict means drop the sample)
70 | dset = dset.sample_transform(lambda s: s if s["length"] > 10 else dict())
71 |
--------------------------------------------------------------------------------
/mlx-data.pc.in:
--------------------------------------------------------------------------------
1 | # Find MLX Data
2 | #
3 | # Defines the following variables:
4 | #
5 | # MLX_DATA_FOUND : True if MLX Data is found
6 | # MLX_DATA_INCLUDE_DIRS : Include directory
7 | # MLX_DATA_LIBRARIES : Libraries to link against
8 | # MLX_DATA_CXX_FLAGS : Additional compiler flags
9 |
10 | @PACKAGE_INIT@
11 |
12 | include(@PACKAGE_MLX_DATA_CMAKE_INSTALL_MODULE_DIR@/MLXDataTargets.cmake)
13 |
14 | set_and_check(MLX_DATA_LIBRARY_DIRS @PACKAGE_CMAKE_INSTALL_LIBDIR@)
15 | set_and_check(MLX_DATA_INCLUDE_DIRS @PACKAGE_CMAKE_INSTALL_INCLUDEDIR@)
16 | set(MLX_DATA_LIBRARIES mlxdata)
17 |
18 | find_library(MLX_DATA_LIBRARY mlxdata PATHS ${MLX_DATA_LIBRARY_DIRS})
19 |
20 | set_target_properties(mlx PROPERTIES
21 | CXX_STANDARD 17
22 | INTERFACE_COMPILE_OPTIONS "${MLX_DATA_CXX_FLAGS}"
23 | )
24 |
25 | include(FindPackageHandleStandardArgs)
26 | find_package_handle_standard_args(MLX_DATA DEFAULT_MSG MLX_DATA_LIBRARY MLX_DATA_INCLUDE_DIRS)
27 |
--------------------------------------------------------------------------------
/mlx/data/Buffer.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #pragma once
4 |
5 | #include "mlx/data/Dataset.h"
6 | #include "mlx/data/buffer/Buffer.h"
7 |
8 | namespace mlx {
9 | namespace data {
10 |
11 | // Forward declaration of Stream so we can define toStream().
12 | class Stream;
13 |
14 | class Buffer : public Dataset {
15 | public:
16 | Buffer(const std::shared_ptr& self);
17 |
18 | Sample get(int64_t idx) const;
19 | int64_t size() const;
20 |
21 | Buffer batch(
22 | int64_t batch_size,
23 | const std::unordered_map& pad_values = {},
24 | const std::unordered_map& batch_dims = {}) const;
25 | Buffer batch(
26 | const std::vector& batch_sizes,
27 | const std::unordered_map& pad_values = {},
28 | const std::unordered_map& batch_dims = {}) const;
29 |
30 | Buffer dynamic_batch(
31 | const std::string& key,
32 | int64_t max_data_size = 0, // batch everything if <= 0
33 | const std::unordered_map& pad_values = {},
34 | const std::unordered_map& batch_dims = {}) const;
35 |
36 | Buffer dynamic_batch(
37 | const Buffer& size_buffer,
38 | const std::string& key,
39 | int64_t max_data_size,
40 | const std::unordered_map& pad_values = {},
41 | const std::unordered_map& batch_dims = {}) const;
42 |
43 | Stream ordered_prefetch(int prefetch_size, int num_thread) const;
44 |
45 | Buffer partition(int64_t num_partitions, int64_t partition) const;
46 | Buffer partition_if(bool cond, int64_t num_partitions, int64_t partition)
47 | const;
48 |
49 | Buffer perm(const std::vector& perm);
50 |
51 | Buffer shuffle();
52 | Buffer shuffle_if(bool cond);
53 |
54 | Stream to_stream();
55 |
56 | friend class Stream;
57 | };
58 |
59 | Buffer buffer_from_vector(const std::vector& data);
60 | Buffer buffer_from_vector(std::vector&& data);
61 | Buffer files_from_tar(
62 | const std::string& tarfile,
63 | bool nested = false,
64 | int num_threads = 1);
65 |
66 | } // namespace data
67 | } // namespace mlx
68 |
--------------------------------------------------------------------------------
/mlx/data/Sample.cpp:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #include
4 |
5 | #include "mlx/data/Sample.h"
6 |
7 | namespace mlx {
8 | namespace data {
9 | namespace sample {
10 | std::vector keys(const Sample& dict) {
11 | std::vector keys;
12 | for (auto& kv : dict) {
13 | keys.push_back(kv.first);
14 | }
15 | return keys;
16 | }
17 |
18 | std::shared_ptr
19 | check_key(const Sample& input, const std::string& key, ArrayType type) {
20 | auto it = input.find(key);
21 | if (it == input.end()) {
22 | throw std::runtime_error("key <" + key + "> expected");
23 | }
24 | auto value = it->second;
25 | if (type != ArrayType::Any && value->type() != type) {
26 | throw std::runtime_error("invalid Array type");
27 | }
28 | return value;
29 | }
30 | } // namespace sample
31 | } // namespace data
32 | } // namespace mlx
33 |
--------------------------------------------------------------------------------
/mlx/data/Sample.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #pragma once
4 |
5 | #include
6 | #include
7 |
8 | #include "mlx/data/Array.h"
9 |
10 | namespace mlx {
11 | namespace data {
12 |
13 | typedef std::unordered_map> Sample;
14 |
15 | namespace sample {
16 | std::vector keys(const Sample& dict);
17 | std::shared_ptr
18 | check_key(const Sample& input, const std::string& key, ArrayType type);
19 |
20 | } // namespace sample
21 | } // namespace data
22 | } // namespace mlx
23 |
--------------------------------------------------------------------------------
/mlx/data/buffer/Batch.cpp:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #include "mlx/data/buffer/Batch.h"
4 | #include "mlx/data/core/Utils.h"
5 |
6 | namespace mlx {
7 | namespace data {
8 | namespace buffer {
9 | Batch::Batch(
10 | const std::shared_ptr& op,
11 | int64_t batch_size,
12 | const std::unordered_map& pad_values,
13 | const std::unordered_map& batch_dims)
14 | : op_(op),
15 | batchSize_(batch_size),
16 | padValues_(pad_values),
17 | batchDims_(batch_dims) {
18 | if (batch_size <= 0) {
19 | throw std::runtime_error("Batch: batch size must be positive");
20 | }
21 | size_ = op->size() / batch_size;
22 | if (op->size() % batch_size) {
23 | size_++;
24 | }
25 | }
26 | Batch::Batch(
27 | const std::shared_ptr& op,
28 | const std::vector& batch_sizes,
29 | const std::unordered_map& pad_values,
30 | const std::unordered_map& batch_dims)
31 | : op_(op),
32 | batchSize_(0),
33 | batchOffsets_(batch_sizes.size()),
34 | batchSizes_(batch_sizes),
35 | padValues_(pad_values),
36 | batchDims_(batch_dims) {
37 | int64_t batch_sizes_sum = 0;
38 | for (int64_t i = 0; i < batch_sizes.size(); i++) {
39 | auto batch_size = batch_sizes[i];
40 | if (batch_size <= 0) {
41 | throw std::runtime_error("Batch: batch size must be positive");
42 | }
43 | batchOffsets_[i] = batch_sizes_sum;
44 | batch_sizes_sum += batch_size;
45 | }
46 | if (batch_sizes_sum > op->size()) {
47 | throw std::runtime_error("Batch: sum of batch sizes exceeds buffer size");
48 | }
49 | size_ = batch_sizes.size();
50 | }
51 |
52 | Sample Batch::get(int64_t idx) const {
53 | if (idx < 0 || idx >= size_) {
54 | throw std::runtime_error("Batch: index out of range");
55 | }
56 | auto batch_size =
57 | (batchSize_ ? std::min(batchSize_, op_->size() - idx * batchSize_)
58 | : batchSizes_[idx]);
59 | auto batch_offset = (batchSize_ ? idx * batchSize_ : batchOffsets_[idx]);
60 | std::vector samples(batch_size);
61 | for (int64_t i = 0; i < batch_size; i++) {
62 | samples[i] = op_->get(batch_offset + i);
63 | }
64 | return core::merge_batch(samples, padValues_, batchDims_);
65 | }
66 |
67 | int64_t Batch::size() const {
68 | return size_;
69 | }
70 |
71 | } // namespace buffer
72 | } // namespace data
73 | } // namespace mlx
74 |
--------------------------------------------------------------------------------
/mlx/data/buffer/Batch.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #pragma once
4 |
5 | #include "mlx/data/buffer/Buffer.h"
6 |
7 | namespace mlx {
8 | namespace data {
9 | namespace buffer {
10 |
11 | class Batch : public Buffer {
12 | public:
13 | Batch(
14 | const std::shared_ptr& op,
15 | int64_t batch_size,
16 | const std::unordered_map& pad_values = {},
17 | const std::unordered_map& batch_dims = {});
18 | Batch(
19 | const std::shared_ptr& op,
20 | const std::vector& batch_sizes,
21 | const std::unordered_map& pad_values = {},
22 | const std::unordered_map& batch_dims = {});
23 |
24 | virtual Sample get(int64_t idx) const override;
25 | virtual int64_t size() const override;
26 |
27 | private:
28 | std::shared_ptr op_;
29 | int64_t batchSize_;
30 | std::vector batchOffsets_;
31 | std::vector batchSizes_;
32 | std::unordered_map padValues_;
33 | std::unordered_map batchDims_;
34 | int64_t size_;
35 | };
36 |
37 | } // namespace buffer
38 | } // namespace data
39 | } // namespace mlx
40 |
--------------------------------------------------------------------------------
/mlx/data/buffer/Buffer.cpp:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #include
4 |
5 | #include "mlx/data/buffer/Buffer.h"
6 |
7 | namespace mlx {
8 | namespace data {
9 | namespace buffer {
10 |
11 | Sample Buffer::get(const int64_t idx) const {
12 | throw std::runtime_error("Buffer::get() NYI");
13 | }
14 |
15 | int64_t Buffer::size() const {
16 | throw std::runtime_error("Buffer::size() NYI");
17 | }
18 |
19 | Buffer::~Buffer() {}
20 |
21 | } // namespace buffer
22 | } // namespace data
23 | } // namespace mlx
24 |
--------------------------------------------------------------------------------
/mlx/data/buffer/Buffer.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #pragma once
4 |
5 | #include "mlx/data/Sample.h"
6 |
7 | namespace mlx {
8 | namespace data {
9 | namespace buffer {
10 |
11 | class Buffer {
12 | public:
13 | Buffer() {};
14 |
15 | // User-specific
16 | virtual Sample get(int64_t idx) const;
17 | virtual int64_t size() const;
18 |
19 | virtual ~Buffer();
20 | };
21 |
22 | } // namespace buffer
23 | } // namespace data
24 | } // namespace mlx
25 |
--------------------------------------------------------------------------------
/mlx/data/buffer/DynamicBatch.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #pragma once
4 |
5 | #include
6 | #include "mlx/data/buffer/Batch.h"
7 |
8 | namespace mlx {
9 | namespace data {
10 | namespace buffer {
11 |
12 | class DynamicBatch : public Batch {
13 | public:
14 | DynamicBatch(
15 | const std::shared_ptr& buffer,
16 | const std::string& key,
17 | int64_t max_data_size = 0, // batch everything if <= 0
18 | const std::unordered_map& pad_values = {},
19 | const std::unordered_map& batch_dims = {});
20 |
21 | DynamicBatch(
22 | const std::shared_ptr& buffer,
23 | const std::shared_ptr& ref_size_buffer,
24 | const std::string& key,
25 | int64_t max_data_size,
26 | const std::unordered_map& pad_values = {},
27 | const std::unordered_map& batch_dims = {});
28 |
29 | private:
30 | DynamicBatch(
31 | std::pair, std::vector>
32 | buffer_with_sizes,
33 | const std::unordered_map& pad_values,
34 | const std::unordered_map& batch_dims);
35 |
36 | // returns sorted buffer with number of samples for each batch
37 | static std::pair, std::vector>
38 | dynamic_batch_(
39 | const std::shared_ptr& buffer,
40 | const std::shared_ptr& ref_sizebuffer,
41 | const std::string& key,
42 | int64_t max_data_size,
43 | const std::unordered_map& batch_dims);
44 | };
45 |
46 | } // namespace buffer
47 | } // namespace data
48 | } // namespace mlx
49 |
--------------------------------------------------------------------------------
/mlx/data/buffer/FilesFromTAR.cpp:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #include "mlx/data/buffer/FilesFromTAR.h"
4 | #include "mlx/data/Array.h"
5 | #include "mlx/data/Sample.h"
6 | #include "mlx/data/core/TARReader.h"
7 |
8 | namespace mlx {
9 | namespace data {
10 | namespace buffer {
11 |
12 | FilesFromTAR::FilesFromTAR(
13 | const std::string& tarfile,
14 | bool nested,
15 | int num_threads) {
16 | core::TARReader tarreader(tarfile, nested, num_threads);
17 | files_ = tarreader.get_file_list();
18 | }
19 |
20 | Sample FilesFromTAR::get(int64_t idx) const {
21 | if (idx < 0 || idx >= files_.size()) {
22 | throw std::runtime_error("FilesFromTAR: index out of range");
23 | }
24 | Sample res;
25 | res["file"] = std::make_shared(files_[idx]);
26 | return res;
27 | }
28 |
29 | int64_t FilesFromTAR::size() const {
30 | return files_.size();
31 | }
32 |
33 | } // namespace buffer
34 | } // namespace data
35 | } // namespace mlx
36 |
--------------------------------------------------------------------------------
/mlx/data/buffer/FilesFromTAR.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #pragma once
4 |
5 | #include "mlx/data/buffer/Buffer.h"
6 |
7 | namespace mlx {
8 | namespace data {
9 | namespace buffer {
10 |
11 | class FilesFromTAR : public Buffer {
12 | public:
13 | FilesFromTAR(
14 | const std::string& tarfile,
15 | bool nested = false,
16 | int num_threads = 1);
17 |
18 | Sample get(int64_t idx) const override;
19 | virtual int64_t size() const override;
20 |
21 | private:
22 | std::vector files_;
23 | };
24 |
25 | } // namespace buffer
26 | } // namespace data
27 | } // namespace mlx
28 |
--------------------------------------------------------------------------------
/mlx/data/buffer/FromStream.cpp:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #include "mlx/data/buffer/FromStream.h"
4 |
5 | namespace mlx {
6 | namespace data {
7 | namespace buffer {
8 |
9 | FromStream::FromStream(
10 | const std::shared_ptr& stream,
11 | int64_t size)
12 | : FromVector(bufferize_(stream, size)) {}
13 |
14 | std::vector FromStream::bufferize_(
15 | std::shared_ptr stream,
16 | int64_t size) {
17 | std::vector buffer;
18 | if (size > 0) {
19 | buffer.reserve(size);
20 | }
21 | for (int64_t i = 0; (size < 0) || (i < size); i++) {
22 | auto sample = stream->next();
23 | if (sample.empty()) {
24 | break;
25 | }
26 | buffer.push_back(sample);
27 | }
28 | return buffer;
29 | }
30 |
31 | } // namespace buffer
32 | } // namespace data
33 | } // namespace mlx
34 |
--------------------------------------------------------------------------------
/mlx/data/buffer/FromStream.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #pragma once
4 |
5 | #include "mlx/data/buffer/FromVector.h"
6 | #include "mlx/data/stream/Stream.h"
7 |
8 | namespace mlx {
9 | namespace data {
10 | namespace buffer {
11 |
12 | class FromStream : public FromVector {
13 | public:
14 | FromStream(const std::shared_ptr& stream, int64_t size = -1);
15 |
16 | private:
17 | static std::vector bufferize_(
18 | std::shared_ptr stream,
19 | int64_t size);
20 | };
21 |
22 | } // namespace buffer
23 | } // namespace data
24 | } // namespace mlx
25 |
--------------------------------------------------------------------------------
/mlx/data/buffer/FromVector.cpp:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #include "mlx/data/buffer/FromVector.h"
4 |
5 | namespace mlx {
6 | namespace data {
7 | namespace buffer {
8 |
9 | FromVector::FromVector(const std::vector& data) : buffer_(data) {
10 | check_samples_();
11 | }
12 |
13 | FromVector::FromVector(std::vector&& data) : buffer_(std::move(data)) {
14 | check_samples_();
15 | }
16 |
17 | Sample FromVector::get(int64_t idx) const {
18 | if (idx < 0 || idx >= buffer_.size()) {
19 | throw std::out_of_range("FromVector: index out of range");
20 | }
21 | return buffer_[idx];
22 | }
23 |
24 | int64_t FromVector::size() const {
25 | return buffer_.size();
26 | }
27 |
28 | void FromVector::check_samples_() const {
29 | for (auto& sample : buffer_) {
30 | if (sample.empty()) {
31 | throw std::runtime_error("FromVector: unexpected empty sample");
32 | }
33 | }
34 | }
35 |
36 | } // namespace buffer
37 | } // namespace data
38 | } // namespace mlx
39 |
--------------------------------------------------------------------------------
/mlx/data/buffer/FromVector.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #pragma once
4 |
5 | #include "mlx/data/buffer/Buffer.h"
6 |
7 | namespace mlx {
8 | namespace data {
9 | namespace buffer {
10 |
11 | class FromVector : public Buffer {
12 | public:
13 | FromVector(const std::vector& data);
14 | FromVector(std::vector&& data);
15 |
16 | Sample get(int64_t idx) const override;
17 | virtual int64_t size() const override;
18 |
19 | private:
20 | void check_samples_() const;
21 | std::vector buffer_;
22 | };
23 |
24 | } // namespace buffer
25 | } // namespace data
26 | } // namespace mlx
27 |
--------------------------------------------------------------------------------
/mlx/data/buffer/Partition.cpp:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #include "mlx/data/buffer/Partition.h"
4 |
5 | namespace mlx {
6 | namespace data {
7 | namespace buffer {
8 |
9 | Partition::Partition(
10 | std::shared_ptr buffer,
11 | int64_t num_partitions,
12 | int64_t partition)
13 | : buffer_(buffer), numPartitions_(num_partitions), partition_(partition) {
14 | if (num_partitions < 0) {
15 | throw std::runtime_error(
16 | "Partition: number of partitions must be positive");
17 | }
18 | if (partition < 0 || partition >= num_partitions) {
19 | throw std::runtime_error("Partition: selected partition is out of range");
20 | }
21 | size_ = buffer->size() / num_partitions;
22 | if (partition_ < (buffer->size() % num_partitions)) {
23 | size_++;
24 | }
25 | }
26 |
27 | Sample Partition::get(int64_t idx) const {
28 | if (idx < 0 || idx > size_) {
29 | throw std::runtime_error("Partition: index out of range");
30 | }
31 | return buffer_->get(idx * numPartitions_ + partition_);
32 | }
33 |
34 | int64_t Partition::size() const {
35 | return size_;
36 | }
37 |
38 | } // namespace buffer
39 | } // namespace data
40 | } // namespace mlx
41 |
--------------------------------------------------------------------------------
/mlx/data/buffer/Partition.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #pragma once
4 |
5 | #include "mlx/data/buffer/Buffer.h"
6 |
7 | namespace mlx {
8 | namespace data {
9 | namespace buffer {
10 |
11 | class Partition : public Buffer {
12 | public:
13 | Partition(
14 | std::shared_ptr buffer,
15 | int64_t num_partitions,
16 | int64_t partition);
17 |
18 | virtual Sample get(int64_t idx) const override;
19 | virtual int64_t size() const override;
20 |
21 | private:
22 | std::shared_ptr buffer_;
23 | int64_t numPartitions_;
24 | int64_t partition_;
25 | int64_t size_;
26 | };
27 |
28 | } // namespace buffer
29 | } // namespace data
30 | } // namespace mlx
31 |
--------------------------------------------------------------------------------
/mlx/data/buffer/Perm.cpp:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #include "mlx/data/buffer/Perm.h"
4 |
5 | namespace mlx {
6 | namespace data {
7 | namespace buffer {
8 |
9 | Perm::Perm(const std::shared_ptr& op, const std::vector& perm)
10 | : op_(op) {
11 | set_perm_(perm);
12 | }
13 |
14 | Sample Perm::get(int64_t idx) const {
15 | if (idx < 0 || idx >= perm_.size()) {
16 | throw std::runtime_error("Perm: index out of range");
17 | }
18 | return op_->get(perm_[idx]);
19 | }
20 |
21 | int64_t Perm::size() const {
22 | return perm_.size();
23 | }
24 |
25 | void Perm::set_perm_(const std::vector& perm) {
26 | for (auto idx : perm) {
27 | if (idx < 0 || idx >= op_->size()) {
28 | throw std::runtime_error("Perm: permutation index out of range");
29 | }
30 | }
31 | perm_ = perm;
32 | }
33 |
34 | const std::vector& Perm::get_perm() {
35 | return perm_;
36 | }
37 |
38 | } // namespace buffer
39 | } // namespace data
40 | } // namespace mlx
41 |
--------------------------------------------------------------------------------
/mlx/data/buffer/Perm.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #pragma once
4 |
5 | #include "mlx/data/buffer/Buffer.h"
6 |
7 | namespace mlx {
8 | namespace data {
9 | namespace buffer {
10 |
11 | class Perm : public Buffer {
12 | public:
13 | Perm(const std::shared_ptr& op, const std::vector& perm);
14 |
15 | Sample get(int64_t idx) const override;
16 | virtual int64_t size() const override;
17 |
18 | const std::vector& get_perm();
19 |
20 | private:
21 | std::shared_ptr op_;
22 | std::vector perm_;
23 |
24 | // unsafe if it was public
25 | void set_perm_(const std::vector& perm);
26 | };
27 |
28 | } // namespace buffer
29 | } // namespace data
30 | } // namespace mlx
31 |
--------------------------------------------------------------------------------
/mlx/data/buffer/Shuffle.cpp:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #include "mlx/data/buffer/Shuffle.h"
4 | #include "mlx/data/core/State.h"
5 |
6 | #include
7 | #include // iota
8 |
9 | namespace mlx {
10 | namespace data {
11 | namespace buffer {
12 |
13 | Shuffle::Shuffle(const std::shared_ptr& buffer)
14 | : Perm(buffer, rand_perm_(buffer->size())) {}
15 |
16 | std::vector Shuffle::rand_perm_(int64_t size) {
17 | auto state = core::get_state();
18 | std::vector perm(size);
19 | std::iota(perm.begin(), perm.end(), 0);
20 | std::shuffle(perm.begin(), perm.end(), state->randomGenerator);
21 | return perm;
22 | }
23 |
24 | } // namespace buffer
25 | } // namespace data
26 | } // namespace mlx
27 |
--------------------------------------------------------------------------------
/mlx/data/buffer/Shuffle.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #pragma once
4 |
5 | #include "mlx/data/buffer/Perm.h"
6 |
7 | namespace mlx {
8 | namespace data {
9 | namespace buffer {
10 |
11 | class Shuffle : public Perm {
12 | public:
13 | Shuffle(const std::shared_ptr& buffer);
14 |
15 | private:
16 | std::vector rand_perm_(int64_t size);
17 | };
18 |
19 | } // namespace buffer
20 | } // namespace data
21 | } // namespace mlx
22 |
--------------------------------------------------------------------------------
/mlx/data/buffer/Transform.cpp:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #include
4 |
5 | #include "mlx/data/buffer/Transform.h"
6 |
7 | namespace mlx {
8 | namespace data {
9 | namespace buffer {
10 |
11 | Transform::Transform(
12 | const std::shared_ptr& od,
13 | const std::shared_ptr& op)
14 | : od_(od), ops_({op}) {};
15 |
16 | Transform::Transform(
17 | const std::shared_ptr& od,
18 | const std::vector>& ops)
19 | : od_(od), ops_(ops) {};
20 |
21 | Sample Transform::get(const int64_t idx) const {
22 | auto t_sample = od_->get(idx);
23 | if (t_sample.empty()) {
24 | throw std::runtime_error("Transform: cannot return empty sample");
25 | }
26 | for (auto& op : ops_) {
27 | t_sample = op->apply(t_sample);
28 | if (t_sample.empty()) {
29 | throw std::runtime_error("Transform: cannot return empty sample");
30 | }
31 | }
32 | return t_sample;
33 | }
34 |
35 | int64_t Transform::size() const {
36 | return od_->size();
37 | }
38 |
39 | } // namespace buffer
40 | } // namespace data
41 | } // namespace mlx
42 |
--------------------------------------------------------------------------------
/mlx/data/buffer/Transform.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #pragma once
4 |
5 | #include
6 | #include
7 | #include
8 |
9 | #include "mlx/data/buffer/Buffer.h"
10 | #include "mlx/data/op/Op.h"
11 |
12 | namespace mlx {
13 | namespace data {
14 | namespace buffer {
15 |
16 | class Transform : public Buffer {
17 | public:
18 | Transform(
19 | const std::shared_ptr& od,
20 | const std::shared_ptr& op);
21 | Transform(
22 | const std::shared_ptr& od,
23 | const std::vector>& ops);
24 |
25 | virtual Sample get(int64_t idx) const override;
26 |
27 | virtual int64_t size() const override;
28 |
29 | protected:
30 | std::shared_ptr od_;
31 | std::vector> ops_;
32 | };
33 |
34 | } // namespace buffer
35 | } // namespace data
36 | } // namespace mlx
37 |
--------------------------------------------------------------------------------
/mlx/data/core/BPETokenizer.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2024 Apple Inc.
2 |
3 | #pragma once
4 |
5 | #include
6 | #include
7 | #include
8 |
9 | #include "mlx/data/core/Trie.h"
10 |
11 | namespace mlx {
12 | namespace data {
13 | namespace core {
14 |
15 | class BPEMerges {
16 | public:
17 | void add(const std::string& left, const std::string& right, int64_t token);
18 | std::pair can_merge(
19 | std::string_view left,
20 | std::string_view right) const;
21 |
22 | template
23 | std::pair
24 | can_merge(iterator_type left, iterator_type middle, iterator_type end) const {
25 | // switch to std::string_view(left, middle) when in C++20
26 | return can_merge(
27 | std::string_view(&(*left), std::distance(left, middle)),
28 | std::string_view(&(*middle), std::distance(middle, end)));
29 | }
30 |
31 | private:
32 | std::unordered_set strings_;
33 | std::unordered_map<
34 | std::string_view,
35 | std::unordered_map>
36 | merges_;
37 | };
38 |
39 | class BPETokenizer {
40 | public:
41 | BPETokenizer(
42 | std::shared_ptr> symbols,
43 | std::shared_ptr merges);
44 |
45 | std::vector tokenize(std::string_view input) const;
46 |
47 | private:
48 | std::shared_ptr> symbols_;
49 | std::shared_ptr merges_;
50 | };
51 |
52 | } // namespace core
53 | } // namespace data
54 | } // namespace mlx
55 |
--------------------------------------------------------------------------------
/mlx/data/core/BatchShape.cpp:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #include
4 |
5 | #include "mlx/data/core/BatchShape.h"
6 |
7 | namespace mlx {
8 | namespace data {
9 | namespace core {
10 |
11 | BatchShape::BatchShape() : nodim_(true), num_sample_(0) {};
12 | BatchShape::BatchShape(int dim) : dim_(dim), nodim_(false), num_sample_(0) {};
13 |
14 | int64_t BatchShape::size() const {
15 | int64_t size = 1;
16 | for (auto dim : shape_) {
17 | size *= dim;
18 | }
19 | return size;
20 | }
21 |
22 | const std::vector& BatchShape::shape() const {
23 | return shape_;
24 | }
25 |
26 | void BatchShape::add(const std::vector& shape) {
27 | if (nodim_) {
28 | if (num_sample_ == 0) {
29 | shape_.resize(shape.size() + 1, 0);
30 | std::copy(shape.begin(), shape.end(), shape_.begin() + 1);
31 | } else {
32 | if ((shape.size() + 1) != shape_.size()) {
33 | throw std::runtime_error(
34 | "BatchShape: batched arrays expected to have consistent shapes");
35 | }
36 | for (int d = 0; d < shape.size(); d++) {
37 | shape_[d + 1] = std::max(shape_[d + 1], shape[d]);
38 | }
39 | }
40 | shape_[0] += 1;
41 | } else {
42 | int64_t dim = dim_;
43 | if (dim < 0) {
44 | dim = shape.size() + dim;
45 | }
46 | if (dim >= shape.size()) {
47 | throw std::runtime_error("BatchShape: dimension out of bound");
48 | }
49 | if (num_sample_ == 0) {
50 | shape_ = shape;
51 | } else {
52 | if (shape.size() != shape_.size()) {
53 | throw std::runtime_error(
54 | "BatchShape: batched arrays expected to have consistent shapes");
55 | }
56 | for (int d = 0; d < shape.size(); d++) {
57 | if (d == dim) {
58 | shape_[d] += shape[d];
59 | } else {
60 | shape_[d] = std::max(shape_[d], shape[d]);
61 | }
62 | }
63 | }
64 | }
65 | num_sample_++;
66 | }
67 |
68 | int64_t BatchShape::num_sample() const {
69 | return num_sample_;
70 | }
71 |
72 | void BatchShape::clear() {
73 | shape_.clear();
74 | num_sample_ = 0;
75 | }
76 |
77 | int64_t BatchShape::operator[](int dim) const {
78 | if (dim < 0) {
79 | dim = shape_.size() + dim;
80 | }
81 | if (dim >= shape_.size()) {
82 | throw std::runtime_error("BatchShape: dimension out of bound");
83 | }
84 | return shape_[dim];
85 | }
86 |
87 | } // namespace core
88 | } // namespace data
89 | } // namespace mlx
90 |
--------------------------------------------------------------------------------
/mlx/data/core/BatchShape.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #include
4 | #include
5 |
6 | namespace mlx {
7 | namespace data {
8 | namespace core {
9 |
10 | // Computes the shape of a batch
11 | class BatchShape {
12 | public:
13 | // batch by prefixing an extra dim
14 | BatchShape();
15 |
16 | // batch along a specified dim
17 | BatchShape(int dim);
18 |
19 | // add a shape to the batch
20 | void add(const std::vector& shape);
21 |
22 | void clear();
23 |
24 | int64_t size() const;
25 | const std::vector& shape() const;
26 | int64_t num_sample() const;
27 | int64_t operator[](int dim) const;
28 |
29 | private:
30 | std::vector shape_;
31 | int dim_;
32 | bool nodim_;
33 | int64_t num_sample_;
34 | };
35 | } // namespace core
36 | } // namespace data
37 | } // namespace mlx
38 |
--------------------------------------------------------------------------------
/mlx/data/core/CSVReader.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #pragma once
4 |
5 | #include
6 | #include
7 | #include
8 | #include
9 |
10 | #include "bxzstr/bxzstr.hpp"
11 |
12 | namespace mlx {
13 | namespace data {
14 | namespace core {
15 |
16 | class CSVReader {
17 | public:
18 | CSVReader(
19 | const std::string& file,
20 | const char sep = ',',
21 | const char quote = '"');
22 | CSVReader(
23 | const std::shared_ptr& uf,
24 | const char sep = ',',
25 | const char quote = '"');
26 | std::vector next();
27 | void reset();
28 |
29 | private:
30 | void parse_line_(
31 | const std::string& line,
32 | std::vector& fields,
33 | int& current_state,
34 | std::string& current_field) const;
35 |
36 | std::string filename_;
37 | int numFields_ = -1;
38 | int numLine_ = 0;
39 | char sep_ = ',';
40 | char quote_ = '"';
41 | const char lf_ = '\n';
42 | const char cr_ = '\r';
43 | std::shared_ptr uf_;
44 | std::shared_ptr f_;
45 | };
46 |
47 | } // namespace core
48 | } // namespace data
49 | } // namespace mlx
50 |
--------------------------------------------------------------------------------
/mlx/data/core/FileFetcher.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #pragma once
4 |
5 | #include "mlx/data/Array.h"
6 | #include "mlx/data/core/ThreadPool.h"
7 |
8 | #include
9 | #include
10 | #include
11 | #include
12 | #include
13 | #include
14 | #include
15 |
16 | namespace mlx {
17 | namespace data {
18 | namespace core {
19 |
20 | class FileFetcherHandle {
21 | public:
22 | FileFetcherHandle(int64_t rank) : rank_(rank) {};
23 |
24 | private:
25 | int64_t rank_;
26 | friend class FileFetcher;
27 | };
28 |
29 | // Note that FileFetcher holds a weak reference on a given filename
30 | // If it is not valid anymore, it will be fetched again
31 | class FileFetcher {
32 | public:
33 | // In a multi-threaded environment, use numKeptFiles with caution: Make
34 | // sure it is large enough, to ensure threads won't compete on the files
35 | // to keep locally.
36 | FileFetcher(
37 | int num_prefetch_max = 1,
38 | int num_prefetch_threads = 1,
39 | int num_kept_files = 0,
40 | bool verbose = false);
41 |
42 | void prefetch(const std::vector& filenames);
43 |
44 | // Must be called in the destructor of any subclass
45 | // because prefetch calls the virtual backendFetch()
46 | // which would then be destroyed before ~FileFetcher()
47 | void cancel_prefetch();
48 |
49 | std::shared_ptr fetch(const std::string& filename) const;
50 |
51 | // Erase a file from cache, and call backend erase
52 | void erase(const std::string& filename) const;
53 |
54 | virtual void backend_fetch(const std::string& filename) const;
55 |
56 | virtual void backend_erase(const std::string& filename) const;
57 |
58 | virtual ~FileFetcher();
59 |
60 | protected:
61 | void fill_queue_() const;
62 | std::unique_ptr threadPool_;
63 | mutable std::deque prefetchFilenames_;
64 | mutable std::shared_mutex mutex_;
65 | int numPrefetchMax_;
66 | int numKeptFiles_;
67 | mutable int64_t fileRank_;
68 | bool verbose_;
69 |
70 | mutable std::unordered_map> queuedFiles_;
71 |
72 | mutable std::unordered_map>
73 | cachedFiles_;
74 | };
75 |
76 | } // namespace core
77 | } // namespace data
78 | } // namespace mlx
79 |
--------------------------------------------------------------------------------
/mlx/data/core/Levenshtein.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #include "mlx/data/Array.h"
4 |
5 | namespace mlx {
6 | namespace data {
7 | namespace core {
8 |
9 | std::shared_ptr levenshtein(
10 | const std::shared_ptr arr1,
11 | const std::shared_ptr len1,
12 | const std::shared_ptr arr2,
13 | const std::shared_ptr len2);
14 |
15 | }
16 | } // namespace data
17 | } // namespace mlx
18 |
--------------------------------------------------------------------------------
/mlx/data/core/Numpy.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #pragma once
4 |
5 | #include "mlx/data/Array.h"
6 |
7 | #include
8 | #include
9 | #include
10 | #include
11 | #include
12 |
13 | namespace mlx {
14 | namespace data {
15 | namespace core {
16 |
17 | std::shared_ptr load_numpy(const std::string& filename);
18 |
19 | /// @brief Read numpy (npy) array file.
20 | ///
21 | /// See also:
22 | /// https://numpy.org/doc/stable/reference/generated/numpy.lib.format.html
23 | ///
24 | /// @param filename file to read
25 | /// @return Array with the contents of the file
26 | std::shared_ptr load_numpy(
27 | std::istream& stream,
28 | const std::string& filename = nullptr);
29 |
30 | } // namespace core
31 | } // namespace data
32 | } // namespace mlx
33 |
--------------------------------------------------------------------------------
/mlx/data/core/State.cpp:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #include "mlx/data/core/State.h"
4 |
5 | namespace mlx {
6 | namespace data {
7 | namespace core {
8 |
9 | static State global_state;
10 |
11 | void set_state(int64_t seed) {
12 | global_state.randomGenerator = std::mt19937(seed);
13 | global_state.version++;
14 | }
15 |
16 | std::shared_ptr get_state() {
17 | static thread_local std::shared_ptr state;
18 | if (!state || (state->version != global_state.version)) {
19 | state = std::make_shared(global_state);
20 | }
21 | return state;
22 | };
23 |
24 | } // namespace core
25 | } // namespace data
26 | } // namespace mlx
27 |
--------------------------------------------------------------------------------
/mlx/data/core/State.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #pragma once
4 |
5 | #include
6 | #include
7 |
8 | #include "mlx/data/core/ThreadPool.h"
9 |
10 | namespace mlx {
11 | namespace data {
12 | namespace core {
13 |
14 | struct State {
15 | std::mt19937 randomGenerator;
16 | int64_t version;
17 | };
18 |
19 | // thread-local state
20 | std::shared_ptr get_state();
21 |
22 | // should be called in main thread only
23 | void set_state(int64_t seed);
24 |
25 | } // namespace core
26 | } // namespace data
27 | } // namespace mlx
28 |
--------------------------------------------------------------------------------
/mlx/data/core/TARReader.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #pragma once
4 |
5 | #include "mlx/data/Array.h"
6 |
7 | #include
8 | #include
9 | #include
10 | #include
11 | #include
12 |
13 | namespace mlx {
14 | namespace data {
15 | namespace core {
16 |
17 | typedef std::unordered_map>
18 | TARFileIndex;
19 |
20 | class TARReader {
21 | public:
22 | TARReader(
23 | const std::string& filename,
24 | bool nested = false,
25 | int num_threads = 1);
26 |
27 | bool contains(const std::string& filename);
28 | std::shared_ptr get(const std::string& filename);
29 | std::vector get_file_list();
30 |
31 | private:
32 | std::string filename_;
33 | TARFileIndex index_;
34 | };
35 |
36 | } // namespace core
37 | } // namespace data
38 | } // namespace mlx
39 |
--------------------------------------------------------------------------------
/mlx/data/core/ThreadController.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #include
4 | #include
5 |
6 | namespace mlx {
7 | namespace data {
8 | namespace core {
9 |
10 | typedef std::vector ThreadControllerState;
11 |
12 | struct ThreadControllerSym;
13 |
14 | class ThreadController {
15 | public:
16 | ThreadController();
17 |
18 | ThreadControllerState limit();
19 | void restore(const ThreadControllerState& state);
20 |
21 | private:
22 | std::vector> symbols_;
23 | };
24 |
25 | } // namespace core
26 | } // namespace data
27 | } // namespace mlx
28 |
--------------------------------------------------------------------------------
/mlx/data/core/ThreadPool.cpp:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #include "mlx/data/core/ThreadPool.h"
4 |
5 | namespace mlx {
6 | namespace data {
7 | namespace core {
8 |
9 | std::shared_ptr ThreadPool::thread_controller = nullptr;
10 |
11 | ThreadPool::ThreadPool(size_t thread_count) {
12 | if (!thread_controller) {
13 | thread_controller = std::make_shared();
14 | }
15 | for (size_t i = 0; i < thread_count; ++i) {
16 | // start waiting threads. Workers listen for changes through
17 | // the ThreadPool member condition_variable
18 | threads_.emplace_back(std::thread([&]() {
19 | std::unique_lock queue_lock(task_mutex_, std::defer_lock);
20 |
21 | while (true) {
22 | queue_lock.lock();
23 | task_cv_.wait(queue_lock, [&]() -> bool {
24 | return !tasks_.empty() || stop_threads_;
25 | });
26 |
27 | // used by dtor to stop all threads without having to
28 | // unceremoniously stop tasks. The tasks must all be
29 | // finished, lest we break a promise and risk a `future`
30 | // object throwing an exception.
31 | if (stop_threads_ && tasks_.empty())
32 | return;
33 |
34 | // to initialize temp_task, we must move the unique_ptr
35 | // from the queue to the local stack. Since a unique_ptr
36 | // cannot be copied (obviously), it must be explicitly
37 | // moved. This transfers ownership of the pointed-to
38 | // object to *this, as specified in 20.11.1.2.1
39 | // [unique.ptr.single.ctor].
40 | auto temp_task = std::move(tasks_.front());
41 |
42 | tasks_.pop();
43 | queue_lock.unlock();
44 |
45 | auto thread_state = thread_controller->limit();
46 | (*temp_task)();
47 | thread_controller->restore(thread_state);
48 | }
49 | }));
50 | }
51 | }
52 |
53 | ThreadPool::~ThreadPool() {
54 | stop_threads_ = true;
55 | task_cv_.notify_all();
56 |
57 | for (std::thread& thread : threads_) {
58 | thread.join();
59 | }
60 | }
61 | } // namespace core
62 | } // namespace data
63 | } // namespace mlx
64 |
--------------------------------------------------------------------------------
/mlx/data/core/Tokenizer.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #pragma once
4 |
5 | #include
6 | #include
7 |
8 | #include "mlx/data/core/Graph.h"
9 | #include "mlx/data/core/Trie.h"
10 |
11 | namespace mlx {
12 | namespace data {
13 | namespace core {
14 |
15 | std::shared_ptr> tokenize(
16 | std::shared_ptr> trie,
17 | const std::string& input,
18 | bool ignore_unk = false);
19 |
20 | class Tokenizer {
21 | public:
22 | Tokenizer(
23 | std::shared_ptr> trie,
24 | bool ignore_unk = false,
25 | const std::vector& trie_key_scores = {});
26 | std::shared_ptr> tokenize(const std::string& input) const;
27 | std::vector tokenize_shortest(const std::string& input) const;
28 | std::vector tokenize_rand(const std::string& input) const;
29 |
30 | private:
31 | std::shared_ptr> trie_;
32 | bool ignoreUnk_;
33 | std::vector trieKeyScores_;
34 | bool trieKeyScoresPositive_;
35 | };
36 |
37 | class TokenizerIterator {
38 | public:
39 | TokenizerIterator(std::shared_ptr> graph);
40 | std::vector next();
41 |
42 | private:
43 | std::shared_ptr> g_;
44 | std::vector edgeIndices_; // edge indices for path so far
45 | std::vector backEdgeIds_; // back edge ids for path so far
46 | int64_t currentNodeId_; // current node
47 | std::vector currentTokens_; // current tokenization
48 | std::unordered_set::const_iterator startNodeIterator_;
49 | bool new_start_();
50 | void forward_();
51 | };
52 |
53 | } // namespace core
54 | } // namespace data
55 | } // namespace mlx
56 |
--------------------------------------------------------------------------------
/mlx/data/core/Utils.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2023-2024 Apple Inc.
2 |
3 | #include "mlx/data/Array.h"
4 | #include "mlx/data/Sample.h"
5 |
6 | namespace mlx {
7 | namespace data {
8 | namespace core {
9 |
10 | std::pair, std::shared_ptr> uniq(
11 | const std::shared_ptr src,
12 | const std::shared_ptr src_length,
13 | int dim,
14 | double pad);
15 |
16 | std::pair, std::shared_ptr> remove(
17 | const std::shared_ptr src,
18 | const std::shared_ptr src_length,
19 | int dim,
20 | double value,
21 | double pad);
22 |
23 | std::shared_ptr replace(
24 | const std::shared_ptr& src,
25 | const std::shared_ptr& old,
26 | const std::shared_ptr& replacement,
27 | int count);
28 |
29 | Sample merge_batch(
30 | const std::vector& samples,
31 | const std::unordered_map& pad_values = {},
32 | const std::unordered_map& batch_dims = {});
33 |
34 | } // namespace core
35 | } // namespace data
36 | } // namespace mlx
37 |
--------------------------------------------------------------------------------
/mlx/data/core/audio/Audio.cpp:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #include "mlx/data/core/audio/AudioPrivate.h"
4 |
5 | namespace mlx {
6 | namespace data {
7 | namespace core {
8 | namespace audio {
9 |
10 | std::shared_ptr load(const std::string& path, AudioInfo* info) {
11 | return load_sndfile(path, info);
12 | }
13 |
14 | std::shared_ptr load(
15 | const std::shared_ptr& contents,
16 | AudioInfo* info) {
17 | return load_sndfile(contents, info);
18 | }
19 |
20 | AudioInfo info(const std::string& path) {
21 | return info_sndfile(path);
22 | }
23 |
24 | AudioInfo info(const std::shared_ptr& contents) {
25 | return info_sndfile(contents);
26 | }
27 |
28 | void verify_audio(const std::shared_ptr& audio) {
29 | auto dimensions = audio->shape().size();
30 | if (dimensions != 2) {
31 | throw std::runtime_error(
32 | "verifyAudio: audio must be 2 dimension Array (SC)");
33 | }
34 | }
35 |
36 | } // namespace audio
37 | } // namespace core
38 | } // namespace data
39 | } // namespace mlx
40 |
--------------------------------------------------------------------------------
/mlx/data/core/audio/Audio.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #pragma once
4 |
5 | #include "mlx/data/Array.h"
6 |
7 | namespace mlx {
8 | namespace data {
9 | namespace core {
10 | namespace audio {
11 |
12 | /// Image metadata
13 | struct AudioInfo {
14 | public:
15 | int64_t frames;
16 | int sampleRate;
17 | int channels;
18 | };
19 |
20 | std::shared_ptr load(const std::string& path, AudioInfo* info);
21 | std::shared_ptr load(
22 | const std::shared_ptr& contents,
23 | AudioInfo* info);
24 |
25 | AudioInfo info(const std::string& path);
26 | AudioInfo info(const std::shared_ptr& contents);
27 |
28 | /// Verify that the given Array is structured like audio:
29 | /// two dimensions (s, c).
30 | void verify_audio(const std::shared_ptr& audio);
31 |
32 | /// Return the number of frames in the audio. Requires that `verifyAudio()` be
33 | /// called previously.
34 | inline const int64_t frames(const std::shared_ptr& audio) {
35 | return audio->shape()[0];
36 | }
37 |
38 | /// Return the channel count of the audio. Requires that `verifyAudio()` be
39 | /// called previously.
40 | inline const int64_t channels(const std::shared_ptr& audio) {
41 | return audio->shape()[1];
42 | }
43 |
44 | /// Resample mode -- these should match 1:1 with the libsamplerate enum, e.g.
45 | /// SRC_SINC_BEST_QUALITY
46 | enum class ResampleMode {
47 | best = 0,
48 | medium = 1,
49 | fastest = 2,
50 | zeroOrderHold = 3,
51 | linear = 4,
52 | };
53 |
54 | std::shared_ptr resample(
55 | const std::shared_ptr& audio,
56 | ResampleMode resample_mode,
57 | int src_sample_rate,
58 | int dst_sample_rate);
59 |
60 | } // namespace audio
61 | } // namespace core
62 | } // namespace data
63 | } // namespace mlx
64 |
--------------------------------------------------------------------------------
/mlx/data/core/audio/AudioPrivate.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #pragma once
4 |
5 | #include "mlx/data/core/audio/Audio.h"
6 |
7 | namespace mlx {
8 | namespace data {
9 | namespace core {
10 | namespace audio {
11 |
12 | std::shared_ptr load_sndfile(const std::string& path, AudioInfo* info);
13 | std::shared_ptr load_sndfile(
14 | const std::shared_ptr& contents,
15 | AudioInfo* info);
16 |
17 | AudioInfo info_sndfile(const std::string& path);
18 | AudioInfo info_sndfile(const std::shared_ptr& contents);
19 |
20 | } // namespace audio
21 | } // namespace core
22 | } // namespace data
23 | } // namespace mlx
24 |
--------------------------------------------------------------------------------
/mlx/data/core/audio/AudioSampleRate.cpp:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #include
4 | #include "mlx/data/core/audio/Audio.h"
5 |
6 | #ifdef MLX_HAS_SAMPLERATE
7 | #include
8 | #endif
9 |
10 | namespace mlx {
11 | namespace data {
12 | namespace core {
13 | namespace audio {
14 |
15 | #ifdef MLX_HAS_SAMPLERATE
16 |
17 | std::shared_ptr resample(
18 | const std::shared_ptr& audio,
19 | ResampleMode resample_mode,
20 | int src_sample_rate,
21 | int dst_sample_rate) {
22 | if ((dst_sample_rate <= 0) || (src_sample_rate == dst_sample_rate)) {
23 | return audio;
24 | }
25 |
26 | int64_t audio_channels = channels(audio);
27 | int64_t audio_length = frames(audio);
28 |
29 | double length_scale = static_cast(dst_sample_rate) /
30 | static_cast(src_sample_rate);
31 | int64_t new_audio_length =
32 | static_cast(std::floor(audio_length * length_scale));
33 | auto result = std::make_shared(
34 | ArrayType::Float, new_audio_length, audio_channels);
35 | SRC_DATA src_data;
36 | src_data.data_in = audio->data();
37 | src_data.input_frames = audio_length;
38 | src_data.data_out = result->data();
39 | src_data.output_frames = new_audio_length;
40 | src_data.src_ratio = length_scale;
41 | auto status = src_simple(&src_data, (int)resample_mode, audio_channels);
42 | if (status) {
43 | std::string msg("audio: libsamplerate failed with: ");
44 | msg += std::string(src_strerror(status));
45 | throw std::runtime_error(msg);
46 | }
47 |
48 | if (new_audio_length != src_data.output_frames_gen) {
49 | std::vector offset(2, 0);
50 | auto new_shape = result->shape();
51 | new_shape[0] = src_data.output_frames_gen;
52 | result = array::sub(result, offset, new_shape);
53 | }
54 |
55 | return result;
56 | }
57 |
58 | #else
59 |
60 | std::shared_ptr resample(
61 | const std::shared_ptr