├── .circleci └── config.yml ├── .clang-format ├── .gitattributes ├── .gitignore ├── .pre-commit-config.yaml ├── ACKNOWLEDGMENTS.md ├── CMakeLists.txt ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── benchmarks ├── comparative │ ├── README.md │ ├── caltech101 │ │ ├── .gitignore │ │ ├── README.md │ │ ├── mlx_data.py │ │ ├── pytorch.py │ │ ├── run_caltech.sh │ │ ├── tfds.py │ │ └── utils.py │ ├── librispeech │ │ ├── .gitignore │ │ ├── README.md │ │ ├── mlx_data.py │ │ ├── pytorch.py │ │ ├── run_librispeech.sh │ │ ├── tfds.py │ │ └── utils.py │ └── wikitext │ │ ├── .gitignore │ │ ├── README.md │ │ ├── mlx_data.py │ │ ├── pytorch.py │ │ ├── run_wikitext.sh │ │ ├── tfds.py │ │ └── utils.py └── utils.py ├── cmake ├── FindFFMPEG.cmake ├── FindFLAC.cmake ├── FindJPEGTURBO.cmake ├── FindOgg.cmake ├── FindSampleRate.cmake ├── FindSndFile.cmake ├── FindVorbis.cmake ├── bxzstr-v1.2.3.patch └── pybind11-v2.11.1.patch ├── docs ├── .gitignore ├── Makefile ├── README.md ├── requirements.txt └── src │ ├── _static │ ├── mlx_logo.png │ └── mlx_logo_dark.png │ ├── _templates │ └── data_core_modules.rst │ ├── buffers_streams_samples.rst │ ├── conf.py │ ├── hf_datasets_streams.rst │ ├── index.rst │ ├── install.rst │ ├── python │ ├── buffer.rst │ ├── common_datasets.rst │ ├── dataset.rst │ ├── features.rst │ ├── miscellaneous.rst │ ├── stream.rst │ └── tokenizing.rst │ └── quick_start.rst ├── mlx-data.pc.in ├── mlx └── data │ ├── Array.cpp │ ├── Array.h │ ├── Buffer.cpp │ ├── Buffer.h │ ├── Dataset.cpp │ ├── Dataset.h │ ├── Sample.cpp │ ├── Sample.h │ ├── Stream.cpp │ ├── Stream.h │ ├── buffer │ ├── Batch.cpp │ ├── Batch.h │ ├── Buffer.cpp │ ├── Buffer.h │ ├── DynamicBatch.cpp │ ├── DynamicBatch.h │ ├── FilesFromTAR.cpp │ ├── FilesFromTAR.h │ ├── FromStream.cpp │ ├── FromStream.h │ ├── FromVector.cpp │ ├── FromVector.h │ ├── Partition.cpp │ ├── Partition.h │ ├── Perm.cpp │ ├── Perm.h │ ├── Shuffle.cpp │ ├── Shuffle.h │ ├── Transform.cpp │ └── Transform.h │ ├── core │ ├── AWSFileFetcher.cpp │ ├── AWSFileFetcher.h │ ├── BPETokenizer.cpp │ ├── BPETokenizer.h │ ├── BatchShape.cpp │ ├── BatchShape.h │ ├── CSVReader.cpp │ ├── CSVReader.h │ ├── FileFetcher.cpp │ ├── FileFetcher.h │ ├── Graph.cpp │ ├── Graph.h │ ├── Levenshtein.cpp │ ├── Levenshtein.h │ ├── Numpy.cpp │ ├── Numpy.h │ ├── State.cpp │ ├── State.h │ ├── TARReader.cpp │ ├── TARReader.h │ ├── ThreadController.cpp │ ├── ThreadController.h │ ├── ThreadPool.cpp │ ├── ThreadPool.h │ ├── Tokenizer.cpp │ ├── Tokenizer.h │ ├── Trie.h │ ├── Utils.cpp │ ├── Utils.h │ ├── audio │ │ ├── Audio.cpp │ │ ├── Audio.h │ │ ├── AudioPrivate.h │ │ ├── AudioSampleRate.cpp │ │ └── AudioSndfile.cpp │ ├── image │ │ ├── Image.h │ │ ├── ImageIO.cpp │ │ ├── ImageJPEG.cpp │ │ ├── ImagePrivate.h │ │ ├── ImageSTBI.cpp │ │ └── ImageTransform.cpp │ ├── imemstream.h │ └── video │ │ ├── Video.cpp │ │ ├── Video.h │ │ ├── VideoFFMPEG.cpp │ │ └── VideoPrivate.h │ ├── op │ ├── FilterByShape.cpp │ ├── FilterByShape.h │ ├── FilterKey.cpp │ ├── FilterKey.h │ ├── ImageTransform.cpp │ ├── ImageTransform.h │ ├── KeyTransform.cpp │ ├── KeyTransform.h │ ├── LoadAudio.cpp │ ├── LoadAudio.h │ ├── LoadFile.cpp │ ├── LoadFile.h │ ├── LoadImage.cpp │ ├── LoadImage.h │ ├── LoadNumpy.cpp │ ├── LoadNumpy.h │ ├── LoadVideo.cpp │ ├── LoadVideo.h │ ├── Op.cpp │ ├── Op.h │ ├── Pad.cpp │ ├── Pad.h │ ├── ReadFromTAR.cpp │ ├── ReadFromTAR.h │ ├── RemoveValue.cpp │ ├── RemoveValue.h │ ├── RenameKey.cpp │ ├── RenameKey.h │ ├── Replace.cpp │ ├── Replace.h │ ├── SampleTransform.cpp │ ├── SampleTransform.h │ ├── SaveImage.cpp │ ├── SaveImage.h │ ├── Shape.cpp │ ├── Shape.h │ ├── Shard.cpp │ ├── Shard.h │ ├── Slice.cpp │ ├── Slice.h │ ├── Squeeze.cpp │ ├── Squeeze.h │ ├── Tokenize.cpp │ └── Tokenize.h │ └── stream │ ├── Batch.cpp │ ├── Batch.h │ ├── Buffered.cpp │ ├── Buffered.h │ ├── CSVReader.cpp │ ├── CSVReader.h │ ├── Compose.cpp │ ├── Compose.h │ ├── DynamicBatch.cpp │ ├── DynamicBatch.h │ ├── FromBuffer.cpp │ ├── FromBuffer.h │ ├── LineReader.cpp │ ├── LineReader.h │ ├── OrderedPrefetch.cpp │ ├── OrderedPrefetch.h │ ├── Partition.cpp │ ├── Partition.h │ ├── Prefetch.cpp │ ├── Prefetch.h │ ├── Repeat.cpp │ ├── Repeat.h │ ├── Shuffle.cpp │ ├── Shuffle.h │ ├── SlidingWindow.cpp │ ├── SlidingWindow.h │ ├── Stream.cpp │ ├── Stream.h │ ├── Transform.cpp │ └── Transform.h ├── pyproject.toml ├── python ├── mlx │ └── data │ │ ├── __init__.py │ │ ├── core.py │ │ ├── datasets │ │ ├── __init__.py │ │ ├── cifar.py │ │ ├── common.py │ │ ├── image_folder.py │ │ ├── imagenet.py │ │ ├── librispeech.py │ │ ├── libritts_r.py │ │ ├── mnist.py │ │ ├── speechcommands.py │ │ └── wikitext.py │ │ ├── features │ │ ├── __init__.py │ │ └── audio.py │ │ └── tokenizer_helpers.py ├── src │ ├── CMakeLists.txt │ ├── wrap.cpp │ ├── wrap.h │ ├── wrap_buffer.cpp │ ├── wrap_core.cpp │ ├── wrap_dataset.h │ └── wrap_stream.cpp └── tests │ ├── test_bpe.py │ ├── test_buffer.py │ ├── test_general_ops.py │ └── test_replace.py ├── setup.py ├── super ├── CMakeLists.txt └── cmake │ ├── aws-1.11.557.patch │ ├── bzip2-1.0.8.patch │ ├── flac-1.5.0.patch │ └── xvidcore-1.3.7.patch └── tests └── CMakeLists.txt /.gitattributes: -------------------------------------------------------------------------------- 1 | *.jpeg filter=lfs diff=lfs merge=lfs -text 2 | *.jpg filter=lfs diff=lfs merge=lfs -text 3 | *.png filter=lfs diff=lfs merge=lfs -text 4 | *.mov filter=lfs diff=lfs merge=lfs -text 5 | 6 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # build artifacts 2 | build/ 3 | xcode/ 4 | dist/ 5 | __pycache__/ 6 | *.egg-info/ 7 | *.so 8 | *.c 9 | *.swp 10 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/mirrors-clang-format 3 | rev: v19.1.7 4 | hooks: 5 | - id: clang-format 6 | - repo: https://github.com/psf/black-pre-commit-mirror 7 | rev: 25.1.0 8 | hooks: 9 | - id: black 10 | - repo: https://github.com/pycqa/isort 11 | rev: 6.0.0 12 | hooks: 13 | - id: isort 14 | args: 15 | - --profile=black 16 | - repo: https://github.com/cheshirekow/cmake-format-precommit 17 | rev: v0.6.13 18 | hooks: 19 | - id: cmake-format 20 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to MLX data 2 | 3 | We want to make contributing to this project as easy and transparent as 4 | possible. 5 | 6 | ## Pull Requests 7 | 8 | 1. Fork and submit pull requests to the repo. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If a change is likely to impact efficiency, run some of the benchmarks before 11 | and after the change. Examples of benchmarks can be found in `benchmarks/`. 12 | 4. If you've changed APIs, update the documentation. 13 | 5. Every PR should have passing tests and at least one review. 14 | 6. For code formatting install `pre-commit` using something like `pip install pre-commit` and run `pre-commit install`. 15 | This should install hooks for running `black` and `clang-format` to ensure 16 | consistent style for C++ and python code. 17 | 18 | You can also run the formatters manually as follows: 19 | 20 | ``` 21 | clang-format -i file.cpp 22 | ``` 23 | 24 | ``` 25 | black file.py 26 | ``` 27 | 28 | or run `pre-commit run --all-files` to check all files in the repo. 29 | 30 | ## Issues 31 | 32 | We use GitHub issues to track public bugs. Please ensure your description is 33 | clear and has sufficient instructions to be able to reproduce the issue. 34 | 35 | ## License 36 | 37 | By contributing to MLX data, you agree that your contributions will be licensed 38 | under the LICENSE file in the root directory of this source tree. 39 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright © 2023 Apple Inc. All rights reserved. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /benchmarks/comparative/README.md: -------------------------------------------------------------------------------- 1 | Comparative Benchmarks 2 | ====================== 3 | 4 | In this folder we compare `mlx.data` with PyTorch DataLoaders and `tf.data`. We 5 | try to keep the comparison as fair as possible but like any benchmark it can 6 | never replace real-world performance measurements on your particular use-case. 7 | 8 | The goal is to show that `mlx.data` is concise and fast. 9 | 10 | Running the benchmarks 11 | ---------------------- 12 | 13 | Each folder has separate instructions on how to download the data and run the benchmark. 14 | -------------------------------------------------------------------------------- /benchmarks/comparative/caltech101/.gitignore: -------------------------------------------------------------------------------- 1 | data 2 | -------------------------------------------------------------------------------- /benchmarks/comparative/caltech101/README.md: -------------------------------------------------------------------------------- 1 | Caltech 101 Benchmark 2 | ===================== 3 | 4 | This benchmark creates a simple image classification data loading pipeline from 5 | a set of directories containing images. Practically, what is implemented in 6 | `torchvision`'s `ImageFolder` dataset. 7 | 8 | For PyTorch we use the aforementioned `ImageFolder` dataset while for `tf.data` 9 | we build the pipeline according to [their image classification 10 | tutorial](https://www.tensorflow.org/tutorials/load_data/images). 11 | 12 | We apply a minimal set of transforms, namely we resize the original image to 13 | 256 pixels across the smallest dimension, center-crop to 224 by 224 pixels and 14 | transform the pixel values to floating point in [0, 1]. 15 | 16 | Getting the data 17 | ---------------- 18 | 19 | You can download the data manually from 20 | [https://data.caltech.edu/records/mzrjq-6wc02](https://data.caltech.edu/records/mzrjq-6wc02), 21 | or use the provided bash script to download the data, extract them and run the 22 | benchmarks one after the other as follows: 23 | 24 | bash run_caltech.sh 25 | 26 | 27 | Running the benchmarks 28 | ---------------------- 29 | 30 | If you already have the data, you can run each benchmark by simply pointing it 31 | to directory that holds the Caltech101 dataset. Just make sure that the archive 32 | containing the images is untarred. 33 | 34 | python mlx_data.py /path/to/Caltech101 35 | 36 | 37 | Dependencies 38 | ------------ 39 | 40 | To automatically download the data you need `wget` and `unzip`. To run the 41 | benchmarks you need `tensorflow` and `PyTorch` with `torchvision` installed. 42 | -------------------------------------------------------------------------------- /benchmarks/comparative/caltech101/mlx_data.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | import argparse 4 | from pathlib import Path 5 | 6 | from utils import Benchmark 7 | 8 | import mlx.data as dx 9 | 10 | 11 | def files_and_classes(root: Path): 12 | files = [str(f) for f in root.glob("**/*.jpg")] 13 | files = [f for f in files if "BACKGROUND" not in f] 14 | classes = dict( 15 | map(reversed, enumerate(sorted(set(f.split("/")[-2] for f in files)))) 16 | ) 17 | 18 | return [ 19 | dict(image=f.encode("ascii"), label=classes[f.split("/")[-2]]) for f in files 20 | ] 21 | 22 | 23 | def iterate(args, workers): 24 | root = Path(args.data_dir) 25 | 26 | if args.ordered_prefetch: 27 | dset = ( 28 | dx.buffer_from_vector(files_and_classes(root)) 29 | .shuffle() 30 | .load_image("image") 31 | .image_resize_smallest_side("image", 256) 32 | .image_center_crop("image", 224, 224) 33 | .batch(args.batch_size) 34 | .key_transform("image", lambda x: x.astype("float32") / 255) 35 | .ordered_prefetch(workers, workers) 36 | ) 37 | else: 38 | dset = ( 39 | dx.buffer_from_vector(files_and_classes(root)) 40 | .shuffle() 41 | .to_stream() 42 | .load_image("image") 43 | .image_resize_smallest_side("image", 256) 44 | .image_center_crop("image", 224, 224) 45 | .batch(args.batch_size) 46 | .key_transform("image", lambda x: x.astype("float32") / 255) 47 | .prefetch(workers, workers) 48 | ) 49 | 50 | cnt = 0 51 | for sample in dset: 52 | cnt += 1 53 | return cnt 54 | 55 | 56 | if __name__ == "__main__": 57 | parser = argparse.ArgumentParser() 58 | parser.add_argument("data_dir") 59 | parser.add_argument("--batch_size", type=int, default=32) 60 | parser.add_argument("--ordered_prefetch", action="store_true") 61 | args = parser.parse_args() 62 | 63 | benchmark = Benchmark("MLX Caltech 101") 64 | for i in range(3): 65 | benchmark.log_run("iterate_no_workers", iterate, args, 1) 66 | 67 | for i in range(3): 68 | benchmark.log_run("iterate_workers_8", iterate, args, 8) 69 | 70 | for i in range(3): 71 | benchmark.log_run("iterate_workers_16", iterate, args, 16) 72 | benchmark.report() 73 | -------------------------------------------------------------------------------- /benchmarks/comparative/caltech101/pytorch.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | import argparse 4 | 5 | import torch 6 | from torchvision.datasets import ImageFolder 7 | from torchvision.transforms import v2 as transforms 8 | from utils import Benchmark 9 | 10 | 11 | class Caltech101(ImageFolder): 12 | def __init__(self, root: str): 13 | super().__init__( 14 | str(root), 15 | transform=transforms.Compose( 16 | [ 17 | transforms.ToImage(), 18 | transforms.Resize(256), 19 | transforms.CenterCrop(224), 20 | transforms.ToDtype(torch.float32, scale=True), 21 | ] 22 | ), 23 | is_valid_file=self.is_valid_file, 24 | ) 25 | 26 | def is_valid_file(self, filepath: str): 27 | return "BACKGROUND" not in filepath and filepath.endswith("jpg") 28 | 29 | 30 | def iterate(args, workers): 31 | dset = Caltech101(args.data_dir) 32 | data_loader = torch.utils.data.DataLoader( 33 | dset, shuffle=True, batch_size=args.batch_size, num_workers=workers 34 | ) 35 | 36 | cnt = 0 37 | for batch in data_loader: 38 | cnt += 1 39 | return cnt 40 | 41 | 42 | if __name__ == "__main__": 43 | parser = argparse.ArgumentParser() 44 | parser.add_argument("data_dir") 45 | parser.add_argument("--batch_size", type=int, default=32) 46 | args = parser.parse_args() 47 | 48 | benchmark = Benchmark("PyTorch Caltech 101") 49 | for i in range(3): 50 | benchmark.log_run("iterate_no_workers", iterate, args, 0) 51 | 52 | for i in range(3): 53 | benchmark.log_run("iterate_workers_8", iterate, args, 8) 54 | 55 | for i in range(3): 56 | benchmark.log_run("iterate_workers_16", iterate, args, 16) 57 | benchmark.report() 58 | -------------------------------------------------------------------------------- /benchmarks/comparative/caltech101/run_caltech.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Get the data 4 | if [ ! -d data ]; then 5 | mkdir data 6 | pushd data 7 | wget "https://data.caltech.edu/records/mzrjq-6wc02/files/caltech-101.zip?download=1" -O caltech-101.zip 8 | unzip caltech-101.zip && rm caltech-101.zip 9 | pushd caltech-101 10 | tar -zxf 101_ObjectCategories.tar.gz && rm 101_ObjectCategories.tar.gz 11 | popd 12 | popd 13 | fi 14 | 15 | # Run the benchmarks 16 | python pytorch.py data/caltech-101 2>/dev/null 17 | echo "=============" 18 | python tfds.py data/caltech-101 2>/dev/null 19 | echo "=============" 20 | python mlx_data.py data/caltech-101 21 | -------------------------------------------------------------------------------- /benchmarks/comparative/caltech101/tfds.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | import argparse 4 | from functools import partial 5 | from pathlib import Path 6 | 7 | import tensorflow as tf 8 | from utils import Benchmark 9 | 10 | 11 | def files_and_classes(root: Path): 12 | files = [str(f) for f in root.glob("**/*.jpg")] 13 | files = [f for f in files if "BACKGROUND" not in f] 14 | classes = dict( 15 | map(reversed, enumerate(sorted(set(f.split("/")[-2] for f in files)))) 16 | ) 17 | class_per_file = [classes[f.split("/")[-2]] for f in files] 18 | 19 | return files, class_per_file 20 | 21 | 22 | def to(dtype, x): 23 | if isinstance(x, list): 24 | return [to(dtype, xi) for xi in x] 25 | return tf.cast(x, dtype=dtype) 26 | 27 | 28 | to_float32 = partial(to, tf.float32) 29 | to_int32 = partial(to, tf.int32) 30 | 31 | 32 | def process_sample(sample): 33 | file_path = sample["file"] 34 | data = tf.io.read_file(file_path) 35 | img = tf.io.decode_jpeg(data, channels=3) 36 | 37 | height, width = to_float32([tf.shape(img)[0], tf.shape(img)[1]]) 38 | min_side = to_float32(tf.minimum(height, width)) 39 | scale_factor = min_side / tf.constant(256, dtype=tf.float32) 40 | img = tf.image.resize(img, to_int32([scale_factor * height, scale_factor * width])) 41 | img = tf.image.resize_with_crop_or_pad(img, 224, 224) 42 | img = img / tf.constant(255, dtype=tf.float32) 43 | 44 | return dict(image=img, label=sample["label"]) 45 | 46 | 47 | def iterate(args, workers): 48 | root = Path(args.data_dir) 49 | files, classes = files_and_classes(root) 50 | ds = tf.data.Dataset.zip( 51 | dict( 52 | file=tf.data.Dataset.from_tensor_slices(files), 53 | label=tf.data.Dataset.from_tensor_slices(classes), 54 | ) 55 | ) 56 | ds = ( 57 | ds.shuffle(buffer_size=1000) 58 | .map(process_sample) 59 | .batch(args.batch_size) 60 | .prefetch(workers) 61 | ) 62 | options = tf.data.Options() 63 | options.threading.private_threadpool_size = workers 64 | 65 | cnt = 0 66 | for batch in ds.with_options(options): 67 | cnt += 1 68 | return cnt 69 | 70 | 71 | if __name__ == "__main__": 72 | parser = argparse.ArgumentParser() 73 | parser.add_argument("data_dir") 74 | parser.add_argument("--batch_size", type=int, default=32) 75 | args = parser.parse_args() 76 | 77 | benchmark = Benchmark("TFDS Caltech 101") 78 | for i in range(3): 79 | benchmark.log_run("iterate_no_workers", iterate, args, 1) 80 | 81 | for i in range(3): 82 | benchmark.log_run("iterate_workers_8", iterate, args, 8) 83 | 84 | for i in range(3): 85 | benchmark.log_run("iterate_workers_16", iterate, args, 16) 86 | benchmark.report() 87 | -------------------------------------------------------------------------------- /benchmarks/comparative/caltech101/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | import time 4 | from collections import defaultdict 5 | 6 | 7 | class Benchmark: 8 | def __init__(self, name): 9 | self.name = name 10 | self.runtimes = defaultdict(list) 11 | self.n_samples = defaultdict(int) 12 | 13 | def log_run(self, run_name, fn, *args, **kwargs): 14 | start = time.time() 15 | n_samples = fn(*args, **kwargs) 16 | end = time.time() 17 | self.runtimes[run_name].append(end - start) 18 | self.n_samples[run_name] = n_samples 19 | 20 | def report(self): 21 | print(f"Benchmark {self.name}") 22 | print() 23 | klengths = max(map(len, self.runtimes.keys())) 24 | table = [] 25 | for k in self.runtimes.keys(): 26 | n_runs = len(self.runtimes[k]) 27 | avg_time = sum(self.runtimes[k]) / n_runs 28 | avg_throughput = self.n_samples[k] / avg_time 29 | table.append( 30 | [ 31 | k, 32 | f"{avg_time:.3f} s", 33 | f"{avg_throughput:.3f} samples/s", 34 | f"{n_runs} runs", 35 | ] 36 | ) 37 | 38 | column_widths = [ 39 | max(len(table[j][i]) for j in range(len(table))) 40 | for i in range(len(table[0])) 41 | ] 42 | for row in table: 43 | for i, (v, w) in enumerate(zip(row, column_widths)): 44 | if i == 0: 45 | print(v, " " * (w - len(v)), " " * 4, end="") 46 | else: 47 | print(" " * (w - len(v)), v, " " * 4, end="") 48 | print() 49 | -------------------------------------------------------------------------------- /benchmarks/comparative/librispeech/.gitignore: -------------------------------------------------------------------------------- 1 | data 2 | -------------------------------------------------------------------------------- /benchmarks/comparative/librispeech/README.md: -------------------------------------------------------------------------------- 1 | LibriSpeech Benchmark 2 | ===================== 3 | 4 | This benchmark creates a simple ASR pipeline from the structure of the 5 | LibriSpeech dataset. Simply put there is a set of text files containing lists 6 | of audio files and their transcripts. We transform the audio to Mel scaled 7 | spectrograms and the transcripts to tokens from a SentencePiece tokenizer. 8 | 9 | For PyTorch we use the torchaudio LIBRISPEECH dataset and the `MelSpectrogram` 10 | transform with soundfile to load the audio. For tensorflow we use 11 | `tensorflow_text` for the tokenizer and `tensorflow_io` to read the audio. 12 | 13 | A more realistic ASR pipeline would load the audio info first in order to 14 | subsequently group the files by length. This can be achieved efficiently in `mlx.data` 15 | using buffered streams. 16 | 17 | Getting the data 18 | ---------------- 19 | 20 | You can download the data manually from 21 | [https://www.openslr.org/12](https://www.openslr.org/12) or use the provided 22 | bash script to download the data, the tokenizer and run the benchmarks. The 23 | tokenizer model we use is the Llama SPM model provided by Huggingface. 24 | 25 | bash run_librispeech.sh 26 | 27 | 28 | Running the benchmarks 29 | ---------------------- 30 | 31 | If you already have the data, you can run each benchmark by simply pointing it 32 | to directory that holds the LibriSpeech dataset. Just make sure the archive is 33 | extracted. 34 | 35 | OMP_NUM_THREADS=1 python mlx_data.py \ 36 | --tokenizer_file /path/to/tokenizer.model \ 37 | /path/to/librispeech/LibriSpeech/dev-clean 38 | 39 | You should run the PyTorch and `mlx.data` benchmarks with `OMP_NUM_THREADS=1` 40 | to avoid thread contention from computing the FFT and other operations on CPU. 41 | 42 | Dependencies 43 | ------------ 44 | 45 | To automatically download the data you need `wget` and `unzip`. To run the 46 | benchmarks you need the following dependencies: 47 | 48 | - `tensorflow` 49 | - `tensorflow_io` 50 | - `tensorflow_text` 51 | - `sentencepiece` 52 | - `torch` 53 | - `torchaudio` 54 | - `intel-numpy` for fast FFT 55 | 56 | Since the featurization in `mlx.data` happens in `numpy` we propose you install 57 | `intel-numpy` for a fast FFT implementation. 58 | -------------------------------------------------------------------------------- /benchmarks/comparative/librispeech/mlx_data.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | import argparse 4 | from pathlib import Path 5 | 6 | from mlx.data.features import mfsc 7 | from mlx.data.tokenizer_helpers import read_trie_from_spm 8 | from utils import Benchmark 9 | 10 | import mlx.data as dx 11 | 12 | 13 | def to_audio_and_transcript(sample): 14 | # Split the line 15 | file_part, transcript = bytes(sample["line"]).split(b" ", 1) 16 | 17 | # Extract the audio path 18 | parts = file_part.split(b"-") 19 | parts[-1] = file_part + b".flac" 20 | audio_path = b"/".join(parts) 21 | 22 | # Prepare the transcript 23 | transcript = transcript.lower() 24 | 25 | return {"audio": audio_path, "transcript": transcript} 26 | 27 | 28 | def iterate(args, workers): 29 | root = Path(args.data_dir) 30 | 31 | # Load the list of lists of files 32 | filelist = [{"file": str(f).encode("ascii")} for f in root.glob("**/*.txt")] 33 | 34 | # Load the tokenizer 35 | trie, _ = read_trie_from_spm(args.tokenizer_file) 36 | 37 | dset = ( 38 | dx.buffer_from_vector(filelist) 39 | .shuffle() 40 | .to_stream() 41 | .line_reader_from_key("file", "line") 42 | # Transform the lines into an audio path and the transcript 43 | .sample_transform(to_audio_and_transcript) 44 | # Load the audio and extract features 45 | .load_audio("audio", prefix=args.data_dir) 46 | .squeeze("audio") 47 | .key_transform("audio", mfsc(128, 16000)) 48 | .shape("audio", "audio_length", 0) 49 | # Tokenize the transcript 50 | .tokenize("transcript", trie) 51 | .pad("transcript", 0, 1, 0, trie.search("").id) 52 | .pad("transcript", 0, 0, 1, trie.search("").id) 53 | .shape("transcript", "transcript_length", 0) 54 | # Batch and prefetch 55 | .batch(args.batch_size) 56 | .prefetch(workers, workers) 57 | ) 58 | 59 | cnt = 0 60 | for sample in dset: 61 | cnt += 1 62 | return cnt 63 | 64 | 65 | if __name__ == "__main__": 66 | parser = argparse.ArgumentParser() 67 | parser.add_argument("data_dir") 68 | parser.add_argument("--batch_size", type=int, default=32) 69 | parser.add_argument("--tokenizer_file", default="tokenizer.model") 70 | args = parser.parse_args() 71 | 72 | benchmark = Benchmark("MLX Librispeech") 73 | for i in range(3): 74 | benchmark.log_run("iterate_no_workers", iterate, args, 1) 75 | 76 | for i in range(3): 77 | benchmark.log_run("iterate_workers_8", iterate, args, 8) 78 | 79 | for i in range(3): 80 | benchmark.log_run("iterate_workers_16", iterate, args, 16) 81 | benchmark.report() 82 | -------------------------------------------------------------------------------- /benchmarks/comparative/librispeech/run_librispeech.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Get the data 4 | if [ ! -d data ]; then 5 | mkdir data 6 | pushd data 7 | wget "https://www.openslr.org/resources/12/dev-clean.tar.gz" -O dev-clean.tar.gz 8 | tar -zxf dev-clean.tar.gz && rm dev-clean.tar.gz 9 | wget "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer.model" -O tokenizer.model 10 | popd 11 | fi 12 | 13 | # Run the benchmarks 14 | OMP_NUM_THREADS=1 python pytorch.py --tokenizer_file data/tokenizer.model data/LibriSpeech/dev-clean 2>/dev/null 15 | echo "=============" 16 | python tfds.py --tokenizer_file data/tokenizer.model data/LibriSpeech/dev-clean 2>/dev/null 17 | echo "=============" 18 | python mlx_data.py --tokenizer_file data/tokenizer.model data/LibriSpeech/dev-clean 19 | -------------------------------------------------------------------------------- /benchmarks/comparative/librispeech/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | import time 4 | from collections import defaultdict 5 | 6 | 7 | class Benchmark: 8 | def __init__(self, name): 9 | self.name = name 10 | self.runtimes = defaultdict(list) 11 | self.n_samples = defaultdict(int) 12 | 13 | def log_run(self, run_name, fn, *args, **kwargs): 14 | start = time.time() 15 | n_samples = fn(*args, **kwargs) 16 | end = time.time() 17 | self.runtimes[run_name].append(end - start) 18 | self.n_samples[run_name] = n_samples 19 | 20 | def report(self): 21 | print(f"Benchmark {self.name}") 22 | print() 23 | klengths = max(map(len, self.runtimes.keys())) 24 | table = [] 25 | for k in self.runtimes.keys(): 26 | n_runs = len(self.runtimes[k]) 27 | avg_time = sum(self.runtimes[k]) / n_runs 28 | avg_throughput = self.n_samples[k] / avg_time 29 | table.append( 30 | [ 31 | k, 32 | f"{avg_time:.3f} s", 33 | f"{avg_throughput:.3f} samples/s", 34 | f"{n_runs} runs", 35 | ] 36 | ) 37 | 38 | column_widths = [ 39 | max(len(table[j][i]) for j in range(len(table))) 40 | for i in range(len(table[0])) 41 | ] 42 | for row in table: 43 | for i, (v, w) in enumerate(zip(row, column_widths)): 44 | if i == 0: 45 | print(v, " " * (w - len(v)), " " * 4, end="") 46 | else: 47 | print(" " * (w - len(v)), v, " " * 4, end="") 48 | print() 49 | -------------------------------------------------------------------------------- /benchmarks/comparative/wikitext/.gitignore: -------------------------------------------------------------------------------- 1 | data 2 | -------------------------------------------------------------------------------- /benchmarks/comparative/wikitext/README.md: -------------------------------------------------------------------------------- 1 | WikiText Benchmark 2 | ===================== 3 | 4 | This benchmark reads the Wikitext103 dataset using a python generator, 5 | tokenizes the text using an SPM model, computes a sliding window of 1,025 6 | tokens and uses a shuffle buffer of 1,000 samples. 7 | 8 | Getting the data 9 | ---------------- 10 | 11 | You can download the data manually from 12 | [blog.salesforceairesearch.com/.../](https://blog.salesforceairesearch.com/the-wikitext-long-term-dependency-language-modeling-dataset/) 13 | or use the provided bash script to download the data, the tokenizer and run the 14 | benchmarks. The tokenizer model we use is the Llama SPM model provided by 15 | Huggingface. 16 | 17 | bash run_wikitext.sh 18 | 19 | 20 | Running the benchmarks 21 | ---------------------- 22 | 23 | If you already have the data, you can run each benchmark by simply pointing it 24 | to directory that holds the wikitext dataset. Just make sure the archive is 25 | extracted. 26 | 27 | OMP_NUM_THREADS=1 python mlx_data.py \ 28 | --tokenizer_file /path/to/tokenizer.model \ 29 | /path/to/wikitext/wikitext-103-raw 30 | 31 | Dependencies 32 | ------------ 33 | 34 | To automatically download the data you need `wget` and `unzip`. To run the 35 | TF benchmark you need the following dependencies: 36 | 37 | - `tensorflow` 38 | - `tensorflow_io` 39 | - `tensorflow_text` 40 | - `sentencepiece` 41 | -------------------------------------------------------------------------------- /benchmarks/comparative/wikitext/run_wikitext.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Get the data 4 | if [ ! -d data ]; then 5 | mkdir data 6 | pushd data 7 | wget "https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-raw-v1.zip" -O wikitext-103-raw-v1.zip 8 | unzip wikitext-103-raw-v1.zip && rm wikitext-103-raw-v1.zip 9 | wget "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer.model" -O tokenizer.model 10 | popd 11 | fi 12 | 13 | # Run the benchmarks 14 | python pytorch.py --tokenizer_file data/tokenizer.model data/wikitext-103-raw 15 | echo "=============" 16 | python tfds.py --tokenizer_file data/tokenizer.model data/wikitext-103-raw 2>/dev/null 17 | echo "=============" 18 | python mlx_data.py --tokenizer_file data/tokenizer.model data/wikitext-103-raw 19 | -------------------------------------------------------------------------------- /benchmarks/comparative/wikitext/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | import time 4 | from collections import defaultdict 5 | 6 | 7 | class Benchmark: 8 | def __init__(self, name): 9 | self.name = name 10 | self.runtimes = defaultdict(list) 11 | self.n_samples = defaultdict(int) 12 | 13 | def log_run(self, run_name, fn, *args, **kwargs): 14 | start = time.time() 15 | n_samples = fn(*args, **kwargs) 16 | end = time.time() 17 | self.runtimes[run_name].append(end - start) 18 | self.n_samples[run_name] = n_samples 19 | 20 | def report(self): 21 | print(f"Benchmark {self.name}") 22 | print() 23 | klengths = max(map(len, self.runtimes.keys())) 24 | table = [] 25 | for k in self.runtimes.keys(): 26 | n_runs = len(self.runtimes[k]) 27 | avg_time = sum(self.runtimes[k]) / n_runs 28 | avg_throughput = self.n_samples[k] / avg_time 29 | table.append( 30 | [ 31 | k, 32 | f"{avg_time:.3f} s", 33 | f"{avg_throughput:.3f} samples/s", 34 | f"{n_runs} runs", 35 | ] 36 | ) 37 | 38 | column_widths = [ 39 | max(len(table[j][i]) for j in range(len(table))) 40 | for i in range(len(table[0])) 41 | ] 42 | for row in table: 43 | for i, (v, w) in enumerate(zip(row, column_widths)): 44 | if i == 0: 45 | print(v, " " * (w - len(v)), " " * 4, end="") 46 | else: 47 | print(" " * (w - len(v)), v, " " * 4, end="") 48 | print() 49 | -------------------------------------------------------------------------------- /benchmarks/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | import time 4 | from collections import defaultdict 5 | 6 | 7 | class Benchmark: 8 | def __init__(self, name): 9 | self.name = name 10 | self.runtimes = defaultdict(list) 11 | self.n_samples = defaultdict(int) 12 | self.run_count = 0 13 | 14 | def log_run(self, run_name, fn, *args, **kwargs): 15 | print(f"Starting run {self.run_count}.", flush=True) 16 | start = time.time() 17 | n_samples = fn(*args, **kwargs) 18 | end = time.time() 19 | self.runtimes[run_name].append(end - start) 20 | self.n_samples[run_name] = n_samples 21 | print("-------------") 22 | self.run_count += 1 23 | 24 | def report(self): 25 | print(f"Benchmark {self.name}") 26 | print() 27 | klengths = max(map(len, self.runtimes.keys())) 28 | table = [] 29 | for k in self.runtimes.keys(): 30 | n_runs = len(self.runtimes[k]) 31 | avg_time = sum(self.runtimes[k]) / n_runs 32 | avg_throughput = self.n_samples[k] / avg_time 33 | table.append( 34 | [ 35 | k, 36 | f"{avg_time:.3f} s", 37 | f"{avg_throughput:.3f} samples/s", 38 | f"{n_runs} runs", 39 | ] 40 | ) 41 | 42 | column_widths = [ 43 | max(len(table[j][i]) for j in range(len(table))) 44 | for i in range(len(table[0])) 45 | ] 46 | for row in table: 47 | for i, (v, w) in enumerate(zip(row, column_widths)): 48 | if i == 0: 49 | print(v, " " * (w - len(v)), " " * 4, end="") 50 | else: 51 | print(" " * (w - len(v)), v, " " * 4, end="") 52 | print() 53 | -------------------------------------------------------------------------------- /cmake/FindFLAC.cmake: -------------------------------------------------------------------------------- 1 | # * Find FLAC Find the native FLAC includes and libraries 2 | # 3 | # Sets the following imported targets if FLAC is found: FLAC::FLAC 4 | # 5 | # Sets the following legacy CMake variables: FLAC_INCLUDE_DIRS - where to find 6 | # FLAC headers. FLAC_LIBRARIES - List of libraries when using libFLAC. 7 | # FLAC_FOUND - True if libFLAC found. FLAC_DEFINITIONS - FLAC compile 8 | # definitons 9 | 10 | if(FLAC_INCLUDE_DIR) 11 | # Already in cache, be silent 12 | set(FLAC_FIND_QUIETLY TRUE) 13 | endif() 14 | 15 | find_package(Ogg QUIET) 16 | 17 | find_package(PkgConfig QUIET) 18 | pkg_check_modules(PC_FLAC QUIET flac) 19 | 20 | set(FLAC_VERSION ${PC_FLAC_VERSION}) 21 | 22 | find_path(FLAC_INCLUDE_DIR FLAC/stream_decoder.h 23 | HINTS ${PC_FLAC_INCLUDEDIR} ${PC_FLAC_INCLUDE_DIRS} ${FLAC_ROOT}) 24 | 25 | # MSVC built libraries can name them *_static, which is good as it distinguishes 26 | # import libraries from static libraries with the same extension. 27 | find_library( 28 | FLAC_LIBRARY 29 | NAMES FLAC libFLAC libFLAC_dynamic libFLAC_static 30 | HINTS ${PC_FLAC_LIBDIR} ${PC_FLAC_LIBRARY_DIRS} ${FLAC_ROOT}) 31 | 32 | # Handle the QUIETLY and REQUIRED arguments and set FLAC_FOUND to TRUE if all 33 | # listed variables are TRUE. 34 | include(FindPackageHandleStandardArgs) 35 | find_package_handle_standard_args( 36 | FLAC 37 | REQUIRED_VARS FLAC_LIBRARY FLAC_INCLUDE_DIR OGG_FOUND 38 | VERSION_VAR FLAC_VERSION) 39 | 40 | if(FLAC_FOUND) 41 | set(FLAC_INCLUDE_DIRS ${FLAC_INCLUDE_DIR}) 42 | set(FLAC_LIBRARIES ${FLAC_LIBRARY} ${OGG_LIBRARIES}) 43 | if(WIN32) 44 | set(FLAC_LIBRARIES ${FLAC_LIBRARIES} wsock32) 45 | get_filename_component(FLAC_LIBRARY_FILENAME ${FLAC_LIBRARY} NAME_WE) 46 | if(FLAC_LIBRARY_FILENAME MATCHES "libFLAC_static") 47 | set(FLAC_DEFINITIONS -DFLAC__NO_DLL) 48 | endif() 49 | endif() 50 | if(NOT TARGET FLAC::FLAC) 51 | add_library(FLAC::FLAC UNKNOWN IMPORTED) 52 | set_target_properties( 53 | FLAC::FLAC 54 | PROPERTIES INTERFACE_INCLUDE_DIRECTORIES "${FLAC_INCLUDE_DIR}" 55 | IMPORTED_LOCATION "${FLAC_LIBRARY}" 56 | INTERFACE_LINK_LIBRARIES Ogg::ogg 57 | $<$:wsock32> INTERFACE_COMPILE_DEFINITIONS 58 | ${FLAC_DEFINITIONS}) 59 | endif() 60 | endif() 61 | 62 | mark_as_advanced(FLAC_INCLUDE_DIR FLAC_LIBRARY) 63 | -------------------------------------------------------------------------------- /cmake/FindOgg.cmake: -------------------------------------------------------------------------------- 1 | # * Find ogg Find the native ogg includes and libraries 2 | # 3 | # OGG_INCLUDE_DIRS - where to find ogg.h, etc. OGG_LIBRARIES - List of 4 | # libraries when using ogg. OGG_FOUND - True if ogg found. 5 | 6 | find_package(Ogg CONFIG QUIET) 7 | 8 | if(NOT TARGET Ogg::ogg) 9 | if(OGG_INCLUDE_DIR) 10 | # Already in cache, be silent 11 | set(OGG_FIND_QUIETLY TRUE) 12 | endif() 13 | 14 | find_package(PkgConfig QUIET) 15 | pkg_check_modules(PC_OGG QUIET ogg) 16 | 17 | set(OGG_VERSION ${PC_OGG_VERSION}) 18 | 19 | find_path(OGG_INCLUDE_DIR ogg/ogg.h HINTS ${PC_OGG_INCLUDEDIR} 20 | ${PC_OGG_INCLUDE_DIRS} ${OGG_ROOT}) 21 | # MSVC built ogg may be named ogg_static. The provided project files name the 22 | # library with the lib prefix. 23 | find_library( 24 | OGG_LIBRARY 25 | NAMES ogg ogg_static libogg libogg_static 26 | HINTS ${PC_OGG_LIBDIR} ${PC_OGG_LIBRARY_DIRS} ${OGG_ROOT}) 27 | # Handle the QUIETLY and REQUIRED arguments and set OGG_FOUND to TRUE if all 28 | # listed variables are TRUE. 29 | include(FindPackageHandleStandardArgs) 30 | find_package_handle_standard_args( 31 | Ogg 32 | REQUIRED_VARS OGG_LIBRARY OGG_INCLUDE_DIR 33 | VERSION_VAR OGG_VERSION) 34 | 35 | if(OGG_FOUND) 36 | set(OGG_LIBRARIES ${OGG_LIBRARY}) 37 | set(OGG_INCLUDE_DIRS ${OGG_INCLUDE_DIR}) 38 | 39 | if(NOT TARGET Ogg::ogg) 40 | add_library(Ogg::ogg UNKNOWN IMPORTED) 41 | set_target_properties( 42 | Ogg::ogg PROPERTIES INTERFACE_INCLUDE_DIRECTORIES "${OGG_INCLUDE_DIRS}" 43 | IMPORTED_LOCATION "${OGG_LIBRARIES}") 44 | endif() 45 | endif() 46 | 47 | mark_as_advanced(OGG_INCLUDE_DIR OGG_LIBRARY) 48 | endif() 49 | -------------------------------------------------------------------------------- /cmake/FindSampleRate.cmake: -------------------------------------------------------------------------------- 1 | # Try to find libsamplerate 2 | # 3 | # Provides the cmake config target - SampleRate::samplerate 4 | # 5 | # Inputs: SampleRate_INC_DIR: include directory for samplerate headers 6 | # SampleRate_LIB_DIR: directory containing samplerate libraries 7 | # SampleRate_ROOT_DIR: directory containing samplerate installation 8 | # 9 | # Defines: SampleRate_FOUND - system has libsamplerate SampleRate_INCLUDE_DIRS - 10 | # the libsamplerate include directory SampleRate_LIBRARIES - Link these to use 11 | # libsamplerate 12 | # 13 | 14 | find_package(SampleRate CONFIG) 15 | 16 | if(NOT TARGET SampleRate::samplerate) 17 | find_path( 18 | SampleRate_INCLUDE_DIR samplerate.h 19 | PATHS ${SampleRate_INC_DIR} ${SampleRate_ROOT_DIR}/include 20 | PATH_SUFFIXES include) 21 | 22 | find_library( 23 | SampleRate_LIBRARY samplerate 24 | PATHS ${SampleRate_LIB_DIR} ${SampleRate_ROOT_DIR} 25 | PATH_SUFFIXES lib 26 | HINTS SAMPLERATE) 27 | 28 | set(SampleRate_INCLUDE_DIRS ${SampleRate_INCLUDE_DIR}) 29 | set(SampleRate_LIBRARIES ${SampleRate_LIBRARY}) 30 | 31 | mark_as_advanced(SampleRate_INCLUDE_DIRS SampleRate_LIBRARIES) 32 | include(FindPackageHandleStandardArgs) 33 | find_package_handle_standard_args( 34 | SampleRate DEFAULT_MSG SampleRate_INCLUDE_DIRS SampleRate_LIBRARIES) 35 | 36 | if(SampleRate_FOUND) 37 | add_library(SampleRate::samplerate UNKNOWN IMPORTED) 38 | set_target_properties( 39 | SampleRate::samplerate 40 | PROPERTIES INTERFACE_INCLUDE_DIRECTORIES "${SampleRate_INCLUDE_DIRS}" 41 | IMPORTED_LOCATION "${SampleRate_LIBRARIES}" 42 | INTERFACE_LINK_LIBRARIES "${SAMPLERATE_DEP_LIBRARIES}") 43 | message( 44 | STATUS 45 | "Found libsamplerate: (lib: ${SampleRate_LIBRARIES} include: ${SampleRate_INCLUDE_DIRS})" 46 | ) 47 | else() 48 | message(STATUS "libsamplerate not found.") 49 | endif() 50 | endif() # NOT TARGET SampleRate::samplerate 51 | -------------------------------------------------------------------------------- /cmake/pybind11-v2.11.1.patch: -------------------------------------------------------------------------------- 1 | diff --git a/CMakeLists.txt b/CMakeLists.txt 2 | index 87ec1034..eaef1a4c 100644 3 | --- a/CMakeLists.txt 4 | +++ b/CMakeLists.txt 5 | @@ -16,6 +16,10 @@ else() 6 | cmake_policy(VERSION 3.26) 7 | endif() 8 | 9 | +if(POLICY CMP0148) 10 | + cmake_policy(SET CMP0148 NEW) 11 | +endif() 12 | + 13 | # Avoid infinite recursion if tests include this as a subdirectory 14 | if(DEFINED PYBIND11_MASTER_PROJECT) 15 | return() 16 | -------------------------------------------------------------------------------- /docs/.gitignore: -------------------------------------------------------------------------------- 1 | build/ 2 | _autosummary 3 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = src 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | ## Build the Docs 2 | 3 | ### Setup (do once) 4 | 5 | Install [sphinx](https://www.sphinx-doc.org/en/master/usage/installation.html) 6 | for example with `conda`: 7 | 8 | ``` 9 | conda install sphinx 10 | pip install sphinx-book-theme 11 | ``` 12 | 13 | ### Build 14 | 15 | Build the docs from `mlx/docs/` 16 | 17 | ``` 18 | make html 19 | ``` 20 | 21 | View the docs by running a server in `mlx/docs/build/html/`: 22 | 23 | ``` 24 | python -m http.server 25 | ``` 26 | 27 | and point your browser to `http://localhost:`. 28 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx 2 | breathe 3 | sphinx-book-theme 4 | mlx-data 5 | -------------------------------------------------------------------------------- /docs/src/_static/mlx_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-explore/mlx-data/8c6d8b3efa6458812c8c3a70a87806b2bc1c056c/docs/src/_static/mlx_logo.png -------------------------------------------------------------------------------- /docs/src/_static/mlx_logo_dark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-explore/mlx-data/8c6d8b3efa6458812c8c3a70a87806b2bc1c056c/docs/src/_static/mlx_logo_dark.png -------------------------------------------------------------------------------- /docs/src/_templates/data_core_modules.rst: -------------------------------------------------------------------------------- 1 | {{ fullname | escape | underline}} 2 | 3 | .. currentmodule:: {{ module }} 4 | 5 | .. autoclass:: {{ objname }} 6 | 7 | {% block methods %} 8 | 9 | {% if methods %} 10 | .. rubric:: {{ _('Methods') }} 11 | 12 | .. autosummary:: 13 | :toctree: _autosummary 14 | {% for item in methods %} 15 | {%- if item not in inherited_members %} 16 | ~{{ name }}.{{ item }} 17 | {%- endif %} 18 | {%- endfor %} 19 | {% endif %} 20 | {% endblock %} 21 | 22 | -------------------------------------------------------------------------------- /docs/src/conf.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | # -*- coding: utf-8 -*- 4 | 5 | import os 6 | import subprocess 7 | 8 | # -- Project information ----------------------------------------------------- 9 | 10 | project = "MLX Data" 11 | copyright = "2023, MLX Contributors" 12 | author = "MLX Contributors" 13 | version = "0.0.2" 14 | release = "0.0.2" 15 | 16 | # -- General configuration --------------------------------------------------- 17 | 18 | extensions = [ 19 | "sphinx.ext.autodoc", 20 | "sphinx.ext.autosummary", 21 | "sphinx.ext.intersphinx", 22 | "sphinx.ext.napoleon", 23 | ] 24 | 25 | autosummary_generate = True 26 | 27 | intersphinx_mapping = { 28 | "python": ("https://docs.python.org/3", None), 29 | "numpy": ("https://numpy.org/doc/stable/", None), 30 | } 31 | 32 | templates_path = ["_templates"] 33 | html_static_path = ["_static"] 34 | source_suffix = ".rst" 35 | master_doc = "index" 36 | highlight_language = "python" 37 | pygments_style = "sphinx" 38 | 39 | # -- Options for HTML output ------------------------------------------------- 40 | 41 | html_theme = "sphinx_book_theme" 42 | 43 | html_theme_options = { 44 | "show_toc_level": 2, 45 | "repository_url": "https://github.com/ml-explore/mlx-data", 46 | "use_repository_button": True, 47 | "navigation_with_keys": False, 48 | "logo": { 49 | "image_light": "_static/mlx_logo.png", 50 | "image_dark": "_static/mlx_logo_dark.png", 51 | }, 52 | } 53 | 54 | # -- Options for HTMLHelp output --------------------------------------------- 55 | 56 | htmlhelp_basename = "mlx_data_doc" 57 | -------------------------------------------------------------------------------- /docs/src/index.rst: -------------------------------------------------------------------------------- 1 | MLX Data 2 | ======== 3 | 4 | MLX Data is a framework agnostic data loading library brought to you by Apple 5 | machine learning research. 6 | 7 | MLX Data can be used to load data for machine learning training or on its own 8 | for data pre-processing. You can use it with PyTorch, Jax or `MLX 9 | `_. 10 | 11 | The goal of this library is to allow users to leverage multiple threads for 12 | data processing pipelines without the inflexibility of dealing with multiple 13 | processes or having to write in a symbolic language. 14 | 15 | .. note:: 16 | In MLX Data pipelines you can use Python to process data, implement logic or cause side effects! 17 | 18 | 19 | .. toctree:: 20 | :caption: Install 21 | :maxdepth: 1 22 | 23 | install 24 | 25 | .. toctree:: 26 | :caption: Usage 27 | :maxdepth: 1 28 | 29 | quick_start 30 | buffers_streams_samples 31 | hf_datasets_streams 32 | 33 | .. toctree:: 34 | :caption: Python API Reference 35 | :maxdepth: 1 36 | 37 | python/dataset 38 | python/buffer 39 | python/stream 40 | python/common_datasets 41 | python/tokenizing 42 | python/features 43 | python/miscellaneous 44 | -------------------------------------------------------------------------------- /docs/src/python/buffer.rst: -------------------------------------------------------------------------------- 1 | .. _buffer: 2 | 3 | Buffer 4 | ====== 5 | 6 | .. currentmodule:: mlx.data 7 | 8 | As also mentioned in :ref:`Buffers, Streams and Samples ` 9 | a :class:`Buffer` is an indexable container of 10 | samples. Using a buffer in python should feel very similar to accessing a list 11 | of samples. 12 | 13 | .. code-block:: python 14 | 15 | import mlx.data as dx 16 | 17 | numbers = dx.buffer_from_vector([{"x": i} for i in range(10)]) 18 | evens = numbers.key_transform("x", lambda x: 2*x) 19 | 20 | print(evens) 21 | # prints Buffer(size=10, keys={'x'}) 22 | 23 | print(evens[3]) 24 | # prints {'x': array(6)} 25 | 26 | print(len(evens)) 27 | # prints 10 28 | 29 | Factory methods 30 | --------------- 31 | 32 | We provide the following factory methods to create a buffer. 33 | 34 | .. autosummary:: 35 | :toctree: _autosummary 36 | 37 | buffer_from_vector 38 | files_from_tar 39 | 40 | Buffer specific API 41 | ------------------- 42 | 43 | The random access characteristics of a ``Buffer`` allow us to define some 44 | transformations that cannot be implemented or do not make sense for a 45 | :class:`Stream`. 46 | 47 | .. autosummary:: 48 | :toctree: _autosummary 49 | 50 | Buffer.ordered_prefetch 51 | Buffer.partition 52 | Buffer.perm 53 | Buffer.shuffle 54 | Buffer.to_stream 55 | -------------------------------------------------------------------------------- /docs/src/python/features.rst: -------------------------------------------------------------------------------- 1 | Feature extraction 2 | ================== 3 | 4 | This submodule provides some feature extraction utilities that can be used as 5 | ``key_transform`` functions in MLX data pipelines. Even though a C++ 6 | implementation would allow for completely circumventing the GIL and better 7 | utilization of multiple threads, we find that an efficient numpy implementation 8 | can often be fast enough while providing signficiantly more flexibility. 9 | 10 | .. currentmodule:: mlx.data.features 11 | 12 | Audio Features 13 | -------------- 14 | 15 | .. autosummary:: 16 | :toctree: _autosummary 17 | 18 | WindowType 19 | FrequencyScale 20 | mfsc 21 | -------------------------------------------------------------------------------- /docs/src/python/miscellaneous.rst: -------------------------------------------------------------------------------- 1 | Miscellaneous 2 | ============== 3 | 4 | .. currentmodule:: mlx.data 5 | 6 | FileFetcher 7 | ----------- 8 | 9 | Several functions in MLX data can make use of a :class:`FileFetcher` object to 10 | fetch files from a remote location. See the :ref:`installation instructions ` to build MLX data with AWS support which adds the 11 | :class:`core.AWSFileFetcher` described below. 12 | 13 | .. autosummary:: 14 | :toctree: _autosummary 15 | :recursive: 16 | 17 | core.AWSFileFetcher.__init__ 18 | core.AWSFileFetcher.fetch 19 | core.AWSFileFetcher.prefetch 20 | 21 | A :class:`FileFetcher` can also be used standalone in your scripts to 22 | efficiently fetch remote content in background threads. 23 | 24 | .. code-block:: python 25 | 26 | from pathlib import Path 27 | from mlx.data.core import AWSFileFetcher 28 | 29 | LOCAL_CACHE = Path("/path/to/local/cache") 30 | 31 | ff = AWSFileFetcher( 32 | "my-cool-bucket", 33 | endpoint="https://my.endpoint.com/" 34 | local_prefix=LOCAL_CACHE, 35 | num_kept_files=100, 36 | ) 37 | 38 | # When fetch returns my/remote/path/foo.npy will be in LOCAL_CACHE 39 | ff.fetch("my/remote/path/foo.npy") 40 | assert (LOCAL_CACHE / "my/remote/path/foo.npy").is_file() 41 | 42 | # We can prefetch in the background 43 | ff.prefetch(["foo_1.npy", "foo_2.npy"]) 44 | ff.fetch("foo_1.npy") 45 | # process foo_1 while foo_2 downloads in the background 46 | -------------------------------------------------------------------------------- /docs/src/python/stream.rst: -------------------------------------------------------------------------------- 1 | .. _stream: 2 | 3 | Stream 4 | ====== 5 | 6 | .. currentmodule:: mlx.data 7 | 8 | Using a :class:`Stream` in python should feel like accessing an iterator of 9 | samples. Only the next sample can be fetched and the iteration may be restarted 10 | depending on the underlying source of the data (some trully online sources are 11 | not resettable). 12 | 13 | .. code-block:: python 14 | 15 | import mlx.data as dx 16 | 17 | # The samples are never all instantiated 18 | numbers = dx.stream_python_iterable(lambda: ({"x": i} for i in range(10**10))) 19 | 20 | # Filtering is done with transforms returning an empty sample 21 | evens = numbers.sample_transform(lambda s: s if s["x"] % 2 == 0 else dict()) 22 | 23 | print(next(numbers)) 24 | # prints {'x': array(0)} 25 | print(next(numbers)) 26 | # prints {'x': array(1)} 27 | 28 | # Streams are pointers to the streams so evens is using numbers under the 29 | # hood. Since numbers was advanced now evens is advanced as well. 30 | print(next(evens)) 31 | # prints {'x': array(2)} 32 | print(next(evens)) 33 | # prints {'x': array(4)} 34 | print(next(numbers)) 35 | # prints {'x': array(5)} 36 | 37 | # Streams can be reset. 38 | evens.reset() 39 | print(next(evens)) 40 | print(next(evens)) 41 | print(next(numbers)) 42 | # prints {'x': array(0)} 43 | # {'x': array(2)} 44 | # {'x': array(3)} 45 | 46 | 47 | Factory methods 48 | --------------- 49 | 50 | We provide the following factory methods to create a stream. When used from 51 | python the most interesting one is probably :func:`stream_python_iterable`. Of 52 | course another good strategy is to start from a :class:`Buffer` that you then 53 | cast to a stream using :meth:`Buffer.to_stream`. 54 | 55 | .. autosummary:: 56 | :toctree: _autosummary 57 | 58 | stream_csv_reader 59 | stream_csv_reader_from_string 60 | stream_line_reader 61 | stream_python_iterable 62 | 63 | Stream specific API 64 | ------------------- 65 | 66 | :class:`Stream` has a more powerful API than :class:`Buffer`. It does not allow 67 | for random access, however, it allows for stream composing and prefetching. 68 | Stream composing is when a sample becomes the beginning of a new stream that 69 | can have arbitrary length. 70 | 71 | Streams also allow for filtering using the provided functions or a 72 | :meth:`Stream.sample_transform` that returns an empty dictionary (an empty 73 | Sample). 74 | 75 | .. autosummary:: 76 | :toctree: _autosummary 77 | 78 | Stream.csv_reader_from_key 79 | Stream.line_reader_from_key 80 | Stream.dynamic_batch 81 | Stream.partition 82 | Stream.buffered 83 | Stream.repeat 84 | Stream.shuffle 85 | Stream.sliding_window 86 | Stream.prefetch 87 | -------------------------------------------------------------------------------- /docs/src/python/tokenizing.rst: -------------------------------------------------------------------------------- 1 | Tokenizing with MLX data 2 | ======================== 3 | 4 | .. currentmodule:: mlx.data 5 | 6 | MLX data allows sample transformations with the full flexibility of python which 7 | means that you could use any python tokenizer in a 8 | :meth:`Buffer.key_transform`. However, this is likely to be subject to the GIL 9 | which means that effectively only one sample can be tokenized at a time. 10 | 11 | A better choice is to use an :class:`mlx.data.core.CharTrie` to tokenize your 12 | data, taking full advatage of a multicore system. You can build the trie 13 | yourself or use one of the provided helpers to build a trie from an SentencePiece model 14 | or a plain text vocabulary file. 15 | 16 | .. code-block:: python 17 | 18 | from mlx.data.core import CharTrie, Tokenizer 19 | 20 | # We can build a trie ourselves 21 | trie = CharTrie() 22 | for t in b"a quick brown fox jumped over the lazy dog".split(): 23 | trie.insert(t) 24 | trie.insert(b" ") 25 | 26 | tokenizer = Tokenizer(trie) 27 | print(tokenizer.tokenize_shortest(b"a quick brown fox jumped over the lazy dog")) 28 | # [0, 9, 1, 9, 2, 9, 3, 9, 4, 9, 5, 9, 6, 9, 7, 9, 8] 29 | 30 | # We can also add all the letters in the trie and then tokenize anything we want 31 | import string 32 | for l in string.ascii_letters: 33 | trie.insert(bytes(l, "utf-8")) 34 | 35 | print(tokenizer.tokenize_shortest(b"This is a quick example")) 36 | # [54, 16, 17, 27, 9, 17, 27, 9, 0, 9, 1, 9, 13, 32, 0, 21, 24, 20, 13] 37 | 38 | # The more useful option is to read the trie from a file, for instance an spm model 39 | from mlx.data.tokenizer_helpers import read_trie_from_spm 40 | 41 | trie, weights = read_trie_from_spm("path/to/spm/model") 42 | tokenizer = Tokenizer(trie, trie_key_scores=weights) 43 | tokenizer.tokenize_shortest(b"This is some more text to tokenize") 44 | 45 | 46 | .. autosummary:: 47 | :toctree: _autosummary 48 | :template: data_core_modules.rst 49 | :recursive: 50 | 51 | core.Tokenizer 52 | core.CharTrie 53 | tokenizer_helpers.read_trie_from_vocab 54 | tokenizer_helpers.read_trie_from_spm 55 | -------------------------------------------------------------------------------- /docs/src/quick_start.rst: -------------------------------------------------------------------------------- 1 | Quick Start Guide 2 | ================= 3 | 4 | Load some data 5 | -------------- 6 | 7 | In MLX data all samples are dictionaries of arrays. The library provides 8 | functions to download and iterate over some common datasets but the goal is to 9 | **provide functions that allow the user to load and process on the fly their 10 | own datasets**. 11 | 12 | Let's start with the simplest example on MNIST. 13 | 14 | .. code-block:: python 15 | 16 | # This is the standard way to import and access mlx.data 17 | import mlx.data as dx 18 | 19 | # Let's import MNIST loading 20 | from mlx.data.datasets import load_mnist 21 | 22 | # Loads a buffer with the MNIST images 23 | mnist_train = load_mnist(train=True) 24 | 25 | # Let's shuffle flatten and batch to prepare for MLP training 26 | mnist_mlp = ( 27 | mnist_train 28 | .shuffle() 29 | .to_stream() 30 | .key_transform("image", lambda x: x.astype("float32").reshape(-1)) 31 | .batch(32) 32 | .prefetch(4, 2) 33 | ) 34 | 35 | # Now we can iterate over the batches in normal python 36 | for batch in mnist_mlp: 37 | x, y = batch["image"], batch["label"] 38 | 39 | 40 | MLX Data provides many :ref:`operations ` that transform 41 | samples so you can create arbitrarily complex pipelines. 42 | 43 | 44 | About the GIL 45 | ------------- 46 | 47 | Python functions called by MLX data still run under the Global Interpreter 48 | Lock. To avoid serializing your data pipeline, either drop into Numpy or some 49 | other optimized library as quickly as possible or limit the processing time of 50 | the python part of the data pipeline. 51 | 52 | We would advise, however, to avoid premature optimization and only try to 53 | reduce GIL overhead if you are certain that it is limiting your data processing 54 | pipeline. 55 | 56 | The following are examples where we would use a python function for flexibility 57 | rather than have a specific C++ transformation. 58 | 59 | .. code-block:: python 60 | 61 | # Normalizing images in [0, 1] 62 | dset = dset.key_transform("image", lambda x: x.astype("float32") / 255) 63 | 64 | # Extracting mel spectrogram features 65 | # A big chunk of the time is spent computing the FFT which is done with the GIL off so... 66 | from mlx.data.features import mfsc 67 | dset = dset.key_transform("audio", mfsc(n_filterbank=80, sampling_freq=16000)) 68 | 69 | # Filter stream samples based on values (empty dict means drop the sample) 70 | dset = dset.sample_transform(lambda s: s if s["length"] > 10 else dict()) 71 | -------------------------------------------------------------------------------- /mlx-data.pc.in: -------------------------------------------------------------------------------- 1 | # Find MLX Data 2 | # 3 | # Defines the following variables: 4 | # 5 | # MLX_DATA_FOUND : True if MLX Data is found 6 | # MLX_DATA_INCLUDE_DIRS : Include directory 7 | # MLX_DATA_LIBRARIES : Libraries to link against 8 | # MLX_DATA_CXX_FLAGS : Additional compiler flags 9 | 10 | @PACKAGE_INIT@ 11 | 12 | include(@PACKAGE_MLX_DATA_CMAKE_INSTALL_MODULE_DIR@/MLXDataTargets.cmake) 13 | 14 | set_and_check(MLX_DATA_LIBRARY_DIRS @PACKAGE_CMAKE_INSTALL_LIBDIR@) 15 | set_and_check(MLX_DATA_INCLUDE_DIRS @PACKAGE_CMAKE_INSTALL_INCLUDEDIR@) 16 | set(MLX_DATA_LIBRARIES mlxdata) 17 | 18 | find_library(MLX_DATA_LIBRARY mlxdata PATHS ${MLX_DATA_LIBRARY_DIRS}) 19 | 20 | set_target_properties(mlx PROPERTIES 21 | CXX_STANDARD 17 22 | INTERFACE_COMPILE_OPTIONS "${MLX_DATA_CXX_FLAGS}" 23 | ) 24 | 25 | include(FindPackageHandleStandardArgs) 26 | find_package_handle_standard_args(MLX_DATA DEFAULT_MSG MLX_DATA_LIBRARY MLX_DATA_INCLUDE_DIRS) 27 | -------------------------------------------------------------------------------- /mlx/data/Buffer.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include "mlx/data/Dataset.h" 6 | #include "mlx/data/buffer/Buffer.h" 7 | 8 | namespace mlx { 9 | namespace data { 10 | 11 | // Forward declaration of Stream so we can define toStream(). 12 | class Stream; 13 | 14 | class Buffer : public Dataset { 15 | public: 16 | Buffer(const std::shared_ptr& self); 17 | 18 | Sample get(int64_t idx) const; 19 | int64_t size() const; 20 | 21 | Buffer batch( 22 | int64_t batch_size, 23 | const std::unordered_map& pad_values = {}, 24 | const std::unordered_map& batch_dims = {}) const; 25 | Buffer batch( 26 | const std::vector& batch_sizes, 27 | const std::unordered_map& pad_values = {}, 28 | const std::unordered_map& batch_dims = {}) const; 29 | 30 | Buffer dynamic_batch( 31 | const std::string& key, 32 | int64_t max_data_size = 0, // batch everything if <= 0 33 | const std::unordered_map& pad_values = {}, 34 | const std::unordered_map& batch_dims = {}) const; 35 | 36 | Buffer dynamic_batch( 37 | const Buffer& size_buffer, 38 | const std::string& key, 39 | int64_t max_data_size, 40 | const std::unordered_map& pad_values = {}, 41 | const std::unordered_map& batch_dims = {}) const; 42 | 43 | Stream ordered_prefetch(int prefetch_size, int num_thread) const; 44 | 45 | Buffer partition(int64_t num_partitions, int64_t partition) const; 46 | Buffer partition_if(bool cond, int64_t num_partitions, int64_t partition) 47 | const; 48 | 49 | Buffer perm(const std::vector& perm); 50 | 51 | Buffer shuffle(); 52 | Buffer shuffle_if(bool cond); 53 | 54 | Stream to_stream(); 55 | 56 | friend class Stream; 57 | }; 58 | 59 | Buffer buffer_from_vector(const std::vector& data); 60 | Buffer buffer_from_vector(std::vector&& data); 61 | Buffer files_from_tar( 62 | const std::string& tarfile, 63 | bool nested = false, 64 | int num_threads = 1); 65 | 66 | } // namespace data 67 | } // namespace mlx 68 | -------------------------------------------------------------------------------- /mlx/data/Sample.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include 4 | 5 | #include "mlx/data/Sample.h" 6 | 7 | namespace mlx { 8 | namespace data { 9 | namespace sample { 10 | std::vector keys(const Sample& dict) { 11 | std::vector keys; 12 | for (auto& kv : dict) { 13 | keys.push_back(kv.first); 14 | } 15 | return keys; 16 | } 17 | 18 | std::shared_ptr 19 | check_key(const Sample& input, const std::string& key, ArrayType type) { 20 | auto it = input.find(key); 21 | if (it == input.end()) { 22 | throw std::runtime_error("key <" + key + "> expected"); 23 | } 24 | auto value = it->second; 25 | if (type != ArrayType::Any && value->type() != type) { 26 | throw std::runtime_error("invalid Array type"); 27 | } 28 | return value; 29 | } 30 | } // namespace sample 31 | } // namespace data 32 | } // namespace mlx 33 | -------------------------------------------------------------------------------- /mlx/data/Sample.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include 6 | #include 7 | 8 | #include "mlx/data/Array.h" 9 | 10 | namespace mlx { 11 | namespace data { 12 | 13 | typedef std::unordered_map> Sample; 14 | 15 | namespace sample { 16 | std::vector keys(const Sample& dict); 17 | std::shared_ptr 18 | check_key(const Sample& input, const std::string& key, ArrayType type); 19 | 20 | } // namespace sample 21 | } // namespace data 22 | } // namespace mlx 23 | -------------------------------------------------------------------------------- /mlx/data/buffer/Batch.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include "mlx/data/buffer/Batch.h" 4 | #include "mlx/data/core/Utils.h" 5 | 6 | namespace mlx { 7 | namespace data { 8 | namespace buffer { 9 | Batch::Batch( 10 | const std::shared_ptr& op, 11 | int64_t batch_size, 12 | const std::unordered_map& pad_values, 13 | const std::unordered_map& batch_dims) 14 | : op_(op), 15 | batchSize_(batch_size), 16 | padValues_(pad_values), 17 | batchDims_(batch_dims) { 18 | if (batch_size <= 0) { 19 | throw std::runtime_error("Batch: batch size must be positive"); 20 | } 21 | size_ = op->size() / batch_size; 22 | if (op->size() % batch_size) { 23 | size_++; 24 | } 25 | } 26 | Batch::Batch( 27 | const std::shared_ptr& op, 28 | const std::vector& batch_sizes, 29 | const std::unordered_map& pad_values, 30 | const std::unordered_map& batch_dims) 31 | : op_(op), 32 | batchSize_(0), 33 | batchOffsets_(batch_sizes.size()), 34 | batchSizes_(batch_sizes), 35 | padValues_(pad_values), 36 | batchDims_(batch_dims) { 37 | int64_t batch_sizes_sum = 0; 38 | for (int64_t i = 0; i < batch_sizes.size(); i++) { 39 | auto batch_size = batch_sizes[i]; 40 | if (batch_size <= 0) { 41 | throw std::runtime_error("Batch: batch size must be positive"); 42 | } 43 | batchOffsets_[i] = batch_sizes_sum; 44 | batch_sizes_sum += batch_size; 45 | } 46 | if (batch_sizes_sum > op->size()) { 47 | throw std::runtime_error("Batch: sum of batch sizes exceeds buffer size"); 48 | } 49 | size_ = batch_sizes.size(); 50 | } 51 | 52 | Sample Batch::get(int64_t idx) const { 53 | if (idx < 0 || idx >= size_) { 54 | throw std::runtime_error("Batch: index out of range"); 55 | } 56 | auto batch_size = 57 | (batchSize_ ? std::min(batchSize_, op_->size() - idx * batchSize_) 58 | : batchSizes_[idx]); 59 | auto batch_offset = (batchSize_ ? idx * batchSize_ : batchOffsets_[idx]); 60 | std::vector samples(batch_size); 61 | for (int64_t i = 0; i < batch_size; i++) { 62 | samples[i] = op_->get(batch_offset + i); 63 | } 64 | return core::merge_batch(samples, padValues_, batchDims_); 65 | } 66 | 67 | int64_t Batch::size() const { 68 | return size_; 69 | } 70 | 71 | } // namespace buffer 72 | } // namespace data 73 | } // namespace mlx 74 | -------------------------------------------------------------------------------- /mlx/data/buffer/Batch.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include "mlx/data/buffer/Buffer.h" 6 | 7 | namespace mlx { 8 | namespace data { 9 | namespace buffer { 10 | 11 | class Batch : public Buffer { 12 | public: 13 | Batch( 14 | const std::shared_ptr& op, 15 | int64_t batch_size, 16 | const std::unordered_map& pad_values = {}, 17 | const std::unordered_map& batch_dims = {}); 18 | Batch( 19 | const std::shared_ptr& op, 20 | const std::vector& batch_sizes, 21 | const std::unordered_map& pad_values = {}, 22 | const std::unordered_map& batch_dims = {}); 23 | 24 | virtual Sample get(int64_t idx) const override; 25 | virtual int64_t size() const override; 26 | 27 | private: 28 | std::shared_ptr op_; 29 | int64_t batchSize_; 30 | std::vector batchOffsets_; 31 | std::vector batchSizes_; 32 | std::unordered_map padValues_; 33 | std::unordered_map batchDims_; 34 | int64_t size_; 35 | }; 36 | 37 | } // namespace buffer 38 | } // namespace data 39 | } // namespace mlx 40 | -------------------------------------------------------------------------------- /mlx/data/buffer/Buffer.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include 4 | 5 | #include "mlx/data/buffer/Buffer.h" 6 | 7 | namespace mlx { 8 | namespace data { 9 | namespace buffer { 10 | 11 | Sample Buffer::get(const int64_t idx) const { 12 | throw std::runtime_error("Buffer::get() NYI"); 13 | } 14 | 15 | int64_t Buffer::size() const { 16 | throw std::runtime_error("Buffer::size() NYI"); 17 | } 18 | 19 | Buffer::~Buffer() {} 20 | 21 | } // namespace buffer 22 | } // namespace data 23 | } // namespace mlx 24 | -------------------------------------------------------------------------------- /mlx/data/buffer/Buffer.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include "mlx/data/Sample.h" 6 | 7 | namespace mlx { 8 | namespace data { 9 | namespace buffer { 10 | 11 | class Buffer { 12 | public: 13 | Buffer() {}; 14 | 15 | // User-specific 16 | virtual Sample get(int64_t idx) const; 17 | virtual int64_t size() const; 18 | 19 | virtual ~Buffer(); 20 | }; 21 | 22 | } // namespace buffer 23 | } // namespace data 24 | } // namespace mlx 25 | -------------------------------------------------------------------------------- /mlx/data/buffer/DynamicBatch.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include 6 | #include "mlx/data/buffer/Batch.h" 7 | 8 | namespace mlx { 9 | namespace data { 10 | namespace buffer { 11 | 12 | class DynamicBatch : public Batch { 13 | public: 14 | DynamicBatch( 15 | const std::shared_ptr& buffer, 16 | const std::string& key, 17 | int64_t max_data_size = 0, // batch everything if <= 0 18 | const std::unordered_map& pad_values = {}, 19 | const std::unordered_map& batch_dims = {}); 20 | 21 | DynamicBatch( 22 | const std::shared_ptr& buffer, 23 | const std::shared_ptr& ref_size_buffer, 24 | const std::string& key, 25 | int64_t max_data_size, 26 | const std::unordered_map& pad_values = {}, 27 | const std::unordered_map& batch_dims = {}); 28 | 29 | private: 30 | DynamicBatch( 31 | std::pair, std::vector> 32 | buffer_with_sizes, 33 | const std::unordered_map& pad_values, 34 | const std::unordered_map& batch_dims); 35 | 36 | // returns sorted buffer with number of samples for each batch 37 | static std::pair, std::vector> 38 | dynamic_batch_( 39 | const std::shared_ptr& buffer, 40 | const std::shared_ptr& ref_sizebuffer, 41 | const std::string& key, 42 | int64_t max_data_size, 43 | const std::unordered_map& batch_dims); 44 | }; 45 | 46 | } // namespace buffer 47 | } // namespace data 48 | } // namespace mlx 49 | -------------------------------------------------------------------------------- /mlx/data/buffer/FilesFromTAR.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include "mlx/data/buffer/FilesFromTAR.h" 4 | #include "mlx/data/Array.h" 5 | #include "mlx/data/Sample.h" 6 | #include "mlx/data/core/TARReader.h" 7 | 8 | namespace mlx { 9 | namespace data { 10 | namespace buffer { 11 | 12 | FilesFromTAR::FilesFromTAR( 13 | const std::string& tarfile, 14 | bool nested, 15 | int num_threads) { 16 | core::TARReader tarreader(tarfile, nested, num_threads); 17 | files_ = tarreader.get_file_list(); 18 | } 19 | 20 | Sample FilesFromTAR::get(int64_t idx) const { 21 | if (idx < 0 || idx >= files_.size()) { 22 | throw std::runtime_error("FilesFromTAR: index out of range"); 23 | } 24 | Sample res; 25 | res["file"] = std::make_shared(files_[idx]); 26 | return res; 27 | } 28 | 29 | int64_t FilesFromTAR::size() const { 30 | return files_.size(); 31 | } 32 | 33 | } // namespace buffer 34 | } // namespace data 35 | } // namespace mlx 36 | -------------------------------------------------------------------------------- /mlx/data/buffer/FilesFromTAR.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include "mlx/data/buffer/Buffer.h" 6 | 7 | namespace mlx { 8 | namespace data { 9 | namespace buffer { 10 | 11 | class FilesFromTAR : public Buffer { 12 | public: 13 | FilesFromTAR( 14 | const std::string& tarfile, 15 | bool nested = false, 16 | int num_threads = 1); 17 | 18 | Sample get(int64_t idx) const override; 19 | virtual int64_t size() const override; 20 | 21 | private: 22 | std::vector files_; 23 | }; 24 | 25 | } // namespace buffer 26 | } // namespace data 27 | } // namespace mlx 28 | -------------------------------------------------------------------------------- /mlx/data/buffer/FromStream.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include "mlx/data/buffer/FromStream.h" 4 | 5 | namespace mlx { 6 | namespace data { 7 | namespace buffer { 8 | 9 | FromStream::FromStream( 10 | const std::shared_ptr& stream, 11 | int64_t size) 12 | : FromVector(bufferize_(stream, size)) {} 13 | 14 | std::vector FromStream::bufferize_( 15 | std::shared_ptr stream, 16 | int64_t size) { 17 | std::vector buffer; 18 | if (size > 0) { 19 | buffer.reserve(size); 20 | } 21 | for (int64_t i = 0; (size < 0) || (i < size); i++) { 22 | auto sample = stream->next(); 23 | if (sample.empty()) { 24 | break; 25 | } 26 | buffer.push_back(sample); 27 | } 28 | return buffer; 29 | } 30 | 31 | } // namespace buffer 32 | } // namespace data 33 | } // namespace mlx 34 | -------------------------------------------------------------------------------- /mlx/data/buffer/FromStream.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include "mlx/data/buffer/FromVector.h" 6 | #include "mlx/data/stream/Stream.h" 7 | 8 | namespace mlx { 9 | namespace data { 10 | namespace buffer { 11 | 12 | class FromStream : public FromVector { 13 | public: 14 | FromStream(const std::shared_ptr& stream, int64_t size = -1); 15 | 16 | private: 17 | static std::vector bufferize_( 18 | std::shared_ptr stream, 19 | int64_t size); 20 | }; 21 | 22 | } // namespace buffer 23 | } // namespace data 24 | } // namespace mlx 25 | -------------------------------------------------------------------------------- /mlx/data/buffer/FromVector.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include "mlx/data/buffer/FromVector.h" 4 | 5 | namespace mlx { 6 | namespace data { 7 | namespace buffer { 8 | 9 | FromVector::FromVector(const std::vector& data) : buffer_(data) { 10 | check_samples_(); 11 | } 12 | 13 | FromVector::FromVector(std::vector&& data) : buffer_(std::move(data)) { 14 | check_samples_(); 15 | } 16 | 17 | Sample FromVector::get(int64_t idx) const { 18 | if (idx < 0 || idx >= buffer_.size()) { 19 | throw std::out_of_range("FromVector: index out of range"); 20 | } 21 | return buffer_[idx]; 22 | } 23 | 24 | int64_t FromVector::size() const { 25 | return buffer_.size(); 26 | } 27 | 28 | void FromVector::check_samples_() const { 29 | for (auto& sample : buffer_) { 30 | if (sample.empty()) { 31 | throw std::runtime_error("FromVector: unexpected empty sample"); 32 | } 33 | } 34 | } 35 | 36 | } // namespace buffer 37 | } // namespace data 38 | } // namespace mlx 39 | -------------------------------------------------------------------------------- /mlx/data/buffer/FromVector.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include "mlx/data/buffer/Buffer.h" 6 | 7 | namespace mlx { 8 | namespace data { 9 | namespace buffer { 10 | 11 | class FromVector : public Buffer { 12 | public: 13 | FromVector(const std::vector& data); 14 | FromVector(std::vector&& data); 15 | 16 | Sample get(int64_t idx) const override; 17 | virtual int64_t size() const override; 18 | 19 | private: 20 | void check_samples_() const; 21 | std::vector buffer_; 22 | }; 23 | 24 | } // namespace buffer 25 | } // namespace data 26 | } // namespace mlx 27 | -------------------------------------------------------------------------------- /mlx/data/buffer/Partition.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include "mlx/data/buffer/Partition.h" 4 | 5 | namespace mlx { 6 | namespace data { 7 | namespace buffer { 8 | 9 | Partition::Partition( 10 | std::shared_ptr buffer, 11 | int64_t num_partitions, 12 | int64_t partition) 13 | : buffer_(buffer), numPartitions_(num_partitions), partition_(partition) { 14 | if (num_partitions < 0) { 15 | throw std::runtime_error( 16 | "Partition: number of partitions must be positive"); 17 | } 18 | if (partition < 0 || partition >= num_partitions) { 19 | throw std::runtime_error("Partition: selected partition is out of range"); 20 | } 21 | size_ = buffer->size() / num_partitions; 22 | if (partition_ < (buffer->size() % num_partitions)) { 23 | size_++; 24 | } 25 | } 26 | 27 | Sample Partition::get(int64_t idx) const { 28 | if (idx < 0 || idx > size_) { 29 | throw std::runtime_error("Partition: index out of range"); 30 | } 31 | return buffer_->get(idx * numPartitions_ + partition_); 32 | } 33 | 34 | int64_t Partition::size() const { 35 | return size_; 36 | } 37 | 38 | } // namespace buffer 39 | } // namespace data 40 | } // namespace mlx 41 | -------------------------------------------------------------------------------- /mlx/data/buffer/Partition.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include "mlx/data/buffer/Buffer.h" 6 | 7 | namespace mlx { 8 | namespace data { 9 | namespace buffer { 10 | 11 | class Partition : public Buffer { 12 | public: 13 | Partition( 14 | std::shared_ptr buffer, 15 | int64_t num_partitions, 16 | int64_t partition); 17 | 18 | virtual Sample get(int64_t idx) const override; 19 | virtual int64_t size() const override; 20 | 21 | private: 22 | std::shared_ptr buffer_; 23 | int64_t numPartitions_; 24 | int64_t partition_; 25 | int64_t size_; 26 | }; 27 | 28 | } // namespace buffer 29 | } // namespace data 30 | } // namespace mlx 31 | -------------------------------------------------------------------------------- /mlx/data/buffer/Perm.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include "mlx/data/buffer/Perm.h" 4 | 5 | namespace mlx { 6 | namespace data { 7 | namespace buffer { 8 | 9 | Perm::Perm(const std::shared_ptr& op, const std::vector& perm) 10 | : op_(op) { 11 | set_perm_(perm); 12 | } 13 | 14 | Sample Perm::get(int64_t idx) const { 15 | if (idx < 0 || idx >= perm_.size()) { 16 | throw std::runtime_error("Perm: index out of range"); 17 | } 18 | return op_->get(perm_[idx]); 19 | } 20 | 21 | int64_t Perm::size() const { 22 | return perm_.size(); 23 | } 24 | 25 | void Perm::set_perm_(const std::vector& perm) { 26 | for (auto idx : perm) { 27 | if (idx < 0 || idx >= op_->size()) { 28 | throw std::runtime_error("Perm: permutation index out of range"); 29 | } 30 | } 31 | perm_ = perm; 32 | } 33 | 34 | const std::vector& Perm::get_perm() { 35 | return perm_; 36 | } 37 | 38 | } // namespace buffer 39 | } // namespace data 40 | } // namespace mlx 41 | -------------------------------------------------------------------------------- /mlx/data/buffer/Perm.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include "mlx/data/buffer/Buffer.h" 6 | 7 | namespace mlx { 8 | namespace data { 9 | namespace buffer { 10 | 11 | class Perm : public Buffer { 12 | public: 13 | Perm(const std::shared_ptr& op, const std::vector& perm); 14 | 15 | Sample get(int64_t idx) const override; 16 | virtual int64_t size() const override; 17 | 18 | const std::vector& get_perm(); 19 | 20 | private: 21 | std::shared_ptr op_; 22 | std::vector perm_; 23 | 24 | // unsafe if it was public 25 | void set_perm_(const std::vector& perm); 26 | }; 27 | 28 | } // namespace buffer 29 | } // namespace data 30 | } // namespace mlx 31 | -------------------------------------------------------------------------------- /mlx/data/buffer/Shuffle.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include "mlx/data/buffer/Shuffle.h" 4 | #include "mlx/data/core/State.h" 5 | 6 | #include 7 | #include // iota 8 | 9 | namespace mlx { 10 | namespace data { 11 | namespace buffer { 12 | 13 | Shuffle::Shuffle(const std::shared_ptr& buffer) 14 | : Perm(buffer, rand_perm_(buffer->size())) {} 15 | 16 | std::vector Shuffle::rand_perm_(int64_t size) { 17 | auto state = core::get_state(); 18 | std::vector perm(size); 19 | std::iota(perm.begin(), perm.end(), 0); 20 | std::shuffle(perm.begin(), perm.end(), state->randomGenerator); 21 | return perm; 22 | } 23 | 24 | } // namespace buffer 25 | } // namespace data 26 | } // namespace mlx 27 | -------------------------------------------------------------------------------- /mlx/data/buffer/Shuffle.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include "mlx/data/buffer/Perm.h" 6 | 7 | namespace mlx { 8 | namespace data { 9 | namespace buffer { 10 | 11 | class Shuffle : public Perm { 12 | public: 13 | Shuffle(const std::shared_ptr& buffer); 14 | 15 | private: 16 | std::vector rand_perm_(int64_t size); 17 | }; 18 | 19 | } // namespace buffer 20 | } // namespace data 21 | } // namespace mlx 22 | -------------------------------------------------------------------------------- /mlx/data/buffer/Transform.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include 4 | 5 | #include "mlx/data/buffer/Transform.h" 6 | 7 | namespace mlx { 8 | namespace data { 9 | namespace buffer { 10 | 11 | Transform::Transform( 12 | const std::shared_ptr& od, 13 | const std::shared_ptr& op) 14 | : od_(od), ops_({op}) {}; 15 | 16 | Transform::Transform( 17 | const std::shared_ptr& od, 18 | const std::vector>& ops) 19 | : od_(od), ops_(ops) {}; 20 | 21 | Sample Transform::get(const int64_t idx) const { 22 | auto t_sample = od_->get(idx); 23 | if (t_sample.empty()) { 24 | throw std::runtime_error("Transform: cannot return empty sample"); 25 | } 26 | for (auto& op : ops_) { 27 | t_sample = op->apply(t_sample); 28 | if (t_sample.empty()) { 29 | throw std::runtime_error("Transform: cannot return empty sample"); 30 | } 31 | } 32 | return t_sample; 33 | } 34 | 35 | int64_t Transform::size() const { 36 | return od_->size(); 37 | } 38 | 39 | } // namespace buffer 40 | } // namespace data 41 | } // namespace mlx 42 | -------------------------------------------------------------------------------- /mlx/data/buffer/Transform.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include 6 | #include 7 | #include 8 | 9 | #include "mlx/data/buffer/Buffer.h" 10 | #include "mlx/data/op/Op.h" 11 | 12 | namespace mlx { 13 | namespace data { 14 | namespace buffer { 15 | 16 | class Transform : public Buffer { 17 | public: 18 | Transform( 19 | const std::shared_ptr& od, 20 | const std::shared_ptr& op); 21 | Transform( 22 | const std::shared_ptr& od, 23 | const std::vector>& ops); 24 | 25 | virtual Sample get(int64_t idx) const override; 26 | 27 | virtual int64_t size() const override; 28 | 29 | protected: 30 | std::shared_ptr od_; 31 | std::vector> ops_; 32 | }; 33 | 34 | } // namespace buffer 35 | } // namespace data 36 | } // namespace mlx 37 | -------------------------------------------------------------------------------- /mlx/data/core/BPETokenizer.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include 6 | #include 7 | #include 8 | 9 | #include "mlx/data/core/Trie.h" 10 | 11 | namespace mlx { 12 | namespace data { 13 | namespace core { 14 | 15 | class BPEMerges { 16 | public: 17 | void add(const std::string& left, const std::string& right, int64_t token); 18 | std::pair can_merge( 19 | std::string_view left, 20 | std::string_view right) const; 21 | 22 | template 23 | std::pair 24 | can_merge(iterator_type left, iterator_type middle, iterator_type end) const { 25 | // switch to std::string_view(left, middle) when in C++20 26 | return can_merge( 27 | std::string_view(&(*left), std::distance(left, middle)), 28 | std::string_view(&(*middle), std::distance(middle, end))); 29 | } 30 | 31 | private: 32 | std::unordered_set strings_; 33 | std::unordered_map< 34 | std::string_view, 35 | std::unordered_map> 36 | merges_; 37 | }; 38 | 39 | class BPETokenizer { 40 | public: 41 | BPETokenizer( 42 | std::shared_ptr> symbols, 43 | std::shared_ptr merges); 44 | 45 | std::vector tokenize(std::string_view input) const; 46 | 47 | private: 48 | std::shared_ptr> symbols_; 49 | std::shared_ptr merges_; 50 | }; 51 | 52 | } // namespace core 53 | } // namespace data 54 | } // namespace mlx 55 | -------------------------------------------------------------------------------- /mlx/data/core/BatchShape.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include 4 | 5 | #include "mlx/data/core/BatchShape.h" 6 | 7 | namespace mlx { 8 | namespace data { 9 | namespace core { 10 | 11 | BatchShape::BatchShape() : nodim_(true), num_sample_(0) {}; 12 | BatchShape::BatchShape(int dim) : dim_(dim), nodim_(false), num_sample_(0) {}; 13 | 14 | int64_t BatchShape::size() const { 15 | int64_t size = 1; 16 | for (auto dim : shape_) { 17 | size *= dim; 18 | } 19 | return size; 20 | } 21 | 22 | const std::vector& BatchShape::shape() const { 23 | return shape_; 24 | } 25 | 26 | void BatchShape::add(const std::vector& shape) { 27 | if (nodim_) { 28 | if (num_sample_ == 0) { 29 | shape_.resize(shape.size() + 1, 0); 30 | std::copy(shape.begin(), shape.end(), shape_.begin() + 1); 31 | } else { 32 | if ((shape.size() + 1) != shape_.size()) { 33 | throw std::runtime_error( 34 | "BatchShape: batched arrays expected to have consistent shapes"); 35 | } 36 | for (int d = 0; d < shape.size(); d++) { 37 | shape_[d + 1] = std::max(shape_[d + 1], shape[d]); 38 | } 39 | } 40 | shape_[0] += 1; 41 | } else { 42 | int64_t dim = dim_; 43 | if (dim < 0) { 44 | dim = shape.size() + dim; 45 | } 46 | if (dim >= shape.size()) { 47 | throw std::runtime_error("BatchShape: dimension out of bound"); 48 | } 49 | if (num_sample_ == 0) { 50 | shape_ = shape; 51 | } else { 52 | if (shape.size() != shape_.size()) { 53 | throw std::runtime_error( 54 | "BatchShape: batched arrays expected to have consistent shapes"); 55 | } 56 | for (int d = 0; d < shape.size(); d++) { 57 | if (d == dim) { 58 | shape_[d] += shape[d]; 59 | } else { 60 | shape_[d] = std::max(shape_[d], shape[d]); 61 | } 62 | } 63 | } 64 | } 65 | num_sample_++; 66 | } 67 | 68 | int64_t BatchShape::num_sample() const { 69 | return num_sample_; 70 | } 71 | 72 | void BatchShape::clear() { 73 | shape_.clear(); 74 | num_sample_ = 0; 75 | } 76 | 77 | int64_t BatchShape::operator[](int dim) const { 78 | if (dim < 0) { 79 | dim = shape_.size() + dim; 80 | } 81 | if (dim >= shape_.size()) { 82 | throw std::runtime_error("BatchShape: dimension out of bound"); 83 | } 84 | return shape_[dim]; 85 | } 86 | 87 | } // namespace core 88 | } // namespace data 89 | } // namespace mlx 90 | -------------------------------------------------------------------------------- /mlx/data/core/BatchShape.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include 4 | #include 5 | 6 | namespace mlx { 7 | namespace data { 8 | namespace core { 9 | 10 | // Computes the shape of a batch 11 | class BatchShape { 12 | public: 13 | // batch by prefixing an extra dim 14 | BatchShape(); 15 | 16 | // batch along a specified dim 17 | BatchShape(int dim); 18 | 19 | // add a shape to the batch 20 | void add(const std::vector& shape); 21 | 22 | void clear(); 23 | 24 | int64_t size() const; 25 | const std::vector& shape() const; 26 | int64_t num_sample() const; 27 | int64_t operator[](int dim) const; 28 | 29 | private: 30 | std::vector shape_; 31 | int dim_; 32 | bool nodim_; 33 | int64_t num_sample_; 34 | }; 35 | } // namespace core 36 | } // namespace data 37 | } // namespace mlx 38 | -------------------------------------------------------------------------------- /mlx/data/core/CSVReader.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #include "bxzstr/bxzstr.hpp" 11 | 12 | namespace mlx { 13 | namespace data { 14 | namespace core { 15 | 16 | class CSVReader { 17 | public: 18 | CSVReader( 19 | const std::string& file, 20 | const char sep = ',', 21 | const char quote = '"'); 22 | CSVReader( 23 | const std::shared_ptr& uf, 24 | const char sep = ',', 25 | const char quote = '"'); 26 | std::vector next(); 27 | void reset(); 28 | 29 | private: 30 | void parse_line_( 31 | const std::string& line, 32 | std::vector& fields, 33 | int& current_state, 34 | std::string& current_field) const; 35 | 36 | std::string filename_; 37 | int numFields_ = -1; 38 | int numLine_ = 0; 39 | char sep_ = ','; 40 | char quote_ = '"'; 41 | const char lf_ = '\n'; 42 | const char cr_ = '\r'; 43 | std::shared_ptr uf_; 44 | std::shared_ptr f_; 45 | }; 46 | 47 | } // namespace core 48 | } // namespace data 49 | } // namespace mlx 50 | -------------------------------------------------------------------------------- /mlx/data/core/FileFetcher.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include "mlx/data/Array.h" 6 | #include "mlx/data/core/ThreadPool.h" 7 | 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | 16 | namespace mlx { 17 | namespace data { 18 | namespace core { 19 | 20 | class FileFetcherHandle { 21 | public: 22 | FileFetcherHandle(int64_t rank) : rank_(rank) {}; 23 | 24 | private: 25 | int64_t rank_; 26 | friend class FileFetcher; 27 | }; 28 | 29 | // Note that FileFetcher holds a weak reference on a given filename 30 | // If it is not valid anymore, it will be fetched again 31 | class FileFetcher { 32 | public: 33 | // In a multi-threaded environment, use numKeptFiles with caution: Make 34 | // sure it is large enough, to ensure threads won't compete on the files 35 | // to keep locally. 36 | FileFetcher( 37 | int num_prefetch_max = 1, 38 | int num_prefetch_threads = 1, 39 | int num_kept_files = 0, 40 | bool verbose = false); 41 | 42 | void prefetch(const std::vector& filenames); 43 | 44 | // Must be called in the destructor of any subclass 45 | // because prefetch calls the virtual backendFetch() 46 | // which would then be destroyed before ~FileFetcher() 47 | void cancel_prefetch(); 48 | 49 | std::shared_ptr fetch(const std::string& filename) const; 50 | 51 | // Erase a file from cache, and call backend erase 52 | void erase(const std::string& filename) const; 53 | 54 | virtual void backend_fetch(const std::string& filename) const; 55 | 56 | virtual void backend_erase(const std::string& filename) const; 57 | 58 | virtual ~FileFetcher(); 59 | 60 | protected: 61 | void fill_queue_() const; 62 | std::unique_ptr threadPool_; 63 | mutable std::deque prefetchFilenames_; 64 | mutable std::shared_mutex mutex_; 65 | int numPrefetchMax_; 66 | int numKeptFiles_; 67 | mutable int64_t fileRank_; 68 | bool verbose_; 69 | 70 | mutable std::unordered_map> queuedFiles_; 71 | 72 | mutable std::unordered_map> 73 | cachedFiles_; 74 | }; 75 | 76 | } // namespace core 77 | } // namespace data 78 | } // namespace mlx 79 | -------------------------------------------------------------------------------- /mlx/data/core/Levenshtein.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include "mlx/data/Array.h" 4 | 5 | namespace mlx { 6 | namespace data { 7 | namespace core { 8 | 9 | std::shared_ptr levenshtein( 10 | const std::shared_ptr arr1, 11 | const std::shared_ptr len1, 12 | const std::shared_ptr arr2, 13 | const std::shared_ptr len2); 14 | 15 | } 16 | } // namespace data 17 | } // namespace mlx 18 | -------------------------------------------------------------------------------- /mlx/data/core/Numpy.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include "mlx/data/Array.h" 6 | 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | 13 | namespace mlx { 14 | namespace data { 15 | namespace core { 16 | 17 | std::shared_ptr load_numpy(const std::string& filename); 18 | 19 | /// @brief Read numpy (npy) array file. 20 | /// 21 | /// See also: 22 | /// https://numpy.org/doc/stable/reference/generated/numpy.lib.format.html 23 | /// 24 | /// @param filename file to read 25 | /// @return Array with the contents of the file 26 | std::shared_ptr load_numpy( 27 | std::istream& stream, 28 | const std::string& filename = nullptr); 29 | 30 | } // namespace core 31 | } // namespace data 32 | } // namespace mlx 33 | -------------------------------------------------------------------------------- /mlx/data/core/State.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include "mlx/data/core/State.h" 4 | 5 | namespace mlx { 6 | namespace data { 7 | namespace core { 8 | 9 | static State global_state; 10 | 11 | void set_state(int64_t seed) { 12 | global_state.randomGenerator = std::mt19937(seed); 13 | global_state.version++; 14 | } 15 | 16 | std::shared_ptr get_state() { 17 | static thread_local std::shared_ptr state; 18 | if (!state || (state->version != global_state.version)) { 19 | state = std::make_shared(global_state); 20 | } 21 | return state; 22 | }; 23 | 24 | } // namespace core 25 | } // namespace data 26 | } // namespace mlx 27 | -------------------------------------------------------------------------------- /mlx/data/core/State.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include 6 | #include 7 | 8 | #include "mlx/data/core/ThreadPool.h" 9 | 10 | namespace mlx { 11 | namespace data { 12 | namespace core { 13 | 14 | struct State { 15 | std::mt19937 randomGenerator; 16 | int64_t version; 17 | }; 18 | 19 | // thread-local state 20 | std::shared_ptr get_state(); 21 | 22 | // should be called in main thread only 23 | void set_state(int64_t seed); 24 | 25 | } // namespace core 26 | } // namespace data 27 | } // namespace mlx 28 | -------------------------------------------------------------------------------- /mlx/data/core/TARReader.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include "mlx/data/Array.h" 6 | 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | 13 | namespace mlx { 14 | namespace data { 15 | namespace core { 16 | 17 | typedef std::unordered_map> 18 | TARFileIndex; 19 | 20 | class TARReader { 21 | public: 22 | TARReader( 23 | const std::string& filename, 24 | bool nested = false, 25 | int num_threads = 1); 26 | 27 | bool contains(const std::string& filename); 28 | std::shared_ptr get(const std::string& filename); 29 | std::vector get_file_list(); 30 | 31 | private: 32 | std::string filename_; 33 | TARFileIndex index_; 34 | }; 35 | 36 | } // namespace core 37 | } // namespace data 38 | } // namespace mlx 39 | -------------------------------------------------------------------------------- /mlx/data/core/ThreadController.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include 4 | #include 5 | 6 | namespace mlx { 7 | namespace data { 8 | namespace core { 9 | 10 | typedef std::vector ThreadControllerState; 11 | 12 | struct ThreadControllerSym; 13 | 14 | class ThreadController { 15 | public: 16 | ThreadController(); 17 | 18 | ThreadControllerState limit(); 19 | void restore(const ThreadControllerState& state); 20 | 21 | private: 22 | std::vector> symbols_; 23 | }; 24 | 25 | } // namespace core 26 | } // namespace data 27 | } // namespace mlx 28 | -------------------------------------------------------------------------------- /mlx/data/core/ThreadPool.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include "mlx/data/core/ThreadPool.h" 4 | 5 | namespace mlx { 6 | namespace data { 7 | namespace core { 8 | 9 | std::shared_ptr ThreadPool::thread_controller = nullptr; 10 | 11 | ThreadPool::ThreadPool(size_t thread_count) { 12 | if (!thread_controller) { 13 | thread_controller = std::make_shared(); 14 | } 15 | for (size_t i = 0; i < thread_count; ++i) { 16 | // start waiting threads. Workers listen for changes through 17 | // the ThreadPool member condition_variable 18 | threads_.emplace_back(std::thread([&]() { 19 | std::unique_lock queue_lock(task_mutex_, std::defer_lock); 20 | 21 | while (true) { 22 | queue_lock.lock(); 23 | task_cv_.wait(queue_lock, [&]() -> bool { 24 | return !tasks_.empty() || stop_threads_; 25 | }); 26 | 27 | // used by dtor to stop all threads without having to 28 | // unceremoniously stop tasks. The tasks must all be 29 | // finished, lest we break a promise and risk a `future` 30 | // object throwing an exception. 31 | if (stop_threads_ && tasks_.empty()) 32 | return; 33 | 34 | // to initialize temp_task, we must move the unique_ptr 35 | // from the queue to the local stack. Since a unique_ptr 36 | // cannot be copied (obviously), it must be explicitly 37 | // moved. This transfers ownership of the pointed-to 38 | // object to *this, as specified in 20.11.1.2.1 39 | // [unique.ptr.single.ctor]. 40 | auto temp_task = std::move(tasks_.front()); 41 | 42 | tasks_.pop(); 43 | queue_lock.unlock(); 44 | 45 | auto thread_state = thread_controller->limit(); 46 | (*temp_task)(); 47 | thread_controller->restore(thread_state); 48 | } 49 | })); 50 | } 51 | } 52 | 53 | ThreadPool::~ThreadPool() { 54 | stop_threads_ = true; 55 | task_cv_.notify_all(); 56 | 57 | for (std::thread& thread : threads_) { 58 | thread.join(); 59 | } 60 | } 61 | } // namespace core 62 | } // namespace data 63 | } // namespace mlx 64 | -------------------------------------------------------------------------------- /mlx/data/core/Tokenizer.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include 6 | #include 7 | 8 | #include "mlx/data/core/Graph.h" 9 | #include "mlx/data/core/Trie.h" 10 | 11 | namespace mlx { 12 | namespace data { 13 | namespace core { 14 | 15 | std::shared_ptr> tokenize( 16 | std::shared_ptr> trie, 17 | const std::string& input, 18 | bool ignore_unk = false); 19 | 20 | class Tokenizer { 21 | public: 22 | Tokenizer( 23 | std::shared_ptr> trie, 24 | bool ignore_unk = false, 25 | const std::vector& trie_key_scores = {}); 26 | std::shared_ptr> tokenize(const std::string& input) const; 27 | std::vector tokenize_shortest(const std::string& input) const; 28 | std::vector tokenize_rand(const std::string& input) const; 29 | 30 | private: 31 | std::shared_ptr> trie_; 32 | bool ignoreUnk_; 33 | std::vector trieKeyScores_; 34 | bool trieKeyScoresPositive_; 35 | }; 36 | 37 | class TokenizerIterator { 38 | public: 39 | TokenizerIterator(std::shared_ptr> graph); 40 | std::vector next(); 41 | 42 | private: 43 | std::shared_ptr> g_; 44 | std::vector edgeIndices_; // edge indices for path so far 45 | std::vector backEdgeIds_; // back edge ids for path so far 46 | int64_t currentNodeId_; // current node 47 | std::vector currentTokens_; // current tokenization 48 | std::unordered_set::const_iterator startNodeIterator_; 49 | bool new_start_(); 50 | void forward_(); 51 | }; 52 | 53 | } // namespace core 54 | } // namespace data 55 | } // namespace mlx 56 | -------------------------------------------------------------------------------- /mlx/data/core/Utils.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023-2024 Apple Inc. 2 | 3 | #include "mlx/data/Array.h" 4 | #include "mlx/data/Sample.h" 5 | 6 | namespace mlx { 7 | namespace data { 8 | namespace core { 9 | 10 | std::pair, std::shared_ptr> uniq( 11 | const std::shared_ptr src, 12 | const std::shared_ptr src_length, 13 | int dim, 14 | double pad); 15 | 16 | std::pair, std::shared_ptr> remove( 17 | const std::shared_ptr src, 18 | const std::shared_ptr src_length, 19 | int dim, 20 | double value, 21 | double pad); 22 | 23 | std::shared_ptr replace( 24 | const std::shared_ptr& src, 25 | const std::shared_ptr& old, 26 | const std::shared_ptr& replacement, 27 | int count); 28 | 29 | Sample merge_batch( 30 | const std::vector& samples, 31 | const std::unordered_map& pad_values = {}, 32 | const std::unordered_map& batch_dims = {}); 33 | 34 | } // namespace core 35 | } // namespace data 36 | } // namespace mlx 37 | -------------------------------------------------------------------------------- /mlx/data/core/audio/Audio.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include "mlx/data/core/audio/AudioPrivate.h" 4 | 5 | namespace mlx { 6 | namespace data { 7 | namespace core { 8 | namespace audio { 9 | 10 | std::shared_ptr load(const std::string& path, AudioInfo* info) { 11 | return load_sndfile(path, info); 12 | } 13 | 14 | std::shared_ptr load( 15 | const std::shared_ptr& contents, 16 | AudioInfo* info) { 17 | return load_sndfile(contents, info); 18 | } 19 | 20 | AudioInfo info(const std::string& path) { 21 | return info_sndfile(path); 22 | } 23 | 24 | AudioInfo info(const std::shared_ptr& contents) { 25 | return info_sndfile(contents); 26 | } 27 | 28 | void verify_audio(const std::shared_ptr& audio) { 29 | auto dimensions = audio->shape().size(); 30 | if (dimensions != 2) { 31 | throw std::runtime_error( 32 | "verifyAudio: audio must be 2 dimension Array (SC)"); 33 | } 34 | } 35 | 36 | } // namespace audio 37 | } // namespace core 38 | } // namespace data 39 | } // namespace mlx 40 | -------------------------------------------------------------------------------- /mlx/data/core/audio/Audio.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include "mlx/data/Array.h" 6 | 7 | namespace mlx { 8 | namespace data { 9 | namespace core { 10 | namespace audio { 11 | 12 | /// Image metadata 13 | struct AudioInfo { 14 | public: 15 | int64_t frames; 16 | int sampleRate; 17 | int channels; 18 | }; 19 | 20 | std::shared_ptr load(const std::string& path, AudioInfo* info); 21 | std::shared_ptr load( 22 | const std::shared_ptr& contents, 23 | AudioInfo* info); 24 | 25 | AudioInfo info(const std::string& path); 26 | AudioInfo info(const std::shared_ptr& contents); 27 | 28 | /// Verify that the given Array is structured like audio: 29 | /// two dimensions (s, c). 30 | void verify_audio(const std::shared_ptr& audio); 31 | 32 | /// Return the number of frames in the audio. Requires that `verifyAudio()` be 33 | /// called previously. 34 | inline const int64_t frames(const std::shared_ptr& audio) { 35 | return audio->shape()[0]; 36 | } 37 | 38 | /// Return the channel count of the audio. Requires that `verifyAudio()` be 39 | /// called previously. 40 | inline const int64_t channels(const std::shared_ptr& audio) { 41 | return audio->shape()[1]; 42 | } 43 | 44 | /// Resample mode -- these should match 1:1 with the libsamplerate enum, e.g. 45 | /// SRC_SINC_BEST_QUALITY 46 | enum class ResampleMode { 47 | best = 0, 48 | medium = 1, 49 | fastest = 2, 50 | zeroOrderHold = 3, 51 | linear = 4, 52 | }; 53 | 54 | std::shared_ptr resample( 55 | const std::shared_ptr& audio, 56 | ResampleMode resample_mode, 57 | int src_sample_rate, 58 | int dst_sample_rate); 59 | 60 | } // namespace audio 61 | } // namespace core 62 | } // namespace data 63 | } // namespace mlx 64 | -------------------------------------------------------------------------------- /mlx/data/core/audio/AudioPrivate.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include "mlx/data/core/audio/Audio.h" 6 | 7 | namespace mlx { 8 | namespace data { 9 | namespace core { 10 | namespace audio { 11 | 12 | std::shared_ptr load_sndfile(const std::string& path, AudioInfo* info); 13 | std::shared_ptr load_sndfile( 14 | const std::shared_ptr& contents, 15 | AudioInfo* info); 16 | 17 | AudioInfo info_sndfile(const std::string& path); 18 | AudioInfo info_sndfile(const std::shared_ptr& contents); 19 | 20 | } // namespace audio 21 | } // namespace core 22 | } // namespace data 23 | } // namespace mlx 24 | -------------------------------------------------------------------------------- /mlx/data/core/audio/AudioSampleRate.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include 4 | #include "mlx/data/core/audio/Audio.h" 5 | 6 | #ifdef MLX_HAS_SAMPLERATE 7 | #include 8 | #endif 9 | 10 | namespace mlx { 11 | namespace data { 12 | namespace core { 13 | namespace audio { 14 | 15 | #ifdef MLX_HAS_SAMPLERATE 16 | 17 | std::shared_ptr resample( 18 | const std::shared_ptr& audio, 19 | ResampleMode resample_mode, 20 | int src_sample_rate, 21 | int dst_sample_rate) { 22 | if ((dst_sample_rate <= 0) || (src_sample_rate == dst_sample_rate)) { 23 | return audio; 24 | } 25 | 26 | int64_t audio_channels = channels(audio); 27 | int64_t audio_length = frames(audio); 28 | 29 | double length_scale = static_cast(dst_sample_rate) / 30 | static_cast(src_sample_rate); 31 | int64_t new_audio_length = 32 | static_cast(std::floor(audio_length * length_scale)); 33 | auto result = std::make_shared( 34 | ArrayType::Float, new_audio_length, audio_channels); 35 | SRC_DATA src_data; 36 | src_data.data_in = audio->data(); 37 | src_data.input_frames = audio_length; 38 | src_data.data_out = result->data(); 39 | src_data.output_frames = new_audio_length; 40 | src_data.src_ratio = length_scale; 41 | auto status = src_simple(&src_data, (int)resample_mode, audio_channels); 42 | if (status) { 43 | std::string msg("audio: libsamplerate failed with: "); 44 | msg += std::string(src_strerror(status)); 45 | throw std::runtime_error(msg); 46 | } 47 | 48 | if (new_audio_length != src_data.output_frames_gen) { 49 | std::vector offset(2, 0); 50 | auto new_shape = result->shape(); 51 | new_shape[0] = src_data.output_frames_gen; 52 | result = array::sub(result, offset, new_shape); 53 | } 54 | 55 | return result; 56 | } 57 | 58 | #else 59 | 60 | std::shared_ptr resample( 61 | const std::shared_ptr& audio, 62 | ResampleMode resample_mode, 63 | int src_sample_rate, 64 | int dst_sample_rate) { 65 | throw std::runtime_error( 66 | "audio: mlx was not compiled with sample rate conversion support (libsamplerate)"); 67 | } 68 | 69 | #endif 70 | 71 | } // namespace audio 72 | } // namespace core 73 | } // namespace data 74 | } // namespace mlx 75 | -------------------------------------------------------------------------------- /mlx/data/core/image/Image.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include "mlx/data/Array.h" 6 | 7 | namespace mlx { 8 | namespace data { 9 | namespace core { 10 | namespace image { 11 | 12 | /// Image metadata 13 | struct ImageInfo { 14 | public: 15 | int width; 16 | int height; 17 | int channels; 18 | }; 19 | 20 | std::shared_ptr load(const std::string& path); 21 | std::shared_ptr load(const std::shared_ptr& contents); 22 | 23 | ImageInfo info(const std::string& path); 24 | ImageInfo info(const std::shared_ptr& contents); 25 | 26 | bool save(const std::shared_ptr& image, const std::string& path); 27 | 28 | /// Verify that the given Array is structured like an image: 29 | /// three dimensions (w, h, c), last dimension looks like channels 30 | /// and is 1, 3, or 4. 31 | void verify_image(const std::shared_ptr& image); 32 | 33 | /// Return the width of the image. Requires that `verifyImage()` be called 34 | /// previously. 35 | inline const int64_t width(const std::shared_ptr& image) { 36 | return image->shape()[1]; 37 | } 38 | 39 | /// Return the height of the image. Requires that `verifyImage()` be called 40 | /// previously. 41 | inline const int64_t height(const std::shared_ptr& image) { 42 | return image->shape()[0]; 43 | } 44 | 45 | /// Return the channel count of the image. Requires that `verifyImage()` be 46 | /// called previously. 47 | inline const int64_t channels(const std::shared_ptr& image) { 48 | return image->shape()[2]; 49 | } 50 | 51 | std::shared_ptr scale( 52 | const std::shared_ptr& image, 53 | double scale); 54 | std::shared_ptr resize( 55 | const std::shared_ptr& image, 56 | int64_t dw, 57 | int64_t dh); // may alter aspect ratio 58 | std::shared_ptr crop( 59 | const std::shared_ptr& image, 60 | int64_t x, 61 | int64_t y, 62 | int64_t w, 63 | int64_t h); 64 | std::shared_ptr 65 | affine(const std::shared_ptr& image, const float mx[6], bool crop); 66 | std::shared_ptr 67 | rotate(const std::shared_ptr& image, double angle, bool crop); 68 | std::shared_ptr hflip(const std::shared_ptr& image); 69 | std::shared_ptr channel_reduction( 70 | const std::shared_ptr& image, 71 | const float bias, 72 | const float multiplier[3]); 73 | 74 | } // namespace image 75 | } // namespace core 76 | } // namespace data 77 | } // namespace mlx 78 | -------------------------------------------------------------------------------- /mlx/data/core/image/ImageIO.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include "mlx/data/core/image/ImagePrivate.h" 4 | 5 | namespace mlx { 6 | namespace data { 7 | namespace core { 8 | namespace image { 9 | 10 | std::shared_ptr load(const std::string& path) { 11 | auto result = load_jpeg(path); 12 | if (result == nullptr) { 13 | result = load_stbi(path); 14 | } 15 | return result; 16 | } 17 | 18 | std::shared_ptr load(const std::shared_ptr& contents) { 19 | auto result = load_jpeg(contents); 20 | if (result == nullptr) { 21 | result = load_stbi(contents); 22 | } 23 | return result; 24 | } 25 | 26 | ImageInfo info(const std::string& path) { 27 | return info_stbi(path); 28 | } 29 | 30 | ImageInfo info(const std::shared_ptr& contents) { 31 | return info_stbi(contents); 32 | } 33 | 34 | bool save(const std::shared_ptr& image, const std::string& path) { 35 | verify_image(image); 36 | return save_jpeg(image, path); 37 | } 38 | 39 | void verify_image(const std::shared_ptr& image) { 40 | auto dimensions = image->shape().size(); 41 | if (dimensions != 3) { 42 | throw std::runtime_error( 43 | "verifyImage: image must be 3 dimension Array (HWC)"); 44 | } 45 | 46 | if (channels(image) == 0 || channels(image) > 4) { 47 | throw std::runtime_error("verifyImage: channels must be 0 <= c <= 4"); 48 | } 49 | } 50 | 51 | } // namespace image 52 | } // namespace core 53 | } // namespace data 54 | } // namespace mlx 55 | -------------------------------------------------------------------------------- /mlx/data/core/image/ImagePrivate.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include "mlx/data/core/image/Image.h" 6 | 7 | namespace mlx { 8 | namespace data { 9 | namespace core { 10 | namespace image { 11 | 12 | /// implementation for mlx::data::image 13 | 14 | std::shared_ptr load_stbi(const std::string& path); 15 | std::shared_ptr load_stbi(const std::shared_ptr contents); 16 | 17 | ImageInfo info_stbi(const std::string& path); 18 | ImageInfo info_stbi(const std::shared_ptr contents); 19 | 20 | std::shared_ptr load_jpeg(const std::string& path); 21 | std::shared_ptr load_jpeg(const std::shared_ptr contents); 22 | 23 | bool save_jpeg( 24 | const std::shared_ptr image, 25 | const std::string& path); 26 | 27 | } // namespace image 28 | } // namespace core 29 | } // namespace data 30 | } // namespace mlx 31 | -------------------------------------------------------------------------------- /mlx/data/core/image/ImageSTBI.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include "mlx/data/core/image/ImagePrivate.h" 4 | 5 | /* DEBUG: #define STBI_NEON 1 */ 6 | #define STB_IMAGE_IMPLEMENTATION 7 | #include 8 | 9 | namespace mlx { 10 | namespace data { 11 | namespace core { 12 | namespace image { 13 | 14 | std::shared_ptr load_stbi(const std::string& path) { 15 | ImageInfo info = info_stbi(path); 16 | int required_channels = info.channels; 17 | if (required_channels > 3) { 18 | required_channels = 3; 19 | } 20 | 21 | int w, h, c; 22 | unsigned char* img = nullptr; 23 | img = stbi_load(path.c_str(), &w, &h, &c, required_channels); 24 | 25 | if (!img) { 26 | throw std::runtime_error("load_stbi: could not load <" + path + ">"); 27 | } 28 | 29 | std::vector shape = {h, w, required_channels}; 30 | return std::make_shared( 31 | UInt8, shape, std::shared_ptr((void*)img, stbi_image_free)); 32 | } 33 | 34 | std::shared_ptr load_stbi(const std::shared_ptr contents) { 35 | ImageInfo info = info_stbi(contents); 36 | int required_channels = info.channels; 37 | if (required_channels > 3) { 38 | required_channels = 3; 39 | } 40 | 41 | int w, h, c; 42 | unsigned char* img = nullptr; 43 | img = stbi_load_from_memory( 44 | contents->data(), 45 | contents->size(), 46 | &w, 47 | &h, 48 | &c, 49 | required_channels); 50 | 51 | if (!img) { 52 | throw std::runtime_error("load_stbi: could not load from memory"); 53 | } 54 | 55 | std::vector shape = {h, w, required_channels}; 56 | return std::make_shared( 57 | UInt8, shape, std::shared_ptr((void*)img, stbi_image_free)); 58 | } 59 | 60 | ImageInfo info_stbi(const std::string& path) { 61 | int w, h, c; 62 | if (!stbi_info(path.c_str(), &w, &h, &c)) { 63 | return {}; 64 | } 65 | return { 66 | .width = w, 67 | .height = h, 68 | .channels = c, 69 | }; 70 | } 71 | 72 | ImageInfo info_stbi(const std::shared_ptr contents) { 73 | int w, h, c; 74 | if (!stbi_info_from_memory( 75 | contents->data(), contents->size(), &w, &h, &c)) { 76 | return {}; 77 | } 78 | // clamp the number of channels to 3 -- no alpha 79 | if (c == 4) { 80 | c = 3; 81 | } 82 | return { 83 | .width = w, 84 | .height = h, 85 | .channels = c, 86 | }; 87 | } 88 | 89 | } // namespace image 90 | } // namespace core 91 | } // namespace data 92 | } // namespace mlx 93 | -------------------------------------------------------------------------------- /mlx/data/core/imemstream.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include "mlx/data/Array.h" 4 | 5 | #include 6 | 7 | namespace mlx { 8 | namespace data { 9 | namespace core { 10 | 11 | struct membuf : std::streambuf { 12 | membuf(char const* base, size_t size) { 13 | char* p(const_cast(base)); 14 | this->setg(p, p, p + size); 15 | } 16 | }; 17 | struct imemstream : virtual membuf, std::istream { 18 | imemstream(std::shared_ptr array) 19 | : membuf(static_cast(array->data()), array->size()), 20 | std::istream(static_cast(this)), 21 | array_(array) {} 22 | std::shared_ptr array_; 23 | }; 24 | 25 | } // namespace core 26 | } // namespace data 27 | } // namespace mlx 28 | -------------------------------------------------------------------------------- /mlx/data/core/video/Video.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include "mlx/data/core/video/VideoPrivate.h" 4 | 5 | namespace mlx { 6 | namespace data { 7 | namespace core { 8 | namespace video { 9 | 10 | static std::shared_ptr load(VideoReader& reader) { 11 | VideoInfo video_info = reader.info(); 12 | 13 | // allocate storage for the full video 14 | auto result = std::make_shared( 15 | ArrayType::UInt8, 16 | video_info.frames, 17 | video_info.height, 18 | video_info.width, 19 | 3); 20 | 21 | for (int frame_number = 0; frame_number < video_info.frames; frame_number++) { 22 | auto frame = reader.read_frame(array::slice(result, frame_number)); 23 | 24 | if (result == nullptr) { 25 | // finished early -- metadata does not match actual 26 | // number of frames, make a new buffer to fit the data 27 | 28 | auto new_shape = result->shape(); 29 | new_shape[0] = frame_number; 30 | auto new_frames = std::make_shared( 31 | ArrayType::UInt8, 32 | frame_number, 33 | video_info.height, 34 | video_info.width, 35 | 3); 36 | memcpy( 37 | new_frames->data(), 38 | result->data(), 39 | new_frames->itemsize() * new_frames->size()); 40 | result = new_frames; 41 | 42 | break; 43 | } 44 | } 45 | 46 | return result; 47 | } 48 | 49 | std::shared_ptr load(const std::string& path) { 50 | VideoReader reader = VideoReader(path); 51 | return load(reader); 52 | } 53 | 54 | std::shared_ptr load(const std::shared_ptr& contents) { 55 | VideoReader reader = VideoReader(contents); 56 | return load(reader); 57 | } 58 | 59 | VideoInfo info(const std::string& path) { 60 | VideoReader reader = VideoReader(path); 61 | return reader.info(); 62 | } 63 | 64 | VideoInfo info(const std::shared_ptr& contents) { 65 | VideoReader reader = VideoReader(contents); 66 | return reader.info(); 67 | } 68 | 69 | void verify_video(const std::shared_ptr& video) { 70 | auto dimensions = video->shape().size(); 71 | if (dimensions != 4) { 72 | throw std::runtime_error( 73 | "verifyVideo: video must be 4 dimension Array (FHWC)"); 74 | } 75 | 76 | if (channels(video) == 0 || channels(video) > 4) { 77 | throw std::runtime_error("verifyVideo: channels must be 0 <= c <= 4"); 78 | } 79 | } 80 | 81 | } // namespace video 82 | } // namespace core 83 | } // namespace data 84 | } // namespace mlx 85 | -------------------------------------------------------------------------------- /mlx/data/core/video/Video.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include "mlx/data/Array.h" 6 | 7 | namespace mlx { 8 | namespace data { 9 | namespace core { 10 | namespace video { 11 | 12 | /// Video metadata 13 | struct VideoInfo { 14 | public: 15 | int width; 16 | int height; 17 | int channels; 18 | int64_t frames; 19 | }; 20 | 21 | std::shared_ptr load(const std::string& path); 22 | std::shared_ptr load(const std::shared_ptr& contents); 23 | 24 | VideoInfo info(const std::string& path); 25 | VideoInfo info(const std::shared_ptr& contents); 26 | 27 | /// Verify that the given Array is structured like a video: 28 | /// four dimensions (f, w, h, c), last dimension looks like channels 29 | /// and is 1, 3, or 4. 30 | void verify_video(const std::shared_ptr& video); 31 | 32 | /// Return the width of the image. Requires that `verifyVideo()` be called 33 | /// previously. 34 | inline const int64_t width(const std::shared_ptr& video) { 35 | return video->shape()[2]; 36 | } 37 | 38 | /// Return the height of the image. Requires that `verifyVideo()` be called 39 | /// previously. 40 | inline const int64_t height(const std::shared_ptr& video) { 41 | return video->shape()[1]; 42 | } 43 | 44 | /// Return the channel count of the image. Requires that `verifyVideo()` be 45 | /// called previously. 46 | inline const int64_t channels(const std::shared_ptr& video) { 47 | return video->shape()[3]; 48 | } 49 | 50 | /// Return the frame count of the image. Requires that `verifyVideo()` be 51 | /// called previously. 52 | inline const int64_t frames(const std::shared_ptr& video) { 53 | return video->shape()[0]; 54 | } 55 | 56 | } // namespace video 57 | } // namespace core 58 | } // namespace data 59 | } // namespace mlx 60 | -------------------------------------------------------------------------------- /mlx/data/core/video/VideoPrivate.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include "mlx/data/core/video/Video.h" 6 | 7 | #include 8 | 9 | namespace mlx { 10 | namespace data { 11 | namespace core { 12 | namespace video { 13 | 14 | // Opaque state 15 | class VideoReaderState; 16 | 17 | class VideoReader { 18 | public: 19 | VideoReader(const std::string& filename); 20 | VideoReader(const std::shared_ptr& contents); 21 | 22 | ~VideoReader(); 23 | 24 | VideoInfo info(); 25 | std::shared_ptr read_frame( 26 | std::shared_ptr destination = nullptr); 27 | 28 | private: 29 | VideoReaderState* state_; 30 | }; 31 | 32 | } // namespace video 33 | } // namespace core 34 | } // namespace data 35 | } // namespace mlx 36 | -------------------------------------------------------------------------------- /mlx/data/op/FilterByShape.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include "mlx/data/op/FilterByShape.h" 4 | 5 | namespace mlx { 6 | namespace data { 7 | namespace op { 8 | FilterByShape::FilterByShape( 9 | const std::string& key, 10 | int dim, 11 | int64_t low, 12 | int64_t high) 13 | : key_(key), dim_(dim), low_(low), high_(high) {} 14 | Sample FilterByShape::apply(const Sample& sample) const { 15 | auto array = sample::check_key(sample, key_, ArrayType::Any); 16 | auto dim = dim_; 17 | if (dim < 0) { 18 | dim += array->ndim(); 19 | } 20 | if (dim < 0 || dim >= array->ndim()) { 21 | return Sample(); 22 | } 23 | if ((low_ >= 0) && (array->shape(dim) < low_)) { 24 | return Sample(); 25 | } 26 | if ((high_ >= 0) && (array->shape(dim) > high_)) { 27 | return Sample(); 28 | } 29 | return sample; 30 | } 31 | } // namespace op 32 | } // namespace data 33 | } // namespace mlx 34 | -------------------------------------------------------------------------------- /mlx/data/op/FilterByShape.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include "mlx/data/op/Op.h" 6 | 7 | namespace mlx { 8 | namespace data { 9 | namespace op { 10 | 11 | class FilterByShape : public Op { 12 | public: 13 | FilterByShape( 14 | const std::string& key, 15 | int dim, 16 | int64_t low = -1, 17 | int64_t high = -1); 18 | 19 | virtual Sample apply(const Sample& sample) const override; 20 | 21 | private: 22 | std::string key_; 23 | int dim_; 24 | int64_t low_; 25 | int64_t high_; 26 | }; 27 | 28 | } // namespace op 29 | } // namespace data 30 | } // namespace mlx 31 | -------------------------------------------------------------------------------- /mlx/data/op/FilterKey.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include "mlx/data/op/FilterKey.h" 4 | 5 | namespace mlx { 6 | namespace data { 7 | namespace op { 8 | FilterKey::FilterKey(const std::string& key, bool remove) 9 | : keys_({key}), remove_(remove) {} 10 | FilterKey::FilterKey(const std::vector& keys, bool remove) 11 | : keys_(keys), remove_(remove) {} 12 | Sample FilterKey::apply(const Sample& sample) const { 13 | Sample res; 14 | if (remove_) { 15 | res = sample; 16 | for (auto& key : keys_) { 17 | sample::check_key(sample, key, ArrayType::Any); 18 | res.erase(key); 19 | } 20 | } else { 21 | for (auto& key : keys_) { 22 | auto array = sample::check_key(sample, key, ArrayType::Any); 23 | res[key] = array; 24 | } 25 | } 26 | return res; 27 | } 28 | } // namespace op 29 | } // namespace data 30 | } // namespace mlx 31 | -------------------------------------------------------------------------------- /mlx/data/op/FilterKey.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include "mlx/data/op/Op.h" 6 | 7 | namespace mlx { 8 | namespace data { 9 | namespace op { 10 | 11 | class FilterKey : public Op { 12 | public: 13 | FilterKey(const std::string& key, bool remove = false); 14 | FilterKey(const std::vector& keys, bool remove = false); 15 | 16 | virtual Sample apply(const Sample& sample) const override; 17 | 18 | private: 19 | std::vector keys_; 20 | bool remove_; 21 | }; 22 | 23 | } // namespace op 24 | } // namespace data 25 | } // namespace mlx 26 | -------------------------------------------------------------------------------- /mlx/data/op/KeyTransform.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include 4 | 5 | #include "mlx/data/op/KeyTransform.h" 6 | 7 | namespace mlx { 8 | namespace data { 9 | namespace op { 10 | 11 | KeyTransformOp::KeyTransformOp(const std::string& ikey, const std::string& okey) 12 | : ikey_(ikey), okey_(okey) {}; 13 | 14 | Sample KeyTransformOp::apply(const Sample& sample) const { 15 | auto src = sample::check_key(sample, ikey_, ArrayType::Any); 16 | auto dst = apply_key(src); 17 | auto res = sample; 18 | auto okey = (okey_.empty() ? ikey_ : okey_); 19 | res[okey] = dst; 20 | return res; 21 | } 22 | 23 | KeyTransform::KeyTransform( 24 | const std::string& ikey, 25 | std::function(const std::shared_ptr&)> 26 | op, 27 | const std::string& okey) 28 | : KeyTransformOp(ikey, okey), op_(op) {}; 29 | 30 | std::shared_ptr KeyTransform::apply_key( 31 | const std::shared_ptr& x) const { 32 | return op_(x); 33 | } 34 | 35 | } // namespace op 36 | } // namespace data 37 | } // namespace mlx 38 | -------------------------------------------------------------------------------- /mlx/data/op/KeyTransform.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include 6 | 7 | #include "mlx/data/op/Op.h" 8 | 9 | namespace mlx { 10 | namespace data { 11 | namespace op { 12 | 13 | class KeyTransformOp : public Op { 14 | public: 15 | KeyTransformOp(const std::string& ikey, const std::string& okey = ""); 16 | 17 | virtual Sample apply(const Sample& sample) const; 18 | virtual std::shared_ptr apply_key( 19 | const std::shared_ptr& x) const = 0; 20 | 21 | protected: 22 | std::string ikey_; 23 | std::string okey_; 24 | }; 25 | 26 | class KeyTransform : public KeyTransformOp { 27 | public: 28 | KeyTransform( 29 | const std::string& ikey, 30 | std::function(const std::shared_ptr&)> 31 | op, 32 | const std::string& okey = ""); 33 | 34 | virtual std::shared_ptr apply_key( 35 | const std::shared_ptr& x) const override; 36 | 37 | private: 38 | std::function(const std::shared_ptr&)> 39 | op_; 40 | }; 41 | 42 | } // namespace op 43 | } // namespace data 44 | } // namespace mlx 45 | -------------------------------------------------------------------------------- /mlx/data/op/LoadAudio.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include "mlx/data/op/Op.h" 6 | 7 | namespace mlx { 8 | namespace data { 9 | namespace op { 10 | 11 | enum class LoadAudioInfo { 12 | All, 13 | NumFrames, 14 | NumChannels, 15 | SampleRate, 16 | NumSeconds 17 | }; 18 | 19 | enum class LoadAudioResamplingQuality { 20 | SincBest, 21 | SincMedium, 22 | SincFastest, 23 | ZeroOrderHold, 24 | Linear, 25 | }; 26 | 27 | class LoadAudio : public Op { 28 | public: 29 | LoadAudio( 30 | const std::string& ikey, 31 | const std::string& prefix = "", 32 | // if info is true, provides info 33 | // either in infokey (if provided) 34 | // or in okey (in which case no audio will be loaded) 35 | bool info = false, 36 | bool from_memory = false, 37 | LoadAudioInfo info_type = LoadAudioInfo::All, 38 | int sample_rate = 0, 39 | LoadAudioResamplingQuality resampling_quality = 40 | LoadAudioResamplingQuality::SincFastest, 41 | const std::string& infokey = "", 42 | const std::string& okey = ""); 43 | 44 | virtual Sample apply(const Sample& sample) const override; 45 | 46 | private: 47 | std::string iKey_; 48 | std::string oKey_; 49 | std::string infoKey_; 50 | std::string prefix_; 51 | bool info_; 52 | bool from_memory_; 53 | LoadAudioInfo infoType_; 54 | int sampleRate_; 55 | LoadAudioResamplingQuality resamplingQuality_; 56 | }; 57 | 58 | class ResampleAudio : public Op { 59 | public: 60 | // infokey: metadata info provided by LoadAudio, 61 | // which can be a scalar (assumed to be the sample rate itself) 62 | // or a vector containing {audio_length, audio_channels, audio_sample_rate} 63 | // instead of infokey, explicit input_sample_rate can also be provided 64 | ResampleAudio( 65 | const std::string& ikey, 66 | int output_sample_rate, 67 | int input_sample_rate = 0, 68 | const std::string& infokey = "", 69 | LoadAudioResamplingQuality resampling_quality = 70 | LoadAudioResamplingQuality::SincFastest, 71 | const std::string& okey = ""); 72 | 73 | virtual Sample apply(const Sample& sample) const override; 74 | 75 | private: 76 | std::string iKey_; 77 | std::string oKey_; 78 | std::string infoKey_; 79 | LoadAudioResamplingQuality resamplingQuality_; 80 | int inputSampleRate_; 81 | int outputSampleRate_; 82 | }; 83 | 84 | } // namespace op 85 | } // namespace data 86 | } // namespace mlx 87 | -------------------------------------------------------------------------------- /mlx/data/op/LoadFile.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include "mlx/data/op/LoadFile.h" 4 | 5 | #include 6 | 7 | namespace mlx { 8 | namespace data { 9 | namespace op { 10 | LoadFile::LoadFile( 11 | const std::string& ikey, 12 | const std::filesystem::path& prefix, 13 | const std::string& okey) 14 | : KeyTransformOp(ikey, okey), prefix_(prefix) {} 15 | 16 | std::shared_ptr LoadFile::apply_key( 17 | const std::shared_ptr& src) const { 18 | if (src->type() != ArrayType::Int8) { 19 | throw std::runtime_error("LoadFile: char array (int8) expected"); 20 | } 21 | std::filesystem::path path = prefix_; 22 | std::string filename(reinterpret_cast(src->data()), src->size()); 23 | path /= filename; 24 | 25 | std::shared_ptr dst; 26 | std::ifstream file; 27 | file.exceptions(std::ifstream::badbit); 28 | try { 29 | file.open(path, std::ios::binary | std::ios::ate); 30 | int64_t file_size = file.tellg(); 31 | file.seekg(0, std::ios::beg); 32 | dst = std::make_shared(Int8, file_size); 33 | file.read(dst->data(), file_size); 34 | file.close(); 35 | } catch (const std::ifstream::failure& e) { 36 | throw std::runtime_error( 37 | std::string("LoadFile: unable to read ") + path.string()); 38 | } 39 | return dst; 40 | } 41 | } // namespace op 42 | } // namespace data 43 | } // namespace mlx 44 | -------------------------------------------------------------------------------- /mlx/data/op/LoadFile.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include 6 | #include 7 | 8 | #include "mlx/data/op/KeyTransform.h" 9 | 10 | namespace mlx { 11 | namespace data { 12 | namespace op { 13 | 14 | /// Operation that will load a file into memory -- similar to `OpReadFromTAR` 15 | /// but loads directly from the filesystem. This is useful for testing the 16 | /// in-memory path. 17 | class LoadFile : public KeyTransformOp { 18 | public: 19 | LoadFile( 20 | const std::string& ikey, 21 | const std::filesystem::path& prefix = "", 22 | const std::string& okey = ""); 23 | 24 | virtual std::shared_ptr apply_key( 25 | const std::shared_ptr& src) const override; 26 | 27 | private: 28 | std::filesystem::path prefix_; 29 | }; 30 | 31 | } // namespace op 32 | } // namespace data 33 | } // namespace mlx 34 | -------------------------------------------------------------------------------- /mlx/data/op/LoadImage.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include 4 | 5 | #include "mlx/data/core/image/Image.h" 6 | #include "mlx/data/op/LoadImage.h" 7 | 8 | namespace mlx { 9 | namespace data { 10 | namespace op { 11 | LoadImage::LoadImage( 12 | const std::string& ikey, 13 | const std::string& prefix, 14 | bool info, 15 | const std::string& format, 16 | bool from_memory, 17 | const std::string& okey) 18 | : KeyTransformOp(ikey, okey), 19 | prefix_(prefix), 20 | info_(info), 21 | format_(format), 22 | from_memory_(from_memory) {} 23 | std::shared_ptr LoadImage::apply_key( 24 | const std::shared_ptr& src) const { 25 | std::filesystem::path path; 26 | std::shared_ptr dst; 27 | if (!from_memory_) { 28 | path = prefix_; 29 | if (src->type() != ArrayType::Int8) { 30 | throw std::runtime_error("LoadImage: char array (int8) expected"); 31 | } 32 | std::string filename(reinterpret_cast(src->data()), src->size()); 33 | path /= filename; 34 | } 35 | if (info_) { 36 | auto info = from_memory_ ? core::image::info(src) : core::image::info(path); 37 | std::vector info_array({info.width, info.height}); 38 | dst = std::make_shared(info_array); 39 | } else { 40 | dst = from_memory_ ? core::image::load(src) : core::image::load(path); 41 | if (!dst) { 42 | throw std::runtime_error( 43 | "LoadImage: unable to load image <" + 44 | (from_memory_ ? "stream" : path.string()) + ">"); 45 | } 46 | } 47 | return dst; 48 | } 49 | } // namespace op 50 | } // namespace data 51 | } // namespace mlx 52 | -------------------------------------------------------------------------------- /mlx/data/op/LoadImage.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include "mlx/data/op/KeyTransform.h" 6 | 7 | namespace mlx { 8 | namespace data { 9 | namespace op { 10 | 11 | class LoadImage : public KeyTransformOp { 12 | public: 13 | // note: info=true is meant to fast-retrieval of image size and thus will not 14 | // perform transformations 15 | LoadImage( 16 | const std::string& ikey, 17 | const std::string& prefix = "", 18 | bool info = false, 19 | const std::string& format = "RGB", 20 | bool from_memory = false, 21 | const std::string& okey = ""); 22 | 23 | virtual std::shared_ptr apply_key( 24 | const std::shared_ptr& src) const override; 25 | 26 | private: 27 | std::string prefix_; 28 | bool info_; 29 | std::string format_; 30 | bool from_memory_; 31 | }; 32 | 33 | } // namespace op 34 | } // namespace data 35 | } // namespace mlx 36 | -------------------------------------------------------------------------------- /mlx/data/op/LoadNumpy.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include "mlx/data/op/LoadNumpy.h" 4 | #include "mlx/data/core/Numpy.h" 5 | #include "mlx/data/core/imemstream.h" 6 | 7 | #include 8 | 9 | namespace mlx { 10 | namespace data { 11 | namespace op { 12 | LoadNumpy::LoadNumpy( 13 | const std::string& ikey, 14 | const std::string& prefix, 15 | bool from_memory, 16 | const std::string& okey) 17 | : KeyTransformOp(ikey, okey), prefix_(prefix), from_memory_(from_memory) {} 18 | 19 | std::shared_ptr LoadNumpy::apply_key( 20 | const std::shared_ptr& src) const { 21 | std::shared_ptr dst; 22 | if (from_memory_) { 23 | auto stream = core::imemstream(src); 24 | dst = core::load_numpy(stream, ""); 25 | } else { 26 | std::filesystem::path path = prefix_; 27 | if (src->type() != ArrayType::Int8) { 28 | throw std::runtime_error("LoadNumpy: char array (int8) expected"); 29 | } 30 | std::string filename(reinterpret_cast(src->data()), src->size()); 31 | path /= filename; 32 | dst = core::load_numpy(path); 33 | } 34 | return dst; 35 | } 36 | } // namespace op 37 | } // namespace data 38 | } // namespace mlx 39 | -------------------------------------------------------------------------------- /mlx/data/op/LoadNumpy.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include "mlx/data/op/KeyTransform.h" 6 | 7 | namespace mlx { 8 | namespace data { 9 | namespace op { 10 | 11 | class LoadNumpy : public KeyTransformOp { 12 | public: 13 | LoadNumpy( 14 | const std::string& ikey, 15 | const std::string& prefix = "", 16 | bool from_memory = false, 17 | const std::string& okey = ""); 18 | 19 | virtual std::shared_ptr apply_key( 20 | const std::shared_ptr& src) const override; 21 | 22 | private: 23 | std::string prefix_; 24 | bool from_memory_; 25 | }; 26 | 27 | } // namespace op 28 | } // namespace data 29 | } // namespace mlx 30 | -------------------------------------------------------------------------------- /mlx/data/op/LoadVideo.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include 4 | 5 | #include "mlx/data/core/video/Video.h" 6 | #include "mlx/data/op/LoadVideo.h" 7 | 8 | namespace mlx { 9 | namespace data { 10 | namespace op { 11 | LoadVideo::LoadVideo( 12 | const std::string& ikey, 13 | const std::string& prefix, 14 | bool info, 15 | bool from_memory, 16 | const std::string& okey) 17 | : KeyTransformOp(ikey, okey), 18 | prefix_(prefix), 19 | info_(info), 20 | from_memory_(from_memory) {} 21 | std::shared_ptr LoadVideo::apply_key( 22 | const std::shared_ptr& src) const { 23 | std::filesystem::path path; 24 | 25 | if (!from_memory_) { 26 | path = prefix_; 27 | if (src->type() != ArrayType::Int8) { 28 | throw std::runtime_error("LoadImage: char array (int8) expected"); 29 | } 30 | std::string filename(reinterpret_cast(src->data()), src->size()); 31 | path /= filename; 32 | } 33 | 34 | std::shared_ptr dst; 35 | if (info_) { 36 | auto info = from_memory_ ? core::video::info(src) : core::video::info(path); 37 | 38 | std::vector shape({info.width, info.height, info.frames}); 39 | dst = std::make_shared(shape); 40 | 41 | } else { 42 | dst = from_memory_ ? core::video::load(src) : core::video::load(path); 43 | if (!dst) { 44 | throw std::runtime_error( 45 | "LoadVideo: unable to load video <" + 46 | (from_memory_ ? "stream" : path.string()) + ">"); 47 | } 48 | } 49 | 50 | return dst; 51 | } 52 | } // namespace op 53 | } // namespace data 54 | } // namespace mlx 55 | -------------------------------------------------------------------------------- /mlx/data/op/LoadVideo.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include "mlx/data/op/KeyTransform.h" 6 | 7 | namespace mlx { 8 | namespace data { 9 | namespace op { 10 | 11 | class LoadVideo : public KeyTransformOp { 12 | public: 13 | LoadVideo( 14 | const std::string& ikey, 15 | const std::string& prefix = "", 16 | bool info = false, 17 | bool from_memory = false, 18 | const std::string& okey = ""); 19 | 20 | virtual std::shared_ptr apply_key( 21 | const std::shared_ptr& src) const override; 22 | 23 | private: 24 | std::string prefix_; 25 | bool info_; 26 | bool from_memory_; 27 | }; 28 | 29 | } // namespace op 30 | } // namespace data 31 | } // namespace mlx 32 | -------------------------------------------------------------------------------- /mlx/data/op/Op.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include 4 | 5 | #include "mlx/data/op/Op.h" 6 | 7 | namespace mlx { 8 | namespace data { 9 | namespace op { 10 | 11 | Sample Op::apply(const Sample& sample) const { 12 | throw std::runtime_error("Op::apply() NYI"); 13 | } 14 | 15 | Op::~Op() {} 16 | 17 | } // namespace op 18 | } // namespace data 19 | } // namespace mlx 20 | -------------------------------------------------------------------------------- /mlx/data/op/Op.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include 6 | #include 7 | #include 8 | 9 | #include "mlx/data/Array.h" 10 | #include "mlx/data/Sample.h" 11 | 12 | namespace mlx { 13 | namespace data { 14 | namespace op { 15 | 16 | class Op { 17 | public: 18 | Op() {}; 19 | 20 | // DEBUG: (debatable) sample could be not const 21 | virtual Sample apply(const Sample& sample) const; 22 | 23 | virtual ~Op(); 24 | }; 25 | 26 | } // namespace op 27 | } // namespace data 28 | } // namespace mlx 29 | -------------------------------------------------------------------------------- /mlx/data/op/Pad.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include "mlx/data/op/Pad.h" 4 | 5 | namespace mlx { 6 | namespace data { 7 | namespace op { 8 | 9 | Pad::Pad( 10 | const std::string& ikey, 11 | int dim, 12 | int64_t lpad, 13 | int64_t rpad, 14 | double value, 15 | const std::string& okey) 16 | : KeyTransformOp(ikey, okey), 17 | dim_(dim), 18 | lpad_(lpad), 19 | rpad_(rpad), 20 | value_(value) { 21 | if (lpad_ < 0 || rpad_ < 0) { 22 | throw std::runtime_error("Pad: pad value must be positive"); 23 | } 24 | } 25 | std::shared_ptr Pad::apply_key( 26 | const std::shared_ptr& src) const { 27 | auto dim = src->checkdim(dim_); 28 | return array::pad(src, dim, lpad_, rpad_, value_); 29 | } 30 | 31 | PadToSize::PadToSize( 32 | const std::string& ikey, 33 | int dim, 34 | int64_t size, 35 | double value, 36 | const std::string& okey) 37 | : KeyTransformOp(ikey, okey), dim_(dim), sizes_({size}), value_(value) {} 38 | PadToSize::PadToSize( 39 | const std::string& ikey, 40 | int dim, 41 | const std::vector& sizes, 42 | double value, 43 | const std::string& okey) 44 | : KeyTransformOp(ikey, okey), dim_(dim), sizes_(sizes), value_(value) {} 45 | std::shared_ptr PadToSize::apply_key( 46 | const std::shared_ptr& src) const { 47 | auto dim = src->checkdim(dim_); 48 | int64_t min_diff_idx = -1; 49 | int64_t min_diff_size = std::numeric_limits::max(); 50 | int64_t dim_size = src->shape(dim); 51 | for (int i = 0; i < sizes_.size(); i++) { 52 | auto diff_size = sizes_[i] - dim_size; 53 | if (diff_size > 0 && diff_size < min_diff_size) { 54 | min_diff_size = diff_size; 55 | min_diff_idx = i; 56 | } 57 | } 58 | if (min_diff_idx >= 0) { 59 | return array::pad(src, dim, 0, min_diff_size, value_); 60 | } else { 61 | return mlx::data::array::clone(src); 62 | } 63 | } 64 | 65 | PadToMultiple::PadToMultiple( 66 | const std::string& ikey, 67 | int dim, 68 | int64_t size, 69 | double value, 70 | const std::string& okey) 71 | : KeyTransformOp(ikey, okey), dim_(dim), size_(size), value_(value) {} 72 | std::shared_ptr PadToMultiple::apply_key( 73 | const std::shared_ptr& src) const { 74 | auto dim = src->checkdim(dim_); 75 | int64_t mod = src->shape(dim) % size_; 76 | if (mod != 0) { 77 | return array::pad(src, dim, 0, size_ - mod, value_); 78 | } else { 79 | return mlx::data::array::clone(src); 80 | } 81 | } 82 | 83 | } // namespace op 84 | } // namespace data 85 | } // namespace mlx 86 | -------------------------------------------------------------------------------- /mlx/data/op/Pad.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include "mlx/data/op/KeyTransform.h" 6 | 7 | namespace mlx { 8 | namespace data { 9 | namespace op { 10 | 11 | class Pad : public KeyTransformOp { 12 | public: 13 | Pad(const std::string& ikey, 14 | int dim, 15 | int64_t lpad, 16 | int64_t rpad, 17 | double value, 18 | const std::string& okey = ""); 19 | 20 | virtual std::shared_ptr apply_key( 21 | const std::shared_ptr& src) const override; 22 | 23 | private: 24 | int dim_; 25 | int64_t lpad_; 26 | int64_t rpad_; 27 | double value_; 28 | }; 29 | 30 | class PadToSize : public KeyTransformOp { 31 | public: 32 | PadToSize( 33 | const std::string& ikey, 34 | int dim, 35 | int64_t size, 36 | double value, 37 | const std::string& okey = ""); 38 | PadToSize( 39 | const std::string& ikey, 40 | int dim, 41 | const std::vector& sizes, 42 | double value, 43 | const std::string& okey = ""); 44 | 45 | virtual std::shared_ptr apply_key( 46 | const std::shared_ptr& src) const override; 47 | 48 | private: 49 | int dim_; 50 | std::vector sizes_; 51 | double value_; 52 | }; 53 | 54 | class PadToMultiple : public KeyTransformOp { 55 | public: 56 | PadToMultiple( 57 | const std::string& ikey, 58 | int dim, 59 | int64_t size, 60 | double value, 61 | const std::string& okey = ""); 62 | 63 | virtual std::shared_ptr apply_key( 64 | const std::shared_ptr& src) const override; 65 | 66 | private: 67 | int dim_; 68 | int64_t size_; 69 | double value_; 70 | }; 71 | 72 | } // namespace op 73 | } // namespace data 74 | } // namespace mlx 75 | -------------------------------------------------------------------------------- /mlx/data/op/ReadFromTAR.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include "mlx/data/op/ReadFromTAR.h" 4 | 5 | namespace mlx { 6 | namespace data { 7 | namespace op { 8 | ReadFromTAR::ReadFromTAR( 9 | const std::string& tarkey, 10 | const std::string& ikey, 11 | const std::string& okey, 12 | const std::filesystem::path& prefix, 13 | const std::filesystem::path& tar_prefix, 14 | bool from_key, 15 | std::shared_ptr fetcher, 16 | bool nested, 17 | int num_threads) 18 | : tarkey_(tarkey), 19 | ikey_(ikey), 20 | okey_(okey), 21 | prefix_(prefix), 22 | tarPrefix_(tar_prefix), 23 | fromKey_(from_key), 24 | fetcher_(fetcher), 25 | nested_(nested), 26 | numThreads_(num_threads) { 27 | if (!from_key) { 28 | // load the tar index 29 | get_tar_reader_(tarkey); 30 | } 31 | } 32 | std::pair< 33 | std::shared_ptr, 34 | std::shared_ptr> 35 | ReadFromTAR::get_tar_reader_(const std::string& key) const { 36 | // make sure tar file is actually on disk 37 | std::shared_ptr handle; 38 | if (fetcher_) { 39 | handle = fetcher_->fetch(key); 40 | } 41 | { 42 | std::shared_lock slock(mutex_); 43 | auto it = tars_.find(key); 44 | if (it != tars_.end()) { 45 | return std::make_pair(it->second, handle); 46 | } 47 | } 48 | { 49 | std::unique_lock ulock(mutex_); 50 | auto key_path = tarPrefix_ / key; 51 | auto tar = std::make_shared( 52 | key_path.string(), nested_, numThreads_); 53 | tars_[key] = tar; 54 | return std::make_pair(tar, handle); 55 | } 56 | } 57 | Sample ReadFromTAR::apply(const Sample& sample) const { 58 | std::string tarfilename; 59 | if (fromKey_) { 60 | auto tarfilename_array = 61 | sample::check_key(sample, tarkey_, ArrayType::Int8); 62 | tarfilename = std::string( 63 | reinterpret_cast(tarfilename_array->data()), 64 | tarfilename_array->size()); 65 | } else { 66 | tarfilename = tarkey_; 67 | } 68 | auto tar = get_tar_reader_(tarfilename); 69 | std::shared_ptr input_array; 70 | input_array = sample::check_key(sample, ikey_, ArrayType::Int8); 71 | std::string filename( 72 | reinterpret_cast(input_array->data()), input_array->size()); 73 | auto filepath = prefix_ / filename; 74 | auto output_array = tar.first->get(filepath.string()); 75 | auto res = sample; 76 | res[okey_] = output_array; 77 | return res; 78 | } 79 | } // namespace op 80 | } // namespace data 81 | } // namespace mlx 82 | -------------------------------------------------------------------------------- /mlx/data/op/ReadFromTAR.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | #include "mlx/data/core/FileFetcher.h" 12 | #include "mlx/data/core/TARReader.h" 13 | #include "mlx/data/op/Op.h" 14 | 15 | namespace mlx { 16 | namespace data { 17 | namespace op { 18 | 19 | class ReadFromTAR : public Op { 20 | public: 21 | ReadFromTAR( 22 | const std::string& tarkey, 23 | const std::string& ikey, 24 | const std::string& okey, 25 | const std::filesystem::path& prefix = "", 26 | const std::filesystem::path& tar_prefix = "", 27 | bool from_key = false, 28 | std::shared_ptr fetcher = nullptr, 29 | bool nested = false, 30 | int num_threads = 1); 31 | 32 | virtual Sample apply(const Sample& sample) const override; 33 | 34 | private: 35 | std::pair< 36 | std::shared_ptr, 37 | std::shared_ptr> 38 | get_tar_reader_(const std::string& key) const; 39 | 40 | std::string tarkey_; 41 | std::string ikey_; 42 | std::string okey_; 43 | std::filesystem::path prefix_; 44 | std::filesystem::path tarPrefix_; 45 | bool fromKey_; 46 | std::shared_ptr fetcher_; 47 | bool nested_; 48 | int numThreads_; 49 | mutable std::unordered_map> 50 | tars_; 51 | mutable std::shared_mutex mutex_; 52 | }; 53 | 54 | } // namespace op 55 | } // namespace data 56 | } // namespace mlx 57 | -------------------------------------------------------------------------------- /mlx/data/op/RemoveValue.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include "mlx/data/op/RemoveValue.h" 4 | #include "mlx/data/core/Utils.h" 5 | 6 | namespace mlx { 7 | namespace data { 8 | namespace op { 9 | RemoveValue::RemoveValue( 10 | const std::string& key, 11 | const std::string& size_key, 12 | int dim, 13 | double value, 14 | double pad) 15 | : key_(key), size_key_(size_key), dim_(dim), value_(value), pad_(pad) {} 16 | 17 | Sample RemoveValue::apply(const Sample& sample) const { 18 | auto array = sample::check_key(sample, key_, ArrayType::Any); 19 | auto size_array = sample::check_key(sample, size_key_, ArrayType::Int64); 20 | std::tie(array, size_array) = 21 | core::remove(array, size_array, dim_, value_, pad_); 22 | auto new_sample = sample; 23 | new_sample[key_] = array; 24 | new_sample[size_key_] = size_array; 25 | return new_sample; 26 | } 27 | } // namespace op 28 | } // namespace data 29 | } // namespace mlx 30 | -------------------------------------------------------------------------------- /mlx/data/op/RemoveValue.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include "mlx/data/op/Op.h" 6 | 7 | namespace mlx { 8 | namespace data { 9 | namespace op { 10 | 11 | class RemoveValue : public Op { 12 | public: 13 | RemoveValue( 14 | const std::string& key, 15 | const std::string& size_key, 16 | int dim, 17 | double value, 18 | double pad); 19 | 20 | virtual Sample apply(const Sample& sample) const override; 21 | 22 | private: 23 | std::string key_; 24 | std::string size_key_; 25 | int dim_; 26 | double value_; 27 | double pad_; 28 | }; 29 | 30 | } // namespace op 31 | } // namespace data 32 | } // namespace mlx 33 | -------------------------------------------------------------------------------- /mlx/data/op/RenameKey.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include "mlx/data/op/RenameKey.h" 4 | 5 | namespace mlx { 6 | namespace data { 7 | namespace op { 8 | RenameKey::RenameKey(const std::string& ikey, const std::string& okey) 9 | : ikey_(ikey), okey_(okey) {} 10 | Sample RenameKey::apply(const Sample& sample) const { 11 | auto input_array = sample::check_key(sample, ikey_, ArrayType::Any); 12 | if (ikey_ == okey_) { 13 | return sample; 14 | } 15 | auto res = sample; 16 | res[okey_] = input_array; 17 | res.erase(ikey_); 18 | return res; 19 | } 20 | } // namespace op 21 | } // namespace data 22 | } // namespace mlx 23 | -------------------------------------------------------------------------------- /mlx/data/op/RenameKey.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include "mlx/data/op/Op.h" 6 | 7 | namespace mlx { 8 | namespace data { 9 | namespace op { 10 | 11 | class RenameKey : public Op { 12 | public: 13 | RenameKey(const std::string& ikey, const std::string& okey); 14 | 15 | virtual Sample apply(const Sample& sample) const override; 16 | 17 | private: 18 | std::string ikey_; 19 | std::string okey_; 20 | }; 21 | 22 | } // namespace op 23 | } // namespace data 24 | } // namespace mlx 25 | -------------------------------------------------------------------------------- /mlx/data/op/Replace.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | #include "mlx/data/op/Replace.h" 4 | #include "mlx/data/core/Utils.h" 5 | 6 | namespace mlx { 7 | namespace data { 8 | namespace op { 9 | 10 | Replace::Replace( 11 | const std::string& key, 12 | const std::string& old, 13 | const std::string& replacement, 14 | int count) 15 | : KeyTransformOp(key), 16 | old_(std::make_shared(old)), 17 | replacement_(std::make_shared(replacement)), 18 | count_(count) {} 19 | 20 | std::shared_ptr Replace::apply_key( 21 | const std::shared_ptr& src) const { 22 | return core::replace(src, old_, replacement_, count_); 23 | } 24 | 25 | ReplaceBytes::ReplaceBytes( 26 | const std::string& ikey, 27 | std::vector byte_map, 28 | const std::string& okey) 29 | : KeyTransformOp(ikey, okey), byte_map_(std::move(byte_map)) { 30 | while (byte_map_.size() < 256) { 31 | byte_map_.emplace_back(""); 32 | } 33 | } 34 | 35 | std::shared_ptr ReplaceBytes::apply_key( 36 | const std::shared_ptr& src) const { 37 | std::string result; 38 | // waste some space but ensure that we most often we do only 2 allocations 39 | result.reserve(2 * src->size() * src->itemsize()); 40 | 41 | void* raw_data = src->data(); 42 | uint8_t* byte_data = reinterpret_cast(raw_data); 43 | for (int64_t i = 0; i < src->size() * src->itemsize(); i++) { 44 | result += byte_map_[*byte_data]; 45 | byte_data++; 46 | } 47 | 48 | return std::make_shared(result); 49 | } 50 | 51 | } // namespace op 52 | } // namespace data 53 | } // namespace mlx 54 | -------------------------------------------------------------------------------- /mlx/data/op/Replace.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include 6 | 7 | #include "mlx/data/op/KeyTransform.h" 8 | 9 | namespace mlx { 10 | namespace data { 11 | namespace op { 12 | 13 | class Replace : public KeyTransformOp { 14 | public: 15 | Replace( 16 | const std::string& key, 17 | const std::string& old, 18 | const std::string& replacement, 19 | int count); 20 | 21 | virtual std::shared_ptr apply_key( 22 | const std::shared_ptr& src) const override; 23 | 24 | private: 25 | std::string key_; 26 | std::shared_ptr old_; 27 | std::shared_ptr replacement_; 28 | int count_; 29 | }; 30 | 31 | class ReplaceBytes : public KeyTransformOp { 32 | public: 33 | ReplaceBytes( 34 | const std::string& ikey, 35 | std::vector byte_map, 36 | const std::string& okey = ""); 37 | 38 | virtual std::shared_ptr apply_key( 39 | const std::shared_ptr& src) const override; 40 | 41 | private: 42 | std::vector byte_map_; 43 | }; 44 | 45 | } // namespace op 46 | } // namespace data 47 | } // namespace mlx 48 | -------------------------------------------------------------------------------- /mlx/data/op/SampleTransform.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include "mlx/data/op/SampleTransform.h" 4 | 5 | namespace mlx { 6 | namespace data { 7 | namespace op { 8 | 9 | SampleTransform::SampleTransform(std::function op) 10 | : op_(op) {} 11 | 12 | Sample SampleTransform::apply(const Sample& sample) const { 13 | return op_(sample); 14 | } 15 | 16 | } // namespace op 17 | } // namespace data 18 | } // namespace mlx 19 | -------------------------------------------------------------------------------- /mlx/data/op/SampleTransform.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include 6 | 7 | #include "mlx/data/op/Op.h" 8 | 9 | namespace mlx { 10 | namespace data { 11 | namespace op { 12 | 13 | class SampleTransform : public Op { 14 | public: 15 | SampleTransform(std::function op); 16 | 17 | virtual Sample apply(const Sample& sample) const; 18 | 19 | private: 20 | std::function op_; 21 | }; 22 | 23 | } // namespace op 24 | } // namespace data 25 | } // namespace mlx 26 | -------------------------------------------------------------------------------- /mlx/data/op/SaveImage.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include 4 | #include 5 | 6 | #include "mlx/data/core/image/Image.h" 7 | #include "mlx/data/op/SaveImage.h" 8 | 9 | namespace mlx { 10 | namespace data { 11 | namespace op { 12 | SaveImage::SaveImage( 13 | const std::string& image_key, 14 | const std::string& filename_key, 15 | const std::string& prefix, 16 | const std::string& filename_prefix) 17 | : imageKey_(image_key), 18 | filenameKey_(filename_key), 19 | prefix_(prefix), 20 | filenamePrefix_(filename_prefix) {} 21 | Sample SaveImage::apply(const Sample& sample) const { 22 | std::shared_ptr input_array; 23 | input_array = sample::check_key(sample, imageKey_, ArrayType::UInt8); 24 | 25 | std::filesystem::path path = prefix_; 26 | 27 | std::shared_ptr base_filename_array; 28 | base_filename_array = 29 | sample::check_key(sample, filenameKey_, ArrayType::Int8); 30 | std::string base_filename( 31 | reinterpret_cast(base_filename_array->data()), 32 | base_filename_array->size()); 33 | 34 | if (filenamePrefix_.length() > 0) { 35 | path /= filenamePrefix_ + base_filename; 36 | } else { 37 | path /= base_filename; 38 | } 39 | 40 | auto shape = input_array->shape(); 41 | if (shape.size() == 4) { 42 | // a vector of images (e.g. a video) 43 | for (int i = 0; i < shape[0]; i++) { 44 | auto frame = array::slice(input_array, i); 45 | 46 | std::stringstream ext; 47 | ext << std::setw(6) << std::setfill('0') << i << ".jpg"; 48 | 49 | auto frame_path = path; 50 | frame_path.replace_extension(ext.str()); 51 | 52 | if (!core::image::save(frame, frame_path)) { 53 | throw std::runtime_error( 54 | "SaveImage: unable to save frame " + frame_path.string()); 55 | } 56 | } 57 | } else { 58 | // simple image: HxWxC 59 | path.replace_extension("jpg"); 60 | if (!core::image::save(input_array, path)) { 61 | throw std::runtime_error( 62 | "SaveImage: no provider to save image " + path.string()); 63 | } 64 | } 65 | 66 | return sample; 67 | } 68 | } // namespace op 69 | } // namespace data 70 | } // namespace mlx 71 | -------------------------------------------------------------------------------- /mlx/data/op/SaveImage.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include "mlx/data/op/Op.h" 6 | 7 | namespace mlx { 8 | namespace data { 9 | namespace op { 10 | 11 | class SaveImage : public Op { 12 | public: 13 | SaveImage( 14 | const std::string& image_key, 15 | const std::string& filename_key, 16 | const std::string& prefix = "", 17 | const std::string& filename_prefix = ""); 18 | 19 | virtual Sample apply(const Sample& sample) const override; 20 | 21 | private: 22 | std::string imageKey_; 23 | std::string filenameKey_; 24 | std::string prefix_; 25 | std::string filenamePrefix_; 26 | }; 27 | 28 | } // namespace op 29 | } // namespace data 30 | } // namespace mlx 31 | -------------------------------------------------------------------------------- /mlx/data/op/Shape.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include "mlx/data/op/Shape.h" 4 | 5 | namespace mlx { 6 | namespace data { 7 | namespace op { 8 | Shape::Shape(const std::string& ikey, int dim, const std::string& okey) 9 | : ikey_(ikey), dim_(dim), okey_(okey), fullShape_(false) {} 10 | Shape::Shape(const std::string& ikey, const std::string& okey) 11 | : ikey_(ikey), okey_(okey), fullShape_(true) {} 12 | Sample Shape::apply(const Sample& sample) const { 13 | auto input_array = sample::check_key(sample, ikey_, ArrayType::Any); 14 | std::shared_ptr output_array; 15 | if (fullShape_) { 16 | output_array = std::make_shared(input_array->shape()); 17 | } else { 18 | auto dim = input_array->checkdim(dim_); 19 | output_array = std::make_shared(input_array->shape(dim)); 20 | } 21 | auto res = sample; 22 | res[okey_] = output_array; 23 | return res; 24 | } 25 | } // namespace op 26 | } // namespace data 27 | } // namespace mlx 28 | -------------------------------------------------------------------------------- /mlx/data/op/Shape.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include "mlx/data/op/Op.h" 6 | 7 | namespace mlx { 8 | namespace data { 9 | namespace op { 10 | 11 | class Shape : public Op { 12 | public: 13 | Shape(const std::string& ikey, int dim, const std::string& okey); 14 | Shape(const std::string& ikey, const std::string& okey); 15 | 16 | virtual Sample apply(const Sample& sample) const override; 17 | 18 | private: 19 | std::string ikey_; 20 | int dim_; 21 | std::string okey_; 22 | bool fullShape_; 23 | }; 24 | 25 | } // namespace op 26 | } // namespace data 27 | } // namespace mlx 28 | -------------------------------------------------------------------------------- /mlx/data/op/Shard.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include "mlx/data/op/Shard.h" 4 | 5 | namespace mlx { 6 | namespace data { 7 | namespace op { 8 | Shard::Shard(const std::string& ikey, int64_t n_shards, const std::string& okey) 9 | : KeyTransformOp(ikey, okey), nShards_(n_shards) {} 10 | 11 | std::shared_ptr Shard::apply_key( 12 | const std::shared_ptr& src) const { 13 | std::vector shape = src->shape(); 14 | if (shape.size() > 0) { 15 | shape[0] = -1; 16 | shape.insert(shape.begin(), nShards_); 17 | return mlx::data::array::reshape(src, shape); 18 | } else { 19 | return mlx::data::array::clone(src); 20 | } 21 | } 22 | } // namespace op 23 | } // namespace data 24 | } // namespace mlx 25 | -------------------------------------------------------------------------------- /mlx/data/op/Shard.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include "mlx/data/op/KeyTransform.h" 6 | 7 | namespace mlx { 8 | namespace data { 9 | namespace op { 10 | 11 | class Shard : public KeyTransformOp { 12 | public: 13 | Shard( 14 | const std::string& ikey, 15 | int64_t n_shards, 16 | const std::string& okey = ""); 17 | 18 | virtual std::shared_ptr apply_key( 19 | const std::shared_ptr& src) const override; 20 | 21 | private: 22 | int64_t nShards_; 23 | }; 24 | 25 | } // namespace op 26 | } // namespace data 27 | } // namespace mlx 28 | -------------------------------------------------------------------------------- /mlx/data/op/Slice.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include "mlx/data/op/KeyTransform.h" 6 | 7 | namespace mlx { 8 | namespace data { 9 | namespace op { 10 | 11 | class Slice : public KeyTransformOp { 12 | public: 13 | Slice( 14 | const std::string& ikey, 15 | int dim, 16 | int64_t start, 17 | int64_t end, 18 | const std::string& okey = ""); 19 | Slice( 20 | const std::string& ikey, 21 | std::vector dims, 22 | std::vector starts, 23 | std::vector ends, 24 | const std::string& okey = ""); 25 | 26 | virtual std::shared_ptr apply_key( 27 | const std::shared_ptr& src) const override; 28 | 29 | private: 30 | std::vector dims_; 31 | std::vector starts_; 32 | std::vector ends_; 33 | }; 34 | 35 | class RandomSlice : public KeyTransformOp { 36 | public: 37 | RandomSlice( 38 | const std::string& ikey, 39 | int dim, 40 | int64_t size, 41 | const std::string& okey = ""); 42 | RandomSlice( 43 | const std::string& ikey, 44 | std::vector dims, 45 | std::vector sizes, 46 | const std::string& okey = ""); 47 | 48 | virtual std::shared_ptr apply_key( 49 | const std::shared_ptr& src) const override; 50 | 51 | private: 52 | std::vector dims_; 53 | std::vector sizes_; 54 | }; 55 | 56 | } // namespace op 57 | } // namespace data 58 | } // namespace mlx 59 | -------------------------------------------------------------------------------- /mlx/data/op/Squeeze.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include "mlx/data/op/Squeeze.h" 4 | 5 | namespace mlx { 6 | namespace data { 7 | namespace op { 8 | Squeeze::Squeeze(const std::string& ikey, const std::string& okey) 9 | : KeyTransformOp(ikey, okey) {} 10 | Squeeze::Squeeze(const std::string& ikey, int dim, const std::string& okey) 11 | : KeyTransformOp(ikey, okey), dims_({dim}) {} 12 | Squeeze::Squeeze( 13 | const std::string& ikey, 14 | const std::vector& dims, 15 | const std::string& okey) 16 | : KeyTransformOp(ikey, okey), dims_(dims) {} 17 | 18 | std::shared_ptr Squeeze::apply_key( 19 | const std::shared_ptr& src) const { 20 | return mlx::data::array::squeeze(src, dims_); 21 | } 22 | } // namespace op 23 | } // namespace data 24 | } // namespace mlx 25 | -------------------------------------------------------------------------------- /mlx/data/op/Squeeze.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include "mlx/data/op/KeyTransform.h" 6 | 7 | namespace mlx { 8 | namespace data { 9 | namespace op { 10 | 11 | class Squeeze : public KeyTransformOp { 12 | public: 13 | Squeeze(const std::string& ikey, const std::string& okey = ""); 14 | Squeeze(const std::string& ikey, int dim, const std::string& okey = ""); 15 | Squeeze( 16 | const std::string& ikey, 17 | const std::vector& dims, 18 | const std::string& okey = ""); 19 | 20 | virtual std::shared_ptr apply_key( 21 | const std::shared_ptr& src) const override; 22 | 23 | private: 24 | std::vector dims_; 25 | }; 26 | 27 | } // namespace op 28 | } // namespace data 29 | } // namespace mlx 30 | -------------------------------------------------------------------------------- /mlx/data/op/Tokenize.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include "mlx/data/op/Tokenize.h" 4 | 5 | namespace mlx { 6 | namespace data { 7 | namespace op { 8 | 9 | Tokenize::Tokenize( 10 | const std::string& ikey, 11 | std::shared_ptr> trie, 12 | TokenizeMode mode, 13 | bool ignore_unk, 14 | const std::vector& trie_key_scores, 15 | const std::string& okey) 16 | : KeyTransformOp(ikey, okey), 17 | tokenizer_(trie, ignore_unk, trie_key_scores), 18 | mode_(mode) {} 19 | 20 | std::shared_ptr Tokenize::apply_key( 21 | const std::shared_ptr& src) const { 22 | std::string str( 23 | reinterpret_cast(src->data()), src->size() * src->itemsize()); 24 | 25 | std::vector tokens; 26 | switch (mode_) { 27 | case TokenizeMode::shortest: 28 | tokens = tokenizer_.tokenize_shortest(str); 29 | break; 30 | case TokenizeMode::rand: 31 | tokens = tokenizer_.tokenize_rand(str); 32 | break; 33 | default: 34 | throw std::runtime_error("Tokenize: unsupported tokenize mode"); 35 | } 36 | 37 | return std::make_shared(tokens); 38 | } 39 | 40 | BPETokenize::BPETokenize( 41 | const std::string& ikey, 42 | std::shared_ptr> symbols, 43 | std::shared_ptr merges, 44 | const std::string& okey) 45 | : KeyTransformOp(ikey, okey), tokenizer_(symbols, merges) {} 46 | 47 | std::shared_ptr BPETokenize::apply_key( 48 | const std::shared_ptr& src) const { 49 | auto tokens = tokenizer_.tokenize(std::string_view( 50 | reinterpret_cast(src->data()), src->size() * src->itemsize())); 51 | return std::make_shared(tokens); 52 | } 53 | 54 | } // namespace op 55 | } // namespace data 56 | } // namespace mlx 57 | -------------------------------------------------------------------------------- /mlx/data/op/Tokenize.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include "mlx/data/core/BPETokenizer.h" 6 | #include "mlx/data/core/Tokenizer.h" 7 | #include "mlx/data/core/Trie.h" 8 | #include "mlx/data/op/KeyTransform.h" 9 | 10 | namespace mlx { 11 | namespace data { 12 | namespace op { 13 | 14 | enum class TokenizeMode { shortest, rand }; 15 | 16 | class Tokenize : public KeyTransformOp { 17 | public: 18 | Tokenize( 19 | const std::string& ikey, 20 | std::shared_ptr> trie, 21 | TokenizeMode mode, 22 | bool ignore_unk = false, 23 | const std::vector& trie_key_scores = {}, 24 | const std::string& okey = ""); 25 | 26 | virtual std::shared_ptr apply_key( 27 | const std::shared_ptr& src) const override; 28 | 29 | private: 30 | core::Tokenizer tokenizer_; 31 | TokenizeMode mode_; 32 | }; 33 | 34 | class BPETokenize : public KeyTransformOp { 35 | public: 36 | BPETokenize( 37 | const std::string& ikey, 38 | std::shared_ptr> symbols, 39 | std::shared_ptr merges, 40 | const std::string& okey = ""); 41 | 42 | virtual std::shared_ptr apply_key( 43 | const std::shared_ptr& src) const override; 44 | 45 | private: 46 | core::BPETokenizer tokenizer_; 47 | }; 48 | 49 | } // namespace op 50 | } // namespace data 51 | } // namespace mlx 52 | -------------------------------------------------------------------------------- /mlx/data/stream/Batch.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include "mlx/data/stream/Batch.h" 4 | #include 5 | #include "mlx/data/core/ThreadPool.h" 6 | #include "mlx/data/core/Utils.h" 7 | 8 | namespace mlx { 9 | namespace data { 10 | namespace stream { 11 | Batch::Batch( 12 | const std::shared_ptr& stream, 13 | int64_t batch_size, 14 | const std::unordered_map& pad_values, 15 | const std::unordered_map& batch_dims) 16 | : stream_(stream), 17 | batchSize_(batch_size), 18 | padValues_(pad_values), 19 | batchDims_(batch_dims) { 20 | if (batch_size <= 0) { 21 | throw std::runtime_error("Batch: batch size must be positive"); 22 | } 23 | } 24 | 25 | Sample Batch::next() const { 26 | std::vector samples; 27 | for (int i = 0; i < batchSize_; i++) { 28 | auto sample = stream_->next(); 29 | if (sample.empty()) { 30 | break; 31 | } 32 | samples.push_back(std::move(sample)); 33 | } 34 | if (samples.empty()) { 35 | return Sample(); 36 | } else { 37 | return core::merge_batch(samples, padValues_, batchDims_); 38 | } 39 | } 40 | 41 | void Batch::reset() { 42 | stream_->reset(); 43 | } 44 | 45 | } // namespace stream 46 | } // namespace data 47 | } // namespace mlx 48 | -------------------------------------------------------------------------------- /mlx/data/stream/Batch.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include "mlx/data/stream/Stream.h" 6 | 7 | namespace mlx { 8 | namespace data { 9 | namespace stream { 10 | 11 | class Batch : public Stream { 12 | public: 13 | Batch( 14 | const std::shared_ptr& stream, 15 | int64_t batch_size, 16 | const std::unordered_map& pad_values = {}, 17 | const std::unordered_map& batch_dims = {}); 18 | virtual Sample next() const override; 19 | virtual void reset() override; 20 | 21 | private: 22 | std::shared_ptr stream_; 23 | int64_t batchSize_; 24 | std::unordered_map padValues_; 25 | std::unordered_map batchDims_; 26 | }; 27 | 28 | } // namespace stream 29 | } // namespace data 30 | } // namespace mlx 31 | -------------------------------------------------------------------------------- /mlx/data/stream/Buffered.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include "mlx/data/stream/Buffered.h" 4 | #include "mlx/data/buffer/FromVector.h" 5 | #include "mlx/data/stream/FromBuffer.h" 6 | 7 | namespace mlx { 8 | namespace data { 9 | namespace stream { 10 | 11 | using FromVector = mlx::data::buffer::FromVector; 12 | 13 | Buffered::Buffered( 14 | const std::shared_ptr& stream, 15 | int64_t buffer_size, 16 | std::function( 17 | const std::shared_ptr&)> on_refill, 18 | int num_thread) 19 | : stream_(stream), 20 | bufferSize_(buffer_size), 21 | onRefill_(on_refill), 22 | pool_(std::make_shared(num_thread + 1)), 23 | currentIndex_(0), 24 | buffer_(nullptr) {} 25 | 26 | std::future> 27 | Buffered::background_buffer_fetch_() const { 28 | return pool_->enqueue([this]() -> std::shared_ptr { 29 | std::vector> future_buffer; 30 | for (int i = 0; i < bufferSize_; i++) { 31 | future_buffer.push_back( 32 | pool_->enqueue([this] { return stream_->next(); })); 33 | } 34 | std::vector buffer; 35 | for (auto& fsample : future_buffer) { 36 | Sample sample = fsample.get(); 37 | if (!sample.empty()) { 38 | buffer.push_back(sample); 39 | } 40 | } 41 | 42 | return onRefill_(std::make_shared(buffer)); 43 | }); 44 | } 45 | 46 | Sample Buffered::next() const { 47 | std::unique_lock lock(mutex_); 48 | 49 | // First run 50 | if (buffer_ == nullptr) { 51 | buffer_ = background_buffer_fetch_().get(); 52 | nextBuffer_ = background_buffer_fetch_(); 53 | } 54 | 55 | // We are done 56 | if (buffer_->size() == 0) { 57 | return Sample(); 58 | } 59 | 60 | // Normal running 61 | if (currentIndex_ >= buffer_->size()) { 62 | currentIndex_ = 0; 63 | buffer_ = nextBuffer_.get(); 64 | nextBuffer_ = background_buffer_fetch_(); 65 | 66 | if (buffer_->size() == 0) { 67 | return Sample(); 68 | } 69 | } 70 | 71 | return buffer_->get(currentIndex_++); 72 | } 73 | 74 | void Buffered::reset() { 75 | std::unique_lock lock(mutex_); 76 | 77 | buffer_ = nullptr; 78 | if (nextBuffer_.valid()) { 79 | nextBuffer_.get(); 80 | } 81 | stream_->reset(); 82 | } 83 | 84 | std::shared_ptr Buffered::on_refill_default( 85 | const std::shared_ptr& buffer) { 86 | return buffer; 87 | } 88 | 89 | } // namespace stream 90 | } // namespace data 91 | } // namespace mlx 92 | -------------------------------------------------------------------------------- /mlx/data/stream/Buffered.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include 6 | #include 7 | 8 | #include "mlx/data/buffer/Buffer.h" 9 | #include "mlx/data/stream/Stream.h" 10 | 11 | namespace mlx { 12 | namespace data { 13 | namespace stream { 14 | 15 | class Buffered : public Stream { 16 | public: 17 | Buffered( 18 | const std::shared_ptr& stream, 19 | int64_t buffer_size, 20 | std::function( 21 | const std::shared_ptr&)> on_refill = 22 | on_refill_default, 23 | int num_thread = 1); 24 | 25 | virtual Sample next() const override; 26 | virtual void reset() override; 27 | 28 | static std::shared_ptr on_refill_default( 29 | const std::shared_ptr& buffer); 30 | 31 | private: 32 | std::future> background_buffer_fetch_() const; 33 | 34 | std::shared_ptr stream_; // underlying stream 35 | int64_t bufferSize_; // how many buffer items 36 | std::function( 37 | const std::shared_ptr)> 38 | onRefill_; // operation to be performed on top of buffered items 39 | 40 | std::shared_ptr pool_; 41 | mutable int currentIndex_; 42 | mutable std::shared_ptr buffer_; 43 | mutable std::future> nextBuffer_; 44 | mutable std::shared_mutex mutex_; 45 | }; 46 | 47 | } // namespace stream 48 | } // namespace data 49 | } // namespace mlx 50 | -------------------------------------------------------------------------------- /mlx/data/stream/CSVReader.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include "mlx/data/stream/CSVReader.h" 4 | #include "mlx/data/core/imemstream.h" 5 | 6 | namespace mlx { 7 | namespace data { 8 | namespace stream { 9 | CSVReader::CSVReader( 10 | const std::string& filename, 11 | char sep, 12 | char quote, 13 | const std::filesystem::path& local_prefix, 14 | std::shared_ptr fetcher) { 15 | if (fetcher) { 16 | fileHandle_ = fetcher->fetch(filename); 17 | } 18 | auto file_path = local_prefix / filename; 19 | csv_ = std::make_unique(file_path.string(), sep, quote); 20 | keys_ = csv_->next(); 21 | } 22 | CSVReader::CSVReader( 23 | const std::shared_ptr& f, 24 | char sep, 25 | char quote, 26 | std::shared_ptr file_handle) 27 | : fileHandle_(file_handle) { 28 | csv_ = std::make_unique(f, sep, quote); 29 | keys_ = csv_->next(); 30 | } 31 | void CSVReader::reset() { 32 | std::unique_lock lock(mutex_); 33 | csv_->reset(); 34 | csv_->next(); // keys 35 | } 36 | 37 | Sample CSVReader::next() const { 38 | std::vector sample_str; 39 | { 40 | std::unique_lock lock(mutex_); 41 | sample_str = csv_->next(); 42 | } 43 | if (sample_str.empty()) { 44 | return Sample(); 45 | } 46 | if (sample_str.size() != keys_.size()) { 47 | throw std::runtime_error("CSVReader: inconsistent number of fields"); 48 | } 49 | Sample sample; 50 | for (size_t i = 0; i < sample_str.size(); i++) { 51 | sample[keys_.at(i)] = std::make_shared(sample_str.at(i)); 52 | } 53 | return sample; 54 | } 55 | 56 | CSVReaderFromKey::CSVReaderFromKey( 57 | std::shared_ptr stream, 58 | const std::string& key, 59 | char sep, 60 | char quote, 61 | bool fromMemory, 62 | const std::filesystem::path& local_prefix, 63 | const std::shared_ptr& fetcher) 64 | : Compose(stream, [=](const Sample& sample) { 65 | if (fromMemory) { 66 | auto array = 67 | sample::check_key(sample, key, mlx::data::ArrayType::UInt8); 68 | auto ms = std::make_shared(array); 69 | return std::make_shared(ms, sep, quote); 70 | } else { 71 | auto array = 72 | sample::check_key(sample, key, mlx::data::ArrayType::Int8); 73 | std::string filename( 74 | reinterpret_cast(array->data()), array->size()); 75 | return std::make_shared( 76 | filename, sep, quote, local_prefix, fetcher); 77 | } 78 | }) {} 79 | 80 | } // namespace stream 81 | } // namespace data 82 | } // namespace mlx 83 | -------------------------------------------------------------------------------- /mlx/data/stream/CSVReader.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include 6 | #include 7 | 8 | #include "mlx/data/core/CSVReader.h" 9 | #include "mlx/data/core/FileFetcher.h" 10 | #include "mlx/data/stream/Compose.h" 11 | #include "mlx/data/stream/Stream.h" 12 | 13 | namespace mlx { 14 | namespace data { 15 | namespace stream { 16 | 17 | class CSVReader : public Stream { 18 | public: 19 | CSVReader( 20 | const std::string& filename, 21 | char sep = ',', 22 | char quote = '"', 23 | const std::filesystem::path& local_prefix = "", 24 | std::shared_ptr fetcher = nullptr); 25 | CSVReader( 26 | const std::shared_ptr& f, 27 | char sep = ',', 28 | char quote = '"', 29 | std::shared_ptr file_handle = nullptr); 30 | virtual Sample next() const override; 31 | void reset() override; 32 | 33 | private: 34 | std::unique_ptr csv_; 35 | std::vector keys_; 36 | std::shared_ptr fileHandle_; 37 | mutable std::mutex mutex_; 38 | }; 39 | 40 | class CSVReaderFromKey : public Compose { 41 | public: 42 | CSVReaderFromKey( 43 | std::shared_ptr stream, 44 | const std::string& key, 45 | char sep = ',', 46 | char quote = '"', 47 | bool from_memory = false, 48 | const std::filesystem::path& local_prefix = "", 49 | const std::shared_ptr& fetcher = nullptr); 50 | }; 51 | 52 | } // namespace stream 53 | } // namespace data 54 | } // namespace mlx 55 | -------------------------------------------------------------------------------- /mlx/data/stream/Compose.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include 4 | 5 | #include "mlx/data/stream/Compose.h" 6 | 7 | namespace mlx { 8 | namespace data { 9 | namespace stream { 10 | 11 | Compose::Compose( 12 | std::shared_ptr& stream, 13 | std::function(const Sample& sample)> op) 14 | : stream_(stream), op_(op) {}; 15 | 16 | bool Compose::next_stream_() const { 17 | auto sample = stream_->next(); 18 | if (sample.empty()) { 19 | return false; 20 | } 21 | composedStream_ = op_(sample); 22 | if (!composedStream_) { 23 | throw std::runtime_error( 24 | "Compose: composer unexpectedly returned a nullptr stream"); 25 | } 26 | return true; 27 | } 28 | 29 | Sample Compose::next() const { 30 | // note: composedStream_ is read by many threads 31 | // and written by one thread once in a while 32 | std::shared_lock slock(mutex_); 33 | 34 | // Composed stream is not created yet 35 | if (composedStream_ == nullptr) { 36 | slock.unlock(); 37 | { 38 | std::unique_lock ulock(mutex_); 39 | if (!composedStream_) { 40 | if (!next_stream_()) { 41 | return Sample(); // EOF 42 | } 43 | } 44 | } 45 | slock.lock(); 46 | } 47 | 48 | Sample sample; 49 | while (sample.empty()) { 50 | sample = composedStream_->next(); 51 | if (sample.empty()) { 52 | slock.unlock(); 53 | { 54 | std::unique_lock ulock(mutex_); 55 | // maybe we got the lock after the stream was updated 56 | sample = composedStream_->next(); 57 | if (sample.empty()) { 58 | if (!next_stream_()) { 59 | return sample; // EOF 60 | } 61 | sample = composedStream_->next(); 62 | } 63 | } 64 | slock.lock(); 65 | } 66 | } 67 | 68 | return sample; 69 | } 70 | 71 | void Compose::reset() { 72 | std::unique_lock lock(mutex_); 73 | stream_->reset(); 74 | composedStream_ = nullptr; 75 | } 76 | 77 | } // namespace stream 78 | } // namespace data 79 | } // namespace mlx 80 | -------------------------------------------------------------------------------- /mlx/data/stream/Compose.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include 6 | #include 7 | #include 8 | 9 | #include "mlx/data/stream/Stream.h" 10 | 11 | namespace mlx { 12 | namespace data { 13 | namespace stream { 14 | 15 | class Compose : public Stream { 16 | public: 17 | Compose( 18 | std::shared_ptr& stream, 19 | std::function(const Sample& sample)> op); 20 | 21 | virtual Sample next() const override; 22 | virtual void reset() override; 23 | 24 | protected: 25 | bool next_stream_() const; 26 | 27 | mutable std::shared_ptr stream_; 28 | mutable std::shared_ptr composedStream_; 29 | mutable std::shared_mutex mutex_; 30 | std::function(const Sample& sample)> op_; 31 | }; 32 | 33 | } // namespace stream 34 | } // namespace data 35 | } // namespace mlx 36 | -------------------------------------------------------------------------------- /mlx/data/stream/DynamicBatch.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include "mlx/data/stream/DynamicBatch.h" 4 | #include "mlx/data/buffer/DynamicBatch.h" 5 | #include "mlx/data/buffer/Shuffle.h" 6 | 7 | namespace mlx { 8 | namespace data { 9 | namespace stream { 10 | 11 | DynamicBatch::DynamicBatch( 12 | std::shared_ptr stream, 13 | int64_t buffer_size, 14 | const std::string& key, 15 | int64_t max_data_size, 16 | const std::unordered_map& pad_values, 17 | const std::unordered_map& batch_dims, 18 | bool shuffle, 19 | int num_thread) 20 | : Buffered( 21 | stream, 22 | buffer_size, 23 | onRefill_(key, max_data_size, pad_values, batch_dims, shuffle), 24 | num_thread) {}; 25 | 26 | std::function< 27 | std::shared_ptr(const std::shared_ptr)> 28 | DynamicBatch::onRefill_( 29 | const std::string& key, 30 | int64_t max_data_size, 31 | const std::unordered_map& pad_values, 32 | const std::unordered_map& batch_dims, 33 | bool shuffle) { 34 | auto on_refill = [key, max_data_size, pad_values, batch_dims, shuffle]( 35 | std::shared_ptr buffer) { 36 | buffer = std::make_shared( 37 | buffer, key, max_data_size, pad_values, batch_dims); 38 | if (shuffle) { 39 | buffer = std::make_shared(buffer); 40 | } 41 | return buffer; 42 | }; 43 | 44 | return on_refill; 45 | } 46 | 47 | } // namespace stream 48 | } // namespace data 49 | } // namespace mlx 50 | -------------------------------------------------------------------------------- /mlx/data/stream/DynamicBatch.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include "mlx/data/stream/Buffered.h" 6 | 7 | namespace mlx { 8 | namespace data { 9 | namespace stream { 10 | 11 | class DynamicBatch : public Buffered { 12 | public: 13 | DynamicBatch( 14 | std::shared_ptr stream, 15 | int64_t buffer_size, 16 | const std::string& key, 17 | int64_t max_data_size = 0, // batch everything if <= 0 18 | const std::unordered_map& pad_values = {}, 19 | const std::unordered_map& batch_dims = {}, 20 | bool shuffle = false, 21 | int num_thread = 1); 22 | 23 | private: 24 | static std::function< 25 | std::shared_ptr(const std::shared_ptr)> 26 | onRefill_( 27 | const std::string& key, 28 | int64_t max_data_size, 29 | const std::unordered_map& pad_values, 30 | const std::unordered_map& batch_dims, 31 | bool shuffle); 32 | }; 33 | 34 | } // namespace stream 35 | } // namespace data 36 | } // namespace mlx 37 | -------------------------------------------------------------------------------- /mlx/data/stream/FromBuffer.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include "mlx/data/stream/FromBuffer.h" 4 | 5 | namespace mlx { 6 | namespace data { 7 | namespace stream { 8 | 9 | FromBuffer::FromBuffer(const std::shared_ptr& buffer) 10 | : buffer_(buffer), currentIdx_(0) {} 11 | 12 | Sample FromBuffer::next() const { 13 | int64_t idx = -1; 14 | { 15 | std::lock_guard lock(mutex_); 16 | if (currentIdx_ < buffer_->size()) { 17 | idx = currentIdx_++; 18 | } 19 | } 20 | if (idx < 0) { 21 | return Sample(); 22 | } else { 23 | return buffer_->get(idx); 24 | } 25 | } 26 | 27 | void FromBuffer::reset() { 28 | std::lock_guard lock(mutex_); 29 | currentIdx_ = 0; 30 | } 31 | 32 | } // namespace stream 33 | } // namespace data 34 | } // namespace mlx 35 | -------------------------------------------------------------------------------- /mlx/data/stream/FromBuffer.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include 6 | 7 | #include "mlx/data/buffer/Buffer.h" 8 | #include "mlx/data/stream/Stream.h" 9 | 10 | namespace mlx { 11 | namespace data { 12 | namespace stream { 13 | 14 | class FromBuffer : public Stream { 15 | public: 16 | FromBuffer(const std::shared_ptr& buffer); 17 | 18 | virtual Sample next() const override; 19 | virtual void reset() override; 20 | 21 | private: 22 | std::shared_ptr buffer_; 23 | mutable int64_t currentIdx_; 24 | mutable std::mutex mutex_; 25 | }; 26 | 27 | } // namespace stream 28 | } // namespace data 29 | } // namespace mlx 30 | -------------------------------------------------------------------------------- /mlx/data/stream/LineReader.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include 6 | #include 7 | #include 8 | 9 | #include "bxzstr/bxzstr.hpp" 10 | 11 | #include "mlx/data/core/FileFetcher.h" 12 | #include "mlx/data/stream/Compose.h" 13 | #include "mlx/data/stream/Stream.h" 14 | 15 | namespace mlx { 16 | namespace data { 17 | namespace stream { 18 | 19 | class LineReader : public Stream { 20 | public: 21 | LineReader( 22 | const std::string& filename, 23 | const std::string& key, 24 | bool unzip = false, 25 | const std::filesystem::path& local_prefix = "", 26 | std::shared_ptr fetcher = nullptr); 27 | LineReader( 28 | const std::shared_ptr& f, 29 | const std::string& key, 30 | bool unzip = false, 31 | std::shared_ptr file_handle = nullptr); 32 | virtual Sample next() const override; 33 | void reset() override; 34 | 35 | private: 36 | void init_(const std::shared_ptr& f, bool unzip); 37 | Sample process_() const; 38 | std::string filename_; 39 | std::shared_ptr f_; 40 | std::shared_ptr uf_; 41 | std::string key_; 42 | std::shared_ptr fileHandle_; 43 | mutable std::mutex mutex_; 44 | }; 45 | 46 | class LineReaderFromKey : public Compose { 47 | public: 48 | LineReaderFromKey( 49 | std::shared_ptr stream, 50 | const std::string& key, 51 | const std::string& dst_key, 52 | bool from_memory = false, 53 | bool unzip = false, 54 | const std::filesystem::path& local_prefix = "", 55 | std::shared_ptr fetcher = nullptr); 56 | }; 57 | 58 | } // namespace stream 59 | } // namespace data 60 | } // namespace mlx 61 | -------------------------------------------------------------------------------- /mlx/data/stream/OrderedPrefetch.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include "mlx/data/stream/OrderedPrefetch.h" 4 | 5 | namespace mlx { 6 | namespace data { 7 | namespace stream { 8 | 9 | OrderedPrefetch::OrderedPrefetch( 10 | const std::shared_ptr& buffer, 11 | int prefetch_size, 12 | int num_thread) 13 | : buffer_(buffer), 14 | pool_(std::make_shared(num_thread)), 15 | prefetchSize_(prefetch_size), 16 | currentIdx_(0) { 17 | if (prefetchSize_ <= 0) { 18 | throw std::runtime_error( 19 | "Prefetch: prefetch size must be strictly positive"); 20 | } 21 | } 22 | 23 | OrderedPrefetch::~OrderedPrefetch() { 24 | std::lock_guard lock(mutex_); 25 | 26 | prefetchCache_.clear(); 27 | } 28 | 29 | Sample OrderedPrefetch::next() const { 30 | std::unique_lock lock(mutex_); 31 | 32 | // First time we are called so enqueue all the fetching 33 | if (prefetchCache_.size() < prefetchSize_) { 34 | for (int i = 0; i < std::min(prefetchSize_, buffer_->size()); i++) { 35 | prefetchCache_.emplace_back( 36 | pool_->enqueue([b = buffer_, i] { return b->get(i); })); 37 | } 38 | } 39 | 40 | int64_t idx = -1; 41 | if (currentIdx_ < buffer_->size()) { 42 | idx = currentIdx_++; 43 | } 44 | 45 | if (idx < 0) { 46 | return Sample(); 47 | } else { 48 | int f_idx = idx % prefetchSize_; 49 | std::future fsample(std::move(prefetchCache_[f_idx])); 50 | int next_idx = idx + prefetchSize_; 51 | if (next_idx < buffer_->size()) { 52 | prefetchCache_[f_idx] = 53 | pool_->enqueue([b = buffer_, next_idx] { return b->get(next_idx); }); 54 | } 55 | lock.unlock(); 56 | return fsample.get(); 57 | } 58 | } 59 | 60 | void OrderedPrefetch::reset() { 61 | std::lock_guard lock(mutex_); 62 | currentIdx_ = 0; 63 | 64 | prefetchCache_.clear(); 65 | } 66 | 67 | } // namespace stream 68 | } // namespace data 69 | } // namespace mlx 70 | -------------------------------------------------------------------------------- /mlx/data/stream/OrderedPrefetch.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include 6 | #include 7 | 8 | #include "mlx/data/buffer/Buffer.h" 9 | #include "mlx/data/core/ThreadPool.h" 10 | #include "mlx/data/stream/Stream.h" 11 | 12 | namespace mlx { 13 | namespace data { 14 | namespace stream { 15 | 16 | class OrderedPrefetch : public Stream { 17 | public: 18 | OrderedPrefetch( 19 | const std::shared_ptr& stream, 20 | int prefetch_size, 21 | int num_thread); 22 | ~OrderedPrefetch(); 23 | 24 | virtual Sample next() const override; 25 | virtual void reset() override; 26 | 27 | private: 28 | std::shared_ptr buffer_; 29 | std::shared_ptr pool_; 30 | int64_t prefetchSize_; 31 | mutable int64_t currentIdx_; 32 | mutable std::vector> prefetchCache_; 33 | mutable std::mutex mutex_; 34 | }; 35 | 36 | } // namespace stream 37 | } // namespace data 38 | } // namespace mlx 39 | -------------------------------------------------------------------------------- /mlx/data/stream/Partition.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include "mlx/data/stream/Partition.h" 4 | 5 | namespace mlx { 6 | namespace data { 7 | namespace stream { 8 | 9 | Partition::Partition( 10 | const std::shared_ptr& stream, 11 | int64_t num_partitions, 12 | int64_t partition) 13 | : stream_(stream), numPartitions_(num_partitions), partition_(partition) { 14 | if (num_partitions < 0) { 15 | throw std::runtime_error( 16 | "Partition: number of partitions must be positive"); 17 | } 18 | if (partition < 0 || partition >= num_partitions) { 19 | throw std::runtime_error("Partition: selected partition is out of range"); 20 | } 21 | } 22 | 23 | Sample Partition::next() const { 24 | std::unique_lock lock(stream_mutex_); 25 | 26 | Sample res; 27 | for (int i = 0; i < numPartitions_; i++) { 28 | auto sample = stream_->next(); 29 | if (i == partition_) { 30 | res = std::move(sample); 31 | } 32 | } 33 | 34 | return res; 35 | } 36 | 37 | void Partition::reset() { 38 | std::unique_lock lock(stream_mutex_); 39 | stream_->reset(); 40 | } 41 | 42 | } // namespace stream 43 | } // namespace data 44 | } // namespace mlx 45 | -------------------------------------------------------------------------------- /mlx/data/stream/Partition.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include "mlx/data/stream/Stream.h" 6 | 7 | namespace mlx { 8 | namespace data { 9 | namespace stream { 10 | 11 | class Partition : public Stream { 12 | public: 13 | Partition( 14 | const std::shared_ptr& stream, 15 | int64_t num_partitions, 16 | int64_t partition); 17 | 18 | virtual Sample next() const override; 19 | virtual void reset() override; 20 | 21 | private: 22 | std::shared_ptr stream_; 23 | int64_t numPartitions_; 24 | int64_t partition_; 25 | mutable std::mutex stream_mutex_; 26 | }; 27 | 28 | } // namespace stream 29 | } // namespace data 30 | } // namespace mlx 31 | -------------------------------------------------------------------------------- /mlx/data/stream/Prefetch.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include "mlx/data/stream/Prefetch.h" 4 | 5 | namespace mlx { 6 | namespace data { 7 | namespace stream { 8 | 9 | Prefetch::Prefetch( 10 | const std::shared_ptr& stream, 11 | int prefetch_size, 12 | int num_thread) 13 | : stream_(stream), 14 | pool_(std::make_shared(num_thread)), 15 | prefetchSize_(prefetch_size) { 16 | if (prefetchSize_ < 0) { 17 | throw std::runtime_error("Prefetch: prefetch size must be positive"); 18 | } 19 | } 20 | 21 | Prefetch::~Prefetch() { 22 | std::unique_lock lock(mutex_); 23 | while (prefetchCache_.size()) { 24 | prefetchCache_.front().get(); 25 | prefetchCache_.pop(); 26 | } 27 | } 28 | 29 | Sample Prefetch::next() const { 30 | std::unique_lock lock(mutex_); 31 | 32 | // First time we are called so enqueue all the fetching 33 | if (prefetchCache_.size() < prefetchSize_) { 34 | for (int i = 0; i < prefetchSize_; i++) { 35 | prefetchCache_.emplace( 36 | pool_->enqueue([s = stream_] { return s->next(); })); 37 | } 38 | } 39 | 40 | // We are looping prefetchSize_ times. If all we get is empty then the 41 | // underlying stream is indeed exhausted. 42 | Sample res; 43 | for (int i = 0; i < prefetchSize_; i++) { 44 | std::future fsample; 45 | fsample = std::move(prefetchCache_.front()); 46 | prefetchCache_.pop(); 47 | prefetchCache_.emplace(pool_->enqueue([s = stream_] { return s->next(); })); 48 | res = fsample.get(); 49 | 50 | if (!res.empty()) { 51 | break; 52 | } 53 | } 54 | 55 | return res; 56 | } 57 | 58 | void Prefetch::reset() { 59 | std::unique_lock lock(mutex_); 60 | 61 | while (prefetchCache_.size()) { 62 | prefetchCache_.front().get(); 63 | prefetchCache_.pop(); 64 | } 65 | stream_->reset(); 66 | } 67 | 68 | } // namespace stream 69 | } // namespace data 70 | } // namespace mlx 71 | -------------------------------------------------------------------------------- /mlx/data/stream/Prefetch.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include 6 | #include 7 | 8 | #include "mlx/data/core/ThreadPool.h" 9 | #include "mlx/data/stream/Stream.h" 10 | 11 | namespace mlx { 12 | namespace data { 13 | namespace stream { 14 | 15 | class Prefetch : public Stream { 16 | public: 17 | Prefetch( 18 | const std::shared_ptr& stream, 19 | int prefetch_size, 20 | int num_thread); 21 | ~Prefetch(); 22 | 23 | virtual Sample next() const override; 24 | virtual void reset() override; 25 | 26 | private: 27 | std::shared_ptr stream_; 28 | std::shared_ptr pool_; 29 | int prefetchSize_; 30 | mutable std::queue> prefetchCache_; 31 | mutable std::mutex mutex_; 32 | }; 33 | 34 | } // namespace stream 35 | } // namespace data 36 | } // namespace mlx 37 | -------------------------------------------------------------------------------- /mlx/data/stream/Repeat.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include 4 | 5 | #include "mlx/data/stream/Repeat.h" 6 | 7 | namespace mlx { 8 | namespace data { 9 | namespace stream { 10 | 11 | Repeat::Repeat(const std::shared_ptr& stream, int64_t num_time) 12 | : stream_(stream), numTime_(num_time), numDone_(0) {}; 13 | 14 | Sample Repeat::next() const { 15 | Sample sample; 16 | { 17 | std::shared_lock slock(stream_reset_mutex_); 18 | sample = stream_->next(); 19 | } 20 | 21 | // Empty sample we may need to reset the underlying stream 22 | if (sample.empty()) { 23 | { 24 | std::unique_lock lock(stream_reset_mutex_); 25 | 26 | // Get another sample in case someone else reset the stream in the 27 | // meantime. 28 | sample = stream_->next(); 29 | if (!sample.empty()) { 30 | return sample; 31 | } 32 | 33 | // We are not allowed to reset anymore 34 | if (numTime_ > 0 && numDone_ >= numTime_) { 35 | return sample; 36 | } 37 | 38 | numDone_++; 39 | stream_->reset(); 40 | sample = stream_->next(); 41 | } 42 | 43 | return sample; 44 | } 45 | 46 | return sample; 47 | } 48 | 49 | void Repeat::reset() { 50 | std::unique_lock ulock(stream_reset_mutex_); 51 | stream_->reset(); 52 | numDone_ = 0; 53 | } 54 | 55 | } // namespace stream 56 | } // namespace data 57 | } // namespace mlx 58 | -------------------------------------------------------------------------------- /mlx/data/stream/Repeat.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #include "mlx/data/stream/Stream.h" 11 | 12 | namespace mlx { 13 | namespace data { 14 | namespace stream { 15 | 16 | class Repeat : public Stream { 17 | public: 18 | Repeat(const std::shared_ptr& stream, int64_t num_time); 19 | 20 | virtual Sample next() const override; 21 | virtual void reset() override; 22 | 23 | protected: 24 | std::shared_ptr stream_; 25 | int64_t numTime_; 26 | mutable std::shared_mutex stream_reset_mutex_; 27 | mutable int64_t numDone_; 28 | }; 29 | 30 | } // namespace stream 31 | } // namespace data 32 | } // namespace mlx 33 | -------------------------------------------------------------------------------- /mlx/data/stream/Shuffle.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include 4 | 5 | #include "mlx/data/core/State.h" 6 | #include "mlx/data/stream/Shuffle.h" 7 | 8 | namespace mlx { 9 | namespace data { 10 | namespace stream { 11 | 12 | Shuffle::Shuffle(const std::shared_ptr& stream, int buffer_size) 13 | : stream_(stream), buffer_size_(buffer_size) {} 14 | 15 | Sample Shuffle::next() const { 16 | // The while is really only for case 1 below but it reads a bit better than 17 | // putting the while loop in lines 30-35 I believe. 18 | while (true) { 19 | // First get a sample from the underlying stream 20 | auto sample = stream_->next(); 21 | 22 | // Now there are a couple of cases we have to consider 23 | // 24 | // 1. The sample is not empty and the buffer is not full -> keep fetching 25 | // 2. The sample is not empty and the buffer is full -> standard case 26 | // 3. The sample is empty and the buffer is not empty -> pop a random sample 27 | // 4. The sample is empty and the buffer is empty -> we are done 28 | 29 | if (!sample.empty()) { 30 | std::uniform_int_distribution pos_dis(0, buffer_size_ - 1); 31 | int pos = pos_dis(core::get_state()->randomGenerator); 32 | 33 | { 34 | std::unique_lock lock(mutex_); 35 | 36 | if (buffer_.size() < buffer_size_) { 37 | buffer_.emplace_back(sample); 38 | continue; 39 | } 40 | 41 | std::swap(sample, buffer_[pos]); 42 | 43 | return sample; 44 | } 45 | } else { 46 | std::unique_lock lock(mutex_); 47 | 48 | if (buffer_.size() > 0) { 49 | std::uniform_int_distribution pos_dis(0, buffer_.size() - 1); 50 | int pos = pos_dis(core::get_state()->randomGenerator); 51 | 52 | sample = std::move(buffer_[pos]); 53 | buffer_.erase(buffer_.begin() + pos); 54 | } 55 | 56 | return sample; 57 | } 58 | } 59 | } 60 | 61 | void Shuffle::reset() { 62 | std::unique_lock lock(mutex_); 63 | 64 | stream_->reset(); 65 | buffer_.clear(); 66 | } 67 | 68 | } // namespace stream 69 | } // namespace data 70 | } // namespace mlx 71 | -------------------------------------------------------------------------------- /mlx/data/stream/Shuffle.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include 6 | 7 | #include "mlx/data/stream/Stream.h" 8 | 9 | namespace mlx { 10 | namespace data { 11 | namespace stream { 12 | 13 | class Shuffle : public Stream { 14 | public: 15 | Shuffle(const std::shared_ptr& stream, int buffer_size); 16 | 17 | virtual Sample next() const override; 18 | virtual void reset() override; 19 | 20 | private: 21 | std::shared_ptr stream_; 22 | int buffer_size_; 23 | mutable std::vector buffer_; 24 | mutable std::mutex mutex_; 25 | }; 26 | 27 | } // namespace stream 28 | } // namespace data 29 | } // namespace mlx 30 | -------------------------------------------------------------------------------- /mlx/data/stream/SlidingWindow.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include 4 | 5 | #include "mlx/data/stream/SlidingWindow.h" 6 | 7 | namespace mlx { 8 | namespace data { 9 | namespace stream { 10 | SlidingWindow::SlidingWindow( 11 | const std::shared_ptr& stream, 12 | const std::string& key, 13 | int64_t size, 14 | int64_t stride, 15 | int dim, 16 | const std::string& index_key) 17 | : stream_(stream), 18 | key_(key), 19 | size_(size), 20 | stride_(stride), 21 | dim_(dim), 22 | index_key_(index_key) { 23 | if (size <= 0) { 24 | throw std::runtime_error("SlidingWindow: size must be strictly positive"); 25 | } 26 | if (stride <= 0) { 27 | throw std::runtime_error("SlidingWindow: stride must be strictly positive"); 28 | } 29 | } 30 | 31 | Sample SlidingWindow::next() const { 32 | std::unique_lock lock(mutex_); 33 | 34 | // Check if we already created some samples in which case simply return 35 | // the first one. 36 | if (!buffer_.empty()) { 37 | auto sample = std::move(buffer_.front()); 38 | buffer_.pop(); 39 | return sample; 40 | } 41 | 42 | // The buffer is empty so we need to get the next sample from the stream 43 | // and make a sliding window that is saved in the buffer_. 44 | std::queue buffer; 45 | while (buffer.empty()) { 46 | // Fetch the next full sample 47 | auto sample = stream_->next(); 48 | if (sample.empty()) { 49 | return sample; 50 | } 51 | 52 | auto array = sample::check_key(sample, key_, ArrayType::Any); 53 | int dim = array->checkdim(dim_); 54 | int64_t length = array->shape(dim); 55 | auto newshape = array->shape(); 56 | std::vector newoffset(array->ndim(), 0); 57 | int64_t offset = 0; 58 | int64_t slice_index = 0; 59 | while (offset < length) { 60 | auto newsample = sample; 61 | int64_t newlength = 62 | ((size_ <= (length - offset)) ? size_ : (length - offset)); 63 | newshape[dim] = newlength; 64 | newoffset[dim] = offset; 65 | newsample[key_] = array::sub(array, newoffset, newshape); 66 | if (!index_key_.empty()) { 67 | newsample[index_key_] = std::make_shared(slice_index); 68 | } 69 | buffer.emplace(newsample); 70 | offset += stride_; 71 | slice_index++; 72 | } 73 | } 74 | 75 | auto sample = std::move(buffer.front()); 76 | buffer.pop(); 77 | buffer_ = buffer; 78 | 79 | return sample; 80 | } 81 | 82 | void SlidingWindow::reset() { 83 | std::unique_lock lock(mutex_); 84 | buffer_ = std::queue(); 85 | stream_->reset(); 86 | } 87 | 88 | } // namespace stream 89 | } // namespace data 90 | } // namespace mlx 91 | -------------------------------------------------------------------------------- /mlx/data/stream/SlidingWindow.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include 6 | #include 7 | 8 | #include "mlx/data/stream/Stream.h" 9 | 10 | namespace mlx { 11 | namespace data { 12 | namespace stream { 13 | 14 | class SlidingWindow : public Stream { 15 | public: 16 | SlidingWindow( 17 | const std::shared_ptr& stream, 18 | const std::string& key, 19 | int64_t size, 20 | int64_t stride, 21 | int dim = -1, 22 | const std::string& index_key = ""); 23 | 24 | virtual Sample next() const override; 25 | virtual void reset() override; 26 | 27 | private: 28 | mutable std::mutex mutex_; 29 | mutable std::queue buffer_; 30 | std::shared_ptr stream_; 31 | std::string key_; 32 | int64_t size_; 33 | int64_t stride_; 34 | int dim_; 35 | std::string index_key_; 36 | }; 37 | 38 | } // namespace stream 39 | } // namespace data 40 | } // namespace mlx 41 | -------------------------------------------------------------------------------- /mlx/data/stream/Stream.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include 4 | 5 | #include "mlx/data/stream/Stream.h" 6 | 7 | namespace mlx { 8 | namespace data { 9 | namespace stream { 10 | 11 | Sample Stream::next() const { 12 | throw std::runtime_error("Stream::next() NYI"); 13 | } 14 | 15 | void Stream::reset() { 16 | throw std::runtime_error("Stream::reset() NYI"); 17 | } 18 | 19 | Stream::~Stream() {} 20 | 21 | } // namespace stream 22 | } // namespace data 23 | } // namespace mlx 24 | -------------------------------------------------------------------------------- /mlx/data/stream/Stream.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include 6 | #include 7 | 8 | #include "mlx/data/Sample.h" 9 | #include "mlx/data/core/State.h" 10 | 11 | namespace mlx { 12 | namespace data { 13 | namespace stream { 14 | 15 | class Stream { 16 | public: 17 | Stream() {}; 18 | 19 | // fetch next sample 20 | virtual Sample next() const; 21 | 22 | // reset the stream 23 | virtual void reset(); 24 | 25 | virtual ~Stream(); 26 | }; 27 | 28 | } // namespace stream 29 | } // namespace data 30 | } // namespace mlx 31 | -------------------------------------------------------------------------------- /mlx/data/stream/Transform.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include 4 | 5 | #include "mlx/data/stream/Transform.h" 6 | 7 | namespace mlx { 8 | namespace data { 9 | namespace stream { 10 | 11 | Transform::Transform( 12 | const std::shared_ptr& stream, 13 | const std::shared_ptr& op) 14 | : stream_(stream), ops_({op}) {}; 15 | 16 | Transform::Transform( 17 | const std::shared_ptr& stream, 18 | const std::vector>& ops) 19 | : stream_(stream), ops_(ops) {}; 20 | 21 | Sample Transform::next() const { 22 | // Process the stream untill it is either exhausted or a sample is 23 | // generated. While doing so mark the skipped elements. 24 | Sample res; 25 | 26 | while (res.empty()) { 27 | auto sample = stream_->next(); 28 | 29 | // We exhausted the stream so return an empty sample 30 | if (sample.empty()) { 31 | break; 32 | } 33 | 34 | // Got a sample let's transform it 35 | res = sample; 36 | for (auto& op : ops_) { 37 | res = op->apply(res); 38 | 39 | // Hmm we should skip it 40 | if (res.empty()) { 41 | break; 42 | } 43 | } 44 | } 45 | 46 | return res; 47 | } 48 | 49 | void Transform::reset() { 50 | stream_->reset(); 51 | } 52 | 53 | } // namespace stream 54 | } // namespace data 55 | } // namespace mlx 56 | -------------------------------------------------------------------------------- /mlx/data/stream/Transform.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include 6 | #include 7 | #include 8 | 9 | #include "mlx/data/op/Op.h" 10 | #include "mlx/data/stream/Stream.h" 11 | 12 | namespace mlx { 13 | namespace data { 14 | namespace stream { 15 | 16 | class Transform : public Stream { 17 | public: 18 | Transform( 19 | const std::shared_ptr& stream, 20 | const std::shared_ptr& op); 21 | Transform( 22 | const std::shared_ptr& stream, 23 | const std::vector>& ops); 24 | 25 | virtual Sample next() const override; 26 | virtual void reset() override; 27 | 28 | protected: 29 | std::shared_ptr stream_; 30 | std::vector> ops_; 31 | }; 32 | 33 | } // namespace stream 34 | } // namespace data 35 | } // namespace mlx 36 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=42", "cmake>=3.23.3", "pybind11==2.13.6"] 3 | build-backend = "setuptools.build_meta" 4 | -------------------------------------------------------------------------------- /python/mlx/data/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | # pybind11 may import numpy.core.multiarray (eg., for dtype()), 4 | # and may do it in a thread. To prevent any GIL lock issue, we preload 5 | # this subpackage. 6 | import numpy.core.multiarray 7 | 8 | # fmt: off 9 | # pybind11 will import numpy, and may do it in a thread. 10 | # in that event, openblas initialization may lead to invalid address read errors. 11 | # importing numpy in the main thread alleviate the issue. 12 | # alternatively, one can set OPENBLAS_NUM_THREADS=1. 13 | import numpy # isort: skip 14 | 15 | try: 16 | del numpy.core.multiarray 17 | except AttributeError: 18 | pass 19 | 20 | del numpy 21 | 22 | from . import tokenizer_helpers 23 | from ._c import * 24 | from ._c import __version__ 25 | -------------------------------------------------------------------------------- /python/mlx/data/core.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | from ._c.core import * 4 | -------------------------------------------------------------------------------- /python/mlx/data/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | from .cifar import load_cifar10, load_cifar100 4 | from .image_folder import load_images_from_folder 5 | from .imagenet import load_imagenet, load_imagenet_metadata 6 | from .librispeech import load_librispeech, load_librispeech_tarfile 7 | from .libritts_r import load_libritts_r, load_libritts_r_tarfile 8 | from .mnist import load_fashion_mnist, load_mnist 9 | from .speechcommands import load_speechcommands 10 | from .wikitext import load_wikitext_lines 11 | -------------------------------------------------------------------------------- /python/mlx/data/datasets/image_folder.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | from pathlib import Path 4 | 5 | from ... import data as dx 6 | 7 | 8 | def load_images_from_folder(image_folder): 9 | """Load images from a folder. 10 | 11 | For a directory structure like the following 12 | 13 | .. code-block:: 14 | 15 | image_folder/ 16 | ...class_1/ 17 | ......foo.jpg 18 | ......bar.png 19 | ...class_2/ 20 | ......baz.jpg 21 | ......foo_again.png 22 | 23 | this function will return a :class:`~mlx.data.Buffer` that contains samples 24 | with the following keys 25 | 26 | - **folder**: the name of the category of this sample (e.g. class_1, class_2 etc) 27 | - **label**: an integer that corresponds to the sorted position of the folder 28 | names (e.g. class_1 gets 0 and class_2 gets 1) 29 | - **file**: the path to the image relative to the provided root folder 30 | - **image**: the loaded image array 31 | 32 | Args: 33 | image_folder: (Path or str): The directory to load the images from. 34 | """ 35 | root = Path(image_folder) 36 | if not root.is_dir(): 37 | raise ValueError(f"The provided path {root} is not a directory") 38 | 39 | directories = sorted( 40 | [f for f in root.iterdir() if f.is_dir()], key=lambda x: x.name 41 | ) 42 | if not directories: 43 | raise ValueError(f"The provided path {root} contains no directories") 44 | 45 | classes = {f.name: i for i, f in enumerate(directories)} 46 | samples = [ 47 | dict( 48 | folder=folder.name.encode(), 49 | label=classes[folder.name], 50 | file=str(img.relative_to(root)).encode(), 51 | ) 52 | for folder in directories 53 | for img in folder.iterdir() 54 | if img.is_file() 55 | ] 56 | 57 | return dx.buffer_from_vector(samples).load_image( 58 | "file", prefix=str(root), output_key="image" 59 | ) 60 | -------------------------------------------------------------------------------- /python/mlx/data/features/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | from .audio import FrequencyScale, WindowType, mfsc 4 | -------------------------------------------------------------------------------- /python/src/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | find_package( 2 | Python 3 | COMPONENTS Interpreter Development.Module 4 | REQUIRED) 5 | execute_process( 6 | COMMAND "${Python_EXECUTABLE}" -m pybind11 --cmakedir 7 | OUTPUT_STRIP_TRAILING_WHITESPACE 8 | OUTPUT_VARIABLE pybind11_ROOT) 9 | find_package(pybind11 CONFIG REQUIRED) 10 | 11 | pybind11_add_module( 12 | _c 13 | ${CMAKE_CURRENT_LIST_DIR}/wrap.cpp 14 | ${CMAKE_CURRENT_LIST_DIR}/wrap_buffer.cpp 15 | ${CMAKE_CURRENT_LIST_DIR}/wrap_core.cpp 16 | ${CMAKE_CURRENT_LIST_DIR}/wrap_stream.cpp) 17 | 18 | target_include_directories(_c PUBLIC ${CMAKE_SOURCE_DIR}) 19 | target_link_libraries(_c PRIVATE mlxdata) 20 | target_compile_definitions(_c PRIVATE _VERSION_=${MLX_DATA_VERSION}) 21 | 22 | install(TARGETS _c DESTINATION mlx/data) 23 | -------------------------------------------------------------------------------- /python/tests/test_bpe.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2024 Apple Inc. 2 | 3 | import string 4 | import unittest 5 | 6 | from mlx.data.core import BPEMerges, BPETokenizer, CharTrie 7 | 8 | 9 | class TestBpe(unittest.TestCase): 10 | def test_bpe(self): 11 | symbols = CharTrie() 12 | symbols.insert(" ") 13 | for s in string.ascii_letters: 14 | symbols.insert(s) 15 | n = symbols.num_keys() 16 | merges = BPEMerges() 17 | 18 | tokenizer = BPETokenizer(symbols, merges) 19 | 20 | self.assertEqual(tokenizer.tokenize("abcd"), [1, 2, 3, 4]) 21 | 22 | merges.add("a", "b", n + 1) 23 | self.assertEqual(tokenizer.tokenize("abcd"), [n + 1, 3, 4]) 24 | 25 | merges.add("c", "d", n + 2) 26 | merges.add("b", "cd", n + 3) 27 | self.assertEqual(tokenizer.tokenize("abcd"), [n + 1, n + 2]) 28 | 29 | 30 | if __name__ == "__main__": 31 | unittest.main() 32 | -------------------------------------------------------------------------------- /python/tests/test_buffer.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2024 Apple Inc. 2 | 3 | import array 4 | import unittest 5 | 6 | import numpy as np 7 | 8 | import mlx.data as dx 9 | 10 | 11 | class TestBuffer(unittest.TestCase): 12 | def test__getitem__(self): 13 | n = 5 14 | b = dx.buffer_from_vector(list(dict(i=i) for i in range(n))) 15 | for i in range(n): 16 | self.assertEqual(b[i]["i"], i) 17 | i += 1 18 | self.assertEqual(b[-i], b[n - i]) 19 | 20 | with self.assertRaises(IndexError): 21 | _ = b[n] 22 | with self.assertRaises(IndexError): 23 | _ = b[-(n + 1)] 24 | 25 | def test_ordered_prefetch(self): 26 | """Test that elements are fetched in order.""" 27 | num_threads = 8 28 | prefetch_size = 16 29 | n = prefetch_size * 10 30 | buffer = dx.buffer_from_vector(list(dict(i=i) for i in range(n))) 31 | stream = buffer.ordered_prefetch(prefetch_size, num_threads) 32 | for i, e in enumerate(stream): 33 | self.assertEqual(i, e["i"]) 34 | 35 | def test_ordered_prefetch_edge_case(self): 36 | """Test when the buffer is smaller than dataset size.""" 37 | num_threads = 4 38 | prefetch_size = 12 39 | n = int(prefetch_size * 0.5) 40 | buffer = dx.buffer_from_vector(list(dict(i=i) for i in range(n))) 41 | stream = buffer.ordered_prefetch(prefetch_size, num_threads) 42 | for i, e in enumerate(stream): 43 | self.assertEqual(i, e["i"]) 44 | 45 | def test_passing_python_objects(self): 46 | with self.assertRaises(ValueError): 47 | b = dx.buffer_from_vector([{"a": "hello"}]) 48 | with self.assertRaises(ValueError): 49 | b = dx.buffer_from_vector([{"a": object()}]) 50 | 51 | x = array.array("f") 52 | x.append(10) 53 | x.append(-2.5) 54 | y = np.random.randn(10) 55 | b = dx.buffer_from_vector( 56 | [ 57 | { 58 | "a": 1, 59 | "b": 1.2, 60 | "c": b"Hello world", 61 | "d": y, 62 | "e": x, 63 | } 64 | ] 65 | ) 66 | self.assertEqual(-2.5, b[0]["e"][1]) 67 | self.assertEqual(1, b[0]["a"]) 68 | self.assertTrue(np.all(y == b[0]["d"])) 69 | self.assertTrue(np.all(x == b[0]["e"])) 70 | 71 | # Check that we take np arrays without a copy 72 | y[0] = 0 73 | self.assertTrue(np.all(y == b[0]["d"])) 74 | 75 | # and buffers via copy 76 | x[0] = 0 77 | self.assertEqual(10, b[0]["e"][0]) 78 | 79 | 80 | if __name__ == "__main__": 81 | unittest.main() 82 | -------------------------------------------------------------------------------- /python/tests/test_general_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2024 Apple Inc. 2 | 3 | import unittest 4 | 5 | import numpy as np 6 | 7 | import mlx.data as dx 8 | 9 | 10 | class TestGeneralOps(unittest.TestCase): 11 | def test_slice(self): 12 | dset = dx.buffer_from_vector([{"a": b"hello"}, {"a": b"world"}]) 13 | sliced_dset = dset.slice("a", 0, 1, 3) 14 | self.assertEqual(bytes(sliced_dset[0]["a"]), b"el") 15 | self.assertEqual(bytes(sliced_dset[1]["a"]), b"or") 16 | 17 | dset = dx.buffer_from_vector( 18 | [ 19 | {"a": np.arange(12).reshape(3, 4)}, 20 | {"a": np.arange(12).reshape(3, 4) + 10}, 21 | ] 22 | ) 23 | sliced_dset = dset.slice("a", 1, 1, 3) 24 | self.assertTrue(np.all(sliced_dset[0]["a"] == dset[0]["a"][:, 1:3])) 25 | self.assertTrue(np.all(sliced_dset[1]["a"] == dset[1]["a"][:, 1:3])) 26 | sliced_dset = dset.slice("a", 0, 1, 12) 27 | self.assertTrue(np.all(sliced_dset[0]["a"] == dset[0]["a"][1:, :])) 28 | self.assertTrue(np.all(sliced_dset[1]["a"] == dset[1]["a"][1:, :])) 29 | sliced_dset = dset.slice("a", [0, 1], [0, 1], [1, 3]) 30 | self.assertTrue(np.all(sliced_dset[0]["a"] == dset[0]["a"][0:1, 1:3])) 31 | self.assertTrue(np.all(sliced_dset[1]["a"] == dset[1]["a"][0:1, 1:3])) 32 | 33 | with self.assertRaises(ValueError): 34 | sliced_dset = dset.slice("a", [0, 1], 2, 3) 35 | 36 | def test_random_slice(self): 37 | dset = dx.buffer_from_vector([{"a": b"hello"}, {"a": b"world"}]) 38 | sliced_dset = dset.to_stream().repeat(-1).random_slice("a", 0, 3) 39 | options = [ 40 | set([b"hel", b"ell", b"llo"]), 41 | set([b"wor", b"orl", b"rld"]), 42 | ] 43 | for i, s in zip(range(20), sliced_dset): 44 | self.assertTrue(bytes(s["a"]) in options[i % 2]) 45 | 46 | 47 | if __name__ == "__main__": 48 | unittest.main() 49 | -------------------------------------------------------------------------------- /python/tests/test_replace.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2024 Apple Inc. 2 | 3 | import unittest 4 | 5 | import mlx.data as dx 6 | 7 | 8 | class TestReplace(unittest.TestCase): 9 | def test_replace(self): 10 | s = "Hello world".encode() 11 | dset = dx.buffer_from_vector([dict(text=s)]) 12 | 13 | ds = dset.replace("text", "world", "everybody!") 14 | self.assertEqual(bytes(ds[0]["text"]), b"Hello everybody!") 15 | 16 | ds = dset.replace("text", "l", "b") 17 | self.assertEqual(bytes(ds[0]["text"]), b"Hebbo worbd") 18 | 19 | ds = dset.replace("text", "l", "b", 2) 20 | self.assertEqual(bytes(ds[0]["text"]), b"Hebbo world") 21 | 22 | 23 | if __name__ == "__main__": 24 | unittest.main() 25 | -------------------------------------------------------------------------------- /super/cmake/aws-1.11.557.patch: -------------------------------------------------------------------------------- 1 | diff --git a/crt/aws-crt-cpp/crt/aws-c-cal/CMakeLists.txt b/crt/aws-crt-cpp/crt/aws-c-cal/CMakeLists.txt 2 | index fbb502d..1af7612 100644 3 | --- a/crt/aws-crt-cpp/crt/aws-c-cal/CMakeLists.txt 4 | +++ b/crt/aws-crt-cpp/crt/aws-c-cal/CMakeLists.txt 5 | @@ -102,7 +102,7 @@ if (NOT BYO_CRYPTO) 6 | if (USE_OPENSSL AND NOT ANDROID) 7 | find_package(OpenSSL REQUIRED) 8 | find_package(Threads REQUIRED) 9 | - list(APPEND PLATFORM_LIBS OpenSSL::Crypto Threads::Threads) 10 | + list(APPEND PLATFORM_LIBS OpenSSL::SSL OpenSSL::Crypto Threads::Threads) 11 | message(STATUS "Using libcrypto from system: ${OPENSSL_CRYPTO_LIBRARY}") 12 | elseif(NOT USE_OPENSSL AND IN_SOURCE_BUILD) 13 | if (TARGET crypto) 14 | -------------------------------------------------------------------------------- /super/cmake/bzip2-1.0.8.patch: -------------------------------------------------------------------------------- 1 | diff -Naur bzip2/Makefile bzip2-patch/Makefile 2 | --- bzip2/Makefile 2024-01-05 14:13:04.891566438 -0800 3 | +++ bzip2-patch/Makefile 2024-01-05 13:57:31.611942881 -0800 4 | @@ -21,7 +21,7 @@ 5 | LDFLAGS=-fPIE 6 | 7 | BIGFILES=-D_FILE_OFFSET_BITS=64 8 | -CFLAGS=-Wall -Winline -O2 -g $(BIGFILES) 9 | +CFLAGS=-Wall -Winline -O2 -fPIC -g $(BIGFILES) 10 | 11 | # Where you want it installed when you do 'make install' 12 | PREFIX=/usr/local 13 | -------------------------------------------------------------------------------- /super/cmake/flac-1.5.0.patch: -------------------------------------------------------------------------------- 1 | diff --git a/CMakeLists.txt b/CMakeLists.txt 2 | index fb23b7d9..77463836 100644 3 | --- a/CMakeLists.txt 4 | +++ b/CMakeLists.txt 5 | @@ -213,7 +213,7 @@ endif() 6 | # The following folder layout is mostly for MSVC 7 | set_property(GLOBAL PROPERTY USE_FOLDERS ON) 8 | 9 | -set_target_properties(FLAC grabbag getopt replaygain_analysis replaygain_synthesis utf8 PROPERTIES FOLDER Libraries) 10 | +set_target_properties(FLAC grabbag replaygain_analysis replaygain_synthesis utf8 PROPERTIES FOLDER Libraries) 11 | if(BUILD_CXXLIBS) 12 | set_target_properties(FLAC++ PROPERTIES FOLDER Libraries) 13 | endif() 14 | diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt 15 | index 262feead..4a4e4cb0 100644 16 | --- a/src/CMakeLists.txt 17 | +++ b/src/CMakeLists.txt 18 | @@ -9,7 +9,6 @@ if(BUILD_CXXLIBS) 19 | endif() 20 | add_subdirectory("share/replaygain_analysis") 21 | add_subdirectory("share/replaygain_synthesis") 22 | -add_subdirectory("share/getopt") 23 | add_subdirectory("share/utf8") 24 | add_subdirectory("share/grabbag") 25 | 26 | -------------------------------------------------------------------------------- /super/cmake/xvidcore-1.3.7.patch: -------------------------------------------------------------------------------- 1 | --- a/build/generic/Makefile 2 | +++ b/build/generic/Makefile 3 | @@ -149,21 +149,6 @@ 4 | @$(INSTALL) -d $(DESTDIR)$(libdir) 5 | @echo " I: $(libdir)/$(STATIC_LIB)" 6 | @$(INSTALL) -m 644 $(BUILD_DIR)/$(STATIC_LIB) $(DESTDIR)$(libdir)/$(STATIC_LIB) 7 | -ifeq ($(SHARED_EXTENSION),dll) 8 | - @echo " I: $(libdir)/$(SHARED_LIB).a" 9 | - @$(INSTALL) -m 644 $(BUILD_DIR)/$(SHARED_LIB).a $(DESTDIR)$(libdir)/$(SHARED_LIB).a 10 | - @echo " D: $(bindir)" 11 | - @$(INSTALL) -d $(DESTDIR)$(bindir) 12 | - @echo " I: $(bindir)/$(SHARED_LIB)" 13 | - @$(INSTALL) -m 755 $(BUILD_DIR)/$(SHARED_LIB) $(DESTDIR)$(bindir)/$(SHARED_LIB) 14 | -else 15 | - @echo " I: $(libdir)/$(SHARED_LIB)" 16 | - @$(INSTALL) -m 644 $(BUILD_DIR)/$(SHARED_LIB) $(DESTDIR)$(libdir)/$(SHARED_LIB) 17 | - @test -z "$(SO_API_MAJOR_LINK)" || \ 18 | - $(LN_S) $(SHARED_LIB) $(DESTDIR)$(libdir)/$(SO_API_MAJOR_LINK) 19 | - @test -z "$(SO_LINK)" || \ 20 | - $(LN_S) $(SHARED_LIB) $(DESTDIR)$(libdir)/$(SO_LINK) 21 | -endif 22 | 23 | #----------------------------------------------------------------------------- 24 | # Platorm specific file -- dumb rules for people executing make before 25 | -------------------------------------------------------------------------------- /tests/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # coming soon 2 | --------------------------------------------------------------------------------