├── .github └── workflows │ └── publish.yml ├── .gitignore ├── LICENSE ├── README.md ├── analyse_sweep.py ├── assets ├── local_sweep.png └── remote_sweep.png ├── benchmark ├── __init__.py ├── dataset.py └── trainer.py ├── pyproject.toml ├── run_benchmark.py ├── run_sweep.sh ├── stocaching └── __init__.py └── tests ├── __init__.py └── test_shared_cache.py /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish 2 | 3 | on: 4 | release: 5 | types: [created] 6 | 7 | jobs: 8 | deploy: 9 | runs-on: ubuntu-latest 10 | 11 | steps: 12 | - uses: actions/checkout@v2 13 | - name: Set up Python 14 | uses: actions/setup-python@v2 15 | with: 16 | python-version: "3.10" 17 | - name: Install dependencies 18 | run: | 19 | python -m pip install --upgrade pip 20 | pip install poetry 21 | - name: Build and publish 22 | run: | 23 | poetry version $(git describe --tags --abbrev=0) 24 | poetry config pypi-token.pypi ${{ secrets.PYPI_API_TOKEN }} 25 | poetry build 26 | poetry publish 27 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # poetry.lock not necessary in library project 114 | poetry.lock 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | outputs/ 135 | remote-outputs/ 136 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Charles Jones 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Stochastic Caching 2 | 3 | Is your training pipeline data-bottlenecked? Are you looking for a zero-effort speedup? You've come to the right place! 4 | 5 | Introducing `stocaching`, a tiny (1 file) library for stochastic dataset caching in PyTorch. 6 | 7 | See [this blog post](https://charl-ai.github.io/blog/dataloaders/) if you want to understand the benefits, motivation, and decisions behind the library. 8 | 9 | Feel free to simply copy-paste the module into your projects! Alternatively, you can install with pip: 10 | 11 | ```bash 12 | pip install stocaching 13 | ``` 14 | 15 | ## Usage 16 | 17 | Adding stochastic caching to your existing datasets is dead simple. Simply follow these two steps: 18 | 19 | 1. Create a `SharedCache` object in the `__init__` method of your dataset. You tell `SharedCache` about the size of your dataset and the maximum amount of space you want the cache to take up. `SharedCache` then calculates the maximum number of samples that can fit, and allocates that many slots to store data in. 20 | 21 | 2. In the `__getitem__` method of your dataset, interact with the `SharedCache` object to either read the data from the cache (if it has already been cached), or write the data to the cache (if it has not yet been cached). 22 | 23 | You can get and set items in the cache with `x = cache[idx]`, and `cache[idx] = x`. You can picture the cache as a list-like structure with a slot for each sample. 24 | 25 | When the dataset is too large to cache completely, `len(cache) < len(dataset)`. If you used the getter and setter directly, you would end up with lots of fiddly code to check if idx is in bounds for the cache. We provide two convenience methods `get_slot`, and `set_slot`, which allow you to treat the cache as if it has the same length as the dataset. Using `get_slot` out of bounds of the cache simply returns `None`. Using `set_slot` out of bounds is a no-op. These methods are designed to minimise the amount of code you need to write in the `__getitem__` method of your dataset. 26 | 27 | **Advanced:** Internally, the cache is simply a single pytorch array, backed by shared memory. You can access the underlying array with the `array` property. We also keep an auxiliary array in shared memory, which tracks which samples have been cached, which are yet to be cached, and which are out-of-bounds of the cache. You can access it directly with the `aux_array` property. 28 | 29 | ## Example 30 | 31 | ```python 32 | import torch 33 | from stocaching import SharedCache 34 | from torch.utils.data import Dataset 35 | 36 | class MyDataset(Dataset): 37 | def __init__(self): 38 | super().__init__() 39 | 40 | ... # set up dataset 41 | 42 | dataset_len = N # number of samples in the full dataset 43 | data_dims = (C, H, W) # data dims (not including batch) 44 | 45 | # initialize the cache 46 | self.cache = SharedCache( 47 | size_limit_gib=32, 48 | dataset_len=dataset_len, 49 | data_dims=data_dims, 50 | dtype=torch.uint8, 51 | ) 52 | 53 | def __getitem__(self, idx): 54 | # retrieve data from cache if it's there 55 | x = self.cache.get_slot(idx) 56 | # x will be None if the cache slot was empty or OOB 57 | if x is None: 58 | x = ... # load data to uint8 tensor from disk 59 | self.cache.set_slot(idx, x) # try to cache x 60 | return x 61 | ``` 62 | 63 | ## Benchmarks 64 | 65 | We run some basic benchmarks for stochastic caching under a realistic workload -- single GPU image classification. 66 | 67 | We train `mobilenet_v3_small` on a 50k sample dataset for two epochs. The reason we use such a tiny model is to ensure that we are in the dataloading-bottlenecked regime. In the first epoch, the stochastic cache is being lazily filled. In the second epoch, the cache is being read from. 68 | 69 | We perform two sets of experiments: one with the data on a local HDD, and one with the data being read from another machine on the network. All experiments are run on the same machine: RTX 3090 GPU, i9 10th gen CPU. 70 | 71 | In all epochs apart from the first, stochastic caching gives a speedup that scales linearly with the percentage of the dataset being cached. There is a very small overhead in the first epoch (due to filling the cache), but by the end of the second epoch, the speedup from caching more than compensates for this. 72 | 73 | | Local HDD | Remote data | 74 | | :-------------------------: | :--------------------------: | 75 | | ![](assets/local_sweep.png) | ![](assets/remote_sweep.png) | 76 | 77 | ## FAQ 78 | 79 | ### How much memory should I allocate to the cache? 80 | 81 | As much as you like! The speedup from caching scales linearly with the % of your dataset being cached. 82 | 83 | The shared memory is stored in `/dev/shm` (tmpfs), so this is likely the limiting factor for you. We provide a convenience function `get_shm_size` to check how large it is. Alternatively, check with `df -h`. 84 | 85 | Most Unix-like systems have `/dev/shm` pre-set to 50% of your RAM. You can temporarily resize it (e.g. to 128 GiB) by running: `mount -o remount,size=128G /dev/shm` (warning: do this at your own risk). 86 | 87 | ### How does this interact with augmentations/transforms? 88 | 89 | Generally, you don't want to do any random augmentations before caching because the cache will kill the randomness. It's also a good idea to cache data in uint8 format (instead of float32) to save space. 90 | 91 | Splitting your transforms/augmentation pipeline into two phases is a good idea. The first phase converts your data to a (possibly resized) uint8 tensor. The output of this phase gets cached. The second phase should do random augmentations, convert to float32, and normalise. This phase happens 'on-line' and the output goes straight into your model. 92 | 93 | For an example of how to do this properly, see the implementation in `benchmark/dataset.py`. You can also read the [blog post](https://charl-ai.github.io/blog/dataloaders/) for more information. 94 | 95 | ### Does this work with multi-GPU (DDP) training? 96 | 97 | Almost. I'll push an update to support it soon. 98 | 99 | ### How do I reproduce the benchmarks? 100 | 101 | If you feel like it, please reproduce these benchmarks on your setup! 102 | 103 | We benchmark the method with the minimal example in the `benchmark/` dir. You may perform a single benchmark run like so: 104 | 105 | ```bash 106 | python run_benchmark.py --data-dir="your-data-dir" --cache-limit-gib="your-cache limit" 107 | ``` 108 | 109 | Set `data-dir` to a location on an HDD or network drive. The first time you run the code, a random dataset will be generated in this directory. Set `cache-limit-gib` to 0 to disable caching, or to an integer less than the size of `/dev/shm`. 110 | 111 | By default, the benchmark script generates a dummy dataset, with 50k (3,512,512) jpg images. This takes around 7.5 GiB on disk. Around 9 GiB of shared memory is needed to fully cache the dataset. 112 | 113 | All our benchmarks use the default hyperparameters specified in the `run_benchmark.py` file {batch_size: 256, num_workers: 8, pin_memory: True}. 114 | 115 | You can run the entire benchmark sweep like so: 116 | 117 | ```bash 118 | ./run_sweep.sh "YOUR_DATA_DIR" "YOUR_OUTPUT_DIR" 119 | ``` 120 | 121 | You may then reproduce our plots by running: 122 | 123 | ```bash 124 | python analyse_sweep.py --csv_dir="WHERE_YOU_SAVED_THE_OUTPUTS" --fig-save-dir="assets/" 125 | ``` 126 | -------------------------------------------------------------------------------- /analyse_sweep.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import os 4 | 5 | import matplotlib.pyplot as plt 6 | import pandas as pd 7 | import seaborn as sns 8 | 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument("--csv-dir", type=str, default="outputs/") 11 | parser.add_argument("--fig-save-dir", type=str, default="assets/") 12 | 13 | DATASET_SIZE_GIB = 9.2 14 | 15 | 16 | def main(args): 17 | files = glob.glob(os.path.join(args.csv_dir, "*.csv")) 18 | 19 | df = pd.concat((pd.read_csv(f) for f in files), ignore_index=True) 20 | 21 | for col in df.columns: 22 | if col not in ["cache_limit_gib", "epoch_0_time", "epoch_1_time"]: 23 | df = df.drop(col, axis=1) 24 | 25 | df = df.melt( 26 | id_vars=["cache_limit_gib"], 27 | value_vars=["epoch_0_time", "epoch_1_time"], 28 | var_name="epoch", 29 | value_name="time", 30 | ) 31 | 32 | df["cache_limit_gib"] = df["cache_limit_gib"] / DATASET_SIZE_GIB * 100 33 | df["cache_limit_gib"] = df["cache_limit_gib"].clip(upper=100) 34 | 35 | df = df.rename(columns={"cache_limit_gib": "Cache Limit (%)"}) 36 | df = df.rename(columns={"time": "Time (s)"}) 37 | df = df.rename(columns={"epoch": "Epoch"}) 38 | df["Epoch"] = df["Epoch"].replace({"epoch_0_time": "0", "epoch_1_time": "1"}) 39 | 40 | sns.lmplot( 41 | data=df, 42 | x="Cache Limit (%)", 43 | y="Time (s)", 44 | hue="Epoch", 45 | x_jitter=2.0, 46 | ) 47 | plt.ylim(bottom=0) 48 | plt.savefig(os.path.join(args.fig_save_dir, "sweep.png")) 49 | 50 | 51 | if __name__ == "__main__": 52 | args = parser.parse_args() 53 | main(args) 54 | -------------------------------------------------------------------------------- /assets/local_sweep.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Charl-AI/stochastic-caching/e91381d5089ca2fb704c5cb634ab62f61dc9fa33/assets/local_sweep.png -------------------------------------------------------------------------------- /assets/remote_sweep.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Charl-AI/stochastic-caching/e91381d5089ca2fb704c5cb634ab62f61dc9fa33/assets/remote_sweep.png -------------------------------------------------------------------------------- /benchmark/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Charl-AI/stochastic-caching/e91381d5089ca2fb704c5cb634ab62f61dc9fa33/benchmark/__init__.py -------------------------------------------------------------------------------- /benchmark/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Callable 3 | 4 | import numpy as np 5 | import torch 6 | from PIL import Image 7 | from torch.utils.data import Dataset 8 | from torchvision.datasets import FakeData 9 | from torchvision.transforms import v2 10 | from tqdm import tqdm 11 | 12 | from stocaching import SharedCache 13 | 14 | DATASET_LEN = 50000 15 | RAW_DATA_DIMS = (3, 512, 512) 16 | CACHED_DATA_DIMS = (3, 256, 256) 17 | FINAL_DATA_DIMS = (3, 224, 224) 18 | 19 | # for normalisation 20 | IMAGENET_MEAN = [0.485, 0.456, 0.406] 21 | IMAGENET_STD = [0.229, 0.224, 0.225] 22 | 23 | 24 | def save_dummy_data( 25 | data_dir: str, 26 | size: int = DATASET_LEN, 27 | img_size: tuple[int, int, int] = RAW_DATA_DIMS, 28 | ) -> None: 29 | print(f"Saving dummy data to {data_dir}...") 30 | os.makedirs(data_dir, exist_ok=True) 31 | ds = FakeData(size, img_size) 32 | skipped = 0 33 | for i in tqdm(range(len(ds)), desc="Saving dummy data"): 34 | img_path = os.path.join(data_dir, f"{i}.jpg") 35 | if os.path.exists(img_path): 36 | skipped += 1 37 | continue 38 | x, _ = ds[i] 39 | img: Image.Image = x 40 | img.save(img_path) 41 | 42 | print( 43 | f"Dummy data saved to {data_dir}" 44 | + f" ({size - skipped} / {size} images saved. {skipped} imgs already existed)." 45 | ) 46 | 47 | 48 | def get_transforms() -> Callable[[Image.Image], torch.Tensor]: 49 | """Transforms map a PIL image to a uint8 torch Tensor. 50 | This is applied before caching, so it is important that no stochastic 51 | operations are included. 52 | """ 53 | transform_list = [ 54 | v2.ToImage(), 55 | v2.ToDtype(torch.uint8), 56 | v2.Resize(CACHED_DATA_DIMS[1], antialias=True), 57 | ] 58 | 59 | return v2.Compose(transform_list) 60 | 61 | 62 | def get_augmentations() -> Callable[[torch.Tensor], torch.Tensor]: 63 | """Augmentations map a torch uint8 Tensor to a normalised float Tensor 64 | This is applied after caching, so should include all stochastic operations. 65 | """ 66 | aug_list = [ 67 | v2.RandomResizedCrop(FINAL_DATA_DIMS[1], antialias=True), 68 | v2.RandAugment(), 69 | v2.ToDtype(torch.float32), 70 | v2.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD), 71 | ] 72 | return v2.Compose(aug_list) 73 | 74 | 75 | class DummyDataset(Dataset): 76 | def __init__(self, data_dir: str, cache_limit_gib: int): 77 | """PyTorch dataset for dummy data. 78 | No cache is used if cache_limit_gib is 0.""" 79 | self.data_dir = data_dir 80 | self.cache_limit_gib = cache_limit_gib 81 | self.transforms = get_transforms() 82 | self.augmentations = get_augmentations() 83 | 84 | save_dummy_data(data_dir) 85 | 86 | if cache_limit_gib != 0: 87 | self.cache = SharedCache( 88 | cache_limit_gib, DATASET_LEN, CACHED_DATA_DIMS, dtype=torch.uint8 89 | ) 90 | 91 | def _get_img(self, idx) -> torch.Tensor: 92 | """Reads dummy data from disk to a uint8 torch tensor.""" 93 | img_path = os.path.join(self.data_dir, f"{idx}.jpg") 94 | img = Image.open(img_path).convert("RGB") 95 | img = self.transforms(img) 96 | return img 97 | 98 | def __len__(self): 99 | return DATASET_LEN 100 | 101 | def __getitem__(self, idx) -> torch.Tensor: 102 | # caching disabled 103 | if self.cache_limit_gib == 0: 104 | return self.augmentations(self._get_img(idx)) 105 | 106 | # try to read the image from cache 107 | img = self.cache.get_slot(idx) 108 | # otherwise, read from disk and try to cache 109 | if img is None: 110 | img = self._get_img(idx) # uint8 tensor 111 | self.cache.set_slot(idx, img) 112 | 113 | return self.augmentations(img) 114 | -------------------------------------------------------------------------------- /benchmark/trainer.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.utils.data import DataLoader 6 | from torchvision.models import mobilenet_v3_small 7 | from tqdm import tqdm 8 | 9 | 10 | def train(loader: DataLoader, num_epochs: int) -> list[float]: 11 | # using a tiny model because we do not want to be compute bound when 12 | # benchmarking the dataloading 13 | model = mobilenet_v3_small() 14 | model.to("cuda") 15 | prediction_head = nn.Linear(1000, 10) 16 | prediction_head.to("cuda") 17 | model.train() 18 | params = list(model.parameters()) + list(prediction_head.parameters()) 19 | optim = torch.optim.Adam(params, lr=1e-3) 20 | 21 | times = [] 22 | for epoch in range(num_epochs): 23 | torch.cuda.synchronize() 24 | epoch_start = time.time() 25 | for batch in tqdm(loader, desc=f"Epoch {epoch}"): 26 | x = batch.to("cuda") 27 | embeddings = model(x) 28 | logits = prediction_head(embeddings) 29 | # just random labels 30 | y = torch.randint(0, 10, (len(x),), device="cuda") 31 | loss = nn.CrossEntropyLoss()(logits, y) 32 | optim.zero_grad() 33 | loss.backward() 34 | optim.step() 35 | 36 | torch.cuda.synchronize() # wait for all computations to finish 37 | epoch_end = time.time() 38 | epoch_time = epoch_end - epoch_start 39 | times.append(epoch_time) 40 | print(f"Epoch {epoch} took {epoch_time:.3f}s") 41 | return times 42 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "stocaching" 3 | version = "0.2.0" 4 | description = "A tiny library for stochastic dataset caching in PyTorch." 5 | authors = ["C Jones "] 6 | license = "MIT" 7 | readme = "README.md" 8 | 9 | [tool.poetry.dependencies] 10 | python = ">=3.10,<3.13" 11 | # NB: torch 2.1.0 has a known issue: 12 | # https://stackoverflow.com/questions/76327419/valueerror-libcublas-so-0-9-not-found-in-the-system-path 13 | # an easy workaround for development is just to reinstall with pip 14 | # this does not affect users of this library 15 | torch = "^2.1.0" 16 | numpy = "^1.26.1" 17 | 18 | 19 | [tool.poetry.group.dev.dependencies] 20 | torchvision = "^0.16.0" 21 | tqdm = "^4.66.1" 22 | matplotlib = "^3.8.0" 23 | pandas = "^2.1.1" 24 | seaborn = "^0.13.0" 25 | pytest = "^7.4.3" 26 | 27 | [build-system] 28 | requires = ["poetry-core"] 29 | build-backend = "poetry.core.masonry.api" 30 | -------------------------------------------------------------------------------- /run_benchmark.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import pandas as pd 5 | import torch 6 | from torch.utils.data import DataLoader 7 | 8 | from benchmark.dataset import DummyDataset 9 | from benchmark.trainer import train 10 | 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument("--seed", type=int, default=42) 13 | parser.add_argument("--data-dir", type=str, default="/data2/dummy_data") 14 | parser.add_argument("--cache-limit-gib", type=int, default=0) 15 | parser.add_argument("--batch-size", type=int, default=256) 16 | parser.add_argument("--num-workers", type=int, default=8) 17 | parser.add_argument("--pin-memory", type=bool, default=True) 18 | parser.add_argument("--output-dir", type=str, default="outputs/") 19 | 20 | NUM_EPOCHS = 2 21 | 22 | 23 | def main(args): 24 | # first epoch fills cache, second epoch uses cache 25 | torch.manual_seed(args.seed) 26 | dataset = DummyDataset(args.data_dir, args.cache_limit_gib) 27 | loader = DataLoader( 28 | dataset, 29 | batch_size=args.batch_size, 30 | num_workers=args.num_workers, 31 | shuffle=True, 32 | pin_memory=args.pin_memory, 33 | ) 34 | times = train(loader, NUM_EPOCHS) 35 | assert len(times) == 2 36 | epoch_0_time = times[0] 37 | epoch_1_time = times[1] 38 | 39 | os.makedirs(args.output_dir, exist_ok=True) 40 | 41 | run_stats = { 42 | "seed": [args.seed], 43 | "data_dir": [args.data_dir], 44 | "cache_limit_gib": [args.cache_limit_gib], 45 | "batch_size": [args.batch_size], 46 | "num_workers": [args.num_workers], 47 | "pin_memory": [args.pin_memory], 48 | "epoch_0_time": [epoch_0_time], 49 | "epoch_1_time": [epoch_1_time], 50 | } 51 | print(f"Run stats: {run_stats}") 52 | df = pd.DataFrame.from_dict(run_stats) 53 | df.to_csv( 54 | os.path.join(args.output_dir, f"run_{args.seed}_{args.cache_limit_gib}.csv") 55 | ) 56 | 57 | 58 | if __name__ == "__main__": 59 | args = parser.parse_args() 60 | main(args) 61 | -------------------------------------------------------------------------------- /run_sweep.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # first argument is data dir 4 | data_dir=$1 5 | 6 | # second argument is output dir 7 | output_dir=$2 8 | 9 | seeds=( 1 2 3 ) 10 | cache_sizes=( 0 1 2 3 4 5 6 7 8 9 ) 11 | 12 | for seed in "${seeds[@]}" 13 | do 14 | for cache_size in "${cache_sizes[@]}" 15 | do 16 | python3 run_benchmark.py --seed "$seed" --cache-limit-gib "$cache_size" --data-dir "$data_dir" --output-dir "$output_dir" 17 | done 18 | done 19 | 20 | 21 | -------------------------------------------------------------------------------- /stocaching/__init__.py: -------------------------------------------------------------------------------- 1 | """Stocaching, a tiny library for stochastic dataset caching in PyTorch.""" 2 | 3 | import ctypes 4 | import multiprocessing as mp 5 | import os 6 | from enum import Enum 7 | 8 | import numpy as np 9 | import torch 10 | 11 | __all__ = ["SharedCache", "get_shm_size"] 12 | 13 | BYTES_PER_GIB = 1024**3 14 | 15 | C_DTYPES = { 16 | torch.bool: ctypes.c_bool, 17 | torch.uint8: ctypes.c_uint8, 18 | torch.int8: ctypes.c_int8, 19 | torch.int16: ctypes.c_int16, 20 | torch.int32: ctypes.c_int32, 21 | torch.int64: ctypes.c_int64, 22 | torch.float32: ctypes.c_float, 23 | torch.float64: ctypes.c_double, 24 | } 25 | 26 | 27 | class SlotState(Enum): 28 | EMPTY = 0 29 | SET = 1 30 | # in this context, OOB means outside the range of the cache, 31 | # but inside the range of the full dataset 32 | OOB = 2 33 | 34 | 35 | class SharedCache: 36 | """A simple shared memory cache for use in PyTorch datasets. 37 | 38 | You can set a size limit for the cache to take. If your dataset 39 | exceeds this size, the cache will only allocate slots for the first N samples. 40 | This allows you to speed up training by caching only a subset of your dataset. 41 | When applied to a large, shuffled dataset, we call this 'stochastic caching'. 42 | 43 | You may interact with the cache directly as if it were a list of slots, 44 | with one slot per sample. Get and set with `x = cache[0]` and `cache[0] = x`. 45 | 46 | Using the getter and setter directly can be fiddly if you are only caching part 47 | of the dataset. We expose two convenience methods (`get_slot` and `set_slot`), 48 | which simplify usage by allowing you to treat the cache as if it were the same 49 | size as the full dataset. 50 | 51 | Example usage: 52 | 53 | ```python 54 | import torch 55 | from stocaching import SharedCache 56 | from torch.utils.data import Dataset 57 | 58 | class MyDataset(Dataset): 59 | def __init__(self): 60 | super().__init__() 61 | 62 | ... # set up dataset 63 | 64 | dataset_len = N # number of samples in the full dataset 65 | data_dims = (3, 32, 32) # data dims (not including batch) 66 | 67 | # initialize the cache 68 | self.cache = SharedCache( 69 | size_limit_gib=32, 70 | dataset_len=dataset_len, 71 | data_dims=data_dims, 72 | dtype=torch.uint8 73 | ) 74 | def __getitem__(self, idx): 75 | # retrieve data from cache if it's there 76 | x = self.cache.get_slot(idx) 77 | # x will be None if the cache slot was empty or OOB 78 | if x is None: 79 | x = ... # load data to uint8 tensor from disk 80 | # try to cache x, no-op if idx is OOB of the cache 81 | self.cache.set_slot(idx, x) 82 | return x 83 | ``` 84 | """ 85 | 86 | def __init__( 87 | self, 88 | size_limit_gib: int, 89 | dataset_len: int, 90 | data_dims: tuple[int, ...], 91 | dtype: torch.dtype = torch.uint8, 92 | ) -> None: 93 | """ 94 | Args: 95 | size_limit_gib (int): Maximum size of the cache in GiB. 96 | dataset_len (int): Length (number of samples) of the full dataset. 97 | data_dims (tuple[int, ...]): Dimensions of the data to be stored in the 98 | cache. E.g. (C, H, W) for 2D image data. Does not include batch dim. 99 | dtype (torch.dtype, optional): Torch data type of the dataset to cache. 100 | Must be in the subset of torch dtypes with corresponding ctypes: 101 | bool, uint8, int8, int16, int32, int64, float32, float64. 102 | Defaults to torch.uint8 (this is usually best for jpg images). 103 | """ 104 | if dtype not in C_DTYPES: 105 | raise ValueError( 106 | f"Unsupported dtype: {dtype}. Must be one of {C_DTYPES.keys()}" 107 | ) 108 | dtype_bytes = dtype.itemsize 109 | slot_bytes = int(np.prod(data_dims) * dtype_bytes) 110 | dataset_bytes = slot_bytes * dataset_len 111 | size_limit_bytes = size_limit_gib * BYTES_PER_GIB 112 | 113 | # we allocate a flat 8-bit array to keep track of which samples are cached, 114 | # which are not cached yet, and which are out of bounds of the cache 115 | aux_bytes = dataset_len * torch.uint8.itemsize 116 | 117 | ds_and_aux_bytes = dataset_bytes + aux_bytes 118 | 119 | if ds_and_aux_bytes > size_limit_bytes: 120 | cache_len = int((size_limit_bytes - aux_bytes) / slot_bytes) 121 | print( 122 | f"Dataset size ({ds_and_aux_bytes / BYTES_PER_GIB:.1f} GiB)" 123 | + f" exceeds cache limit ({size_limit_gib} GiB)." 124 | + f" Allocating space to cache {cache_len} / {dataset_len} samples." 125 | ) 126 | 127 | else: 128 | cache_len = dataset_len 129 | print( 130 | f"Dataset size ({ds_and_aux_bytes / BYTES_PER_GIB:.1f} GiB)" 131 | + f" fits in cache limit ({size_limit_gib} GiB)." 132 | + f" Allocating space to cache all {cache_len} samples." 133 | ) 134 | 135 | shared_array_base = mp.Array( 136 | C_DTYPES[dtype], int(np.prod(data_dims)) * cache_len 137 | ) 138 | shared_array = np.ctypeslib.as_array(shared_array_base.get_obj()) 139 | shared_array = shared_array.reshape((cache_len, *data_dims)) 140 | self._shm = torch.from_numpy(shared_array) 141 | self._shm *= 0 142 | 143 | shared_aux_base = mp.Array(C_DTYPES[torch.uint8], dataset_len) 144 | shared_aux = np.ctypeslib.as_array(shared_aux_base.get_obj()) 145 | self._aux = torch.from_numpy(shared_aux) 146 | self._aux *= 0 147 | 148 | # only cache the first cache_len samples by index 149 | self._aux[cache_len:] = SlotState.OOB.value 150 | 151 | @property 152 | def array(self) -> torch.Tensor: 153 | """Access the full underlying cache (just a tensor backed by shared memory). 154 | Returns a torch tensor of shape (cache_len, *data_dims). 155 | The dtype is whatever you specified when constructing the cache.""" 156 | return self._shm 157 | 158 | @property 159 | def aux_array(self) -> torch.Tensor: 160 | """Access the auxiliary array (just a tensor backed by shared memory). 161 | The auxiliary array keeps track of which samples from the full dataset have been 162 | cached, which samples are yet to be cached, and which are OOB. 163 | Returns a shared memory torch uint8 tensor, shape (dataset_len,). 164 | `self.aux_array[idx] == 0` means sample idx is not cached. 165 | `self.aux_array[idx] == 1` means sample idx is cached. 166 | `self.aux_array[idx] == 2` means sample idx is OOB. 167 | """ 168 | return self._aux 169 | 170 | def __getitem__(self, idx: int): 171 | return self.array[idx] 172 | 173 | def __setitem__(self, idx: int, value: torch.Tensor): 174 | self.array[idx] = value 175 | 176 | def __len__(self): 177 | return len(self.array) 178 | 179 | def _slot_state(self, idx: int) -> SlotState: 180 | """Get the state of a slot in the cache. Raises an error if idx is outside 181 | the range of the full dataset.""" 182 | if idx < 0 or idx >= len(self.aux_array): 183 | raise IndexError( 184 | f"Index {idx} out of bounds for dataset of length {len(self.aux_array)}" 185 | ) 186 | return SlotState(self._aux[idx].item()) 187 | 188 | def set_slot( 189 | self, 190 | idx: int, 191 | value: torch.Tensor, 192 | allow_oob_idx: bool = True, 193 | allow_overwrite: bool = False, 194 | ) -> None: 195 | """Set a slot in the cache to a value. 196 | 197 | The main reason to use this method over __setitem__ is that we 198 | allow you to call this method when idx is out of bounds of the 199 | cache, but within the range of the full dataset. 200 | 201 | In this case the method is a no-op when idx is out of bounds. 202 | 203 | Args: 204 | idx (int): Index of the slot to set. 205 | value (torch.Tensor): Value to set the slot to. 206 | allow_oob_idx (bool, optional): When False, raises an error if 207 | idx is out of bounds of the cache. Defaults to True. 208 | allow_overwrite (bool, optional): When False, raises an error if 209 | the slot is not empty. Defaults to False. 210 | """ 211 | slot_state = self._slot_state(idx) 212 | if slot_state == SlotState.OOB: 213 | if not allow_oob_idx: 214 | raise IndexError( 215 | f"Index {idx} out of bounds of SharedCache of length {len(self)}" 216 | ) 217 | return # no-op 218 | 219 | if slot_state == SlotState.SET and not allow_overwrite: 220 | raise RuntimeError( 221 | f"Tried to overwrite non-empty slot {idx=} in SharedCache." 222 | ) 223 | 224 | self[idx] = value 225 | self.aux_array[idx] = SlotState.SET.value 226 | 227 | def get_slot( 228 | self, 229 | idx: int, 230 | allow_oob_idx: bool = True, 231 | allow_empty_slot: bool = True, 232 | ) -> torch.Tensor | None: 233 | """Get the value of a slot in the cache. 234 | 235 | The main reason to use this method over __getitem__ is that we 236 | allow you to call this method when idx is out of bounds of the 237 | cache, but within the range of the dataset. 238 | 239 | In this case the method returns None when idx is out of bounds. 240 | 241 | Args: 242 | idx (int): Index of the slot to get. 243 | allow_oob_idx (bool, optional): When False, raises an error if 244 | idx is out of bounds of the cache. Defaults to True. 245 | allow_empty_slot (bool, optional): When True, returns 246 | None if the slot is empty. Otherwise, raises 247 | an exception. Defaults to True. 248 | """ 249 | slot_state = self._slot_state(idx) 250 | if slot_state == SlotState.OOB: 251 | if not allow_oob_idx: 252 | raise IndexError( 253 | f"Index {idx} out of bounds of SharedCache of length {len(self)}" 254 | ) 255 | return None 256 | 257 | if slot_state == SlotState.EMPTY: 258 | if allow_empty_slot: 259 | return None 260 | else: 261 | raise RuntimeError( 262 | f"Tried to read from an empty slot {idx=} in SharedCache." 263 | ) 264 | return self[idx] 265 | 266 | def clear(self) -> None: 267 | """Clear all slots in the cache.""" 268 | self._shm *= 0 269 | self._aux *= 0 270 | self._aux[len(self) :] = SlotState.OOB.value 271 | 272 | 273 | def get_shm_size() -> int: 274 | """Get size of /dev/shm. The size limit of the shared memory cache 275 | should not exceed this. 276 | 277 | N.B. You may check the size of /dev/shm on the command line with `df -h`. 278 | A simple way to (temporarily) change it is to run: 279 | `mount -o remount,size=128G /dev/shm` (change to 128 GiB, for example). 280 | 281 | Returns: 282 | (int) Size of /dev/shm in GiB 283 | """ 284 | stats = os.statvfs("/dev/shm") 285 | shm_bytes = stats.f_bsize * stats.f_blocks 286 | shm_size = shm_bytes / BYTES_PER_GIB 287 | return int(shm_size) 288 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Charl-AI/stochastic-caching/e91381d5089ca2fb704c5cb634ab62f61dc9fa33/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_shared_cache.py: -------------------------------------------------------------------------------- 1 | import multiprocessing as mp 2 | 3 | import numpy as np 4 | import pytest 5 | import torch 6 | 7 | from stocaching import SharedCache 8 | 9 | BYTES_PER_GIB = 1024**3 10 | 11 | CACHE_SIZE_GIB = 1 12 | DATASET_SIZE_GIB = 2 13 | DATA_DIMS = (3, 32, 32) 14 | DTYPE = torch.uint8 15 | 16 | DATASET_LEN = ( 17 | DATASET_SIZE_GIB 18 | * BYTES_PER_GIB 19 | // (DATA_DIMS[0] * DATA_DIMS[1] * DATA_DIMS[2] * DTYPE.itemsize) 20 | ) 21 | 22 | 23 | @pytest.fixture(scope="module") 24 | def cache() -> SharedCache: 25 | return SharedCache( 26 | CACHE_SIZE_GIB, 27 | DATASET_LEN, 28 | DATA_DIMS, 29 | dtype=DTYPE, 30 | ) 31 | 32 | 33 | def test_array_property(cache: SharedCache): 34 | cache_array = cache.array 35 | assert isinstance(cache_array, torch.Tensor) 36 | assert cache_array.shape == (len(cache),) + DATA_DIMS 37 | assert cache_array.dtype == DTYPE 38 | 39 | 40 | def test_aux_array_property(cache: SharedCache): 41 | aux_array = cache.aux_array 42 | assert isinstance(aux_array, torch.Tensor) 43 | assert aux_array.shape == (DATASET_LEN,) 44 | assert aux_array.dtype == DTYPE 45 | 46 | 47 | def test_cache_size(cache: SharedCache): 48 | cache_array = cache.array 49 | aux_array = cache.aux_array 50 | 51 | cache_bytes = np.prod(cache_array.shape) * cache_array.dtype.itemsize 52 | aux_bytes = np.prod(aux_array.shape) * aux_array.dtype.itemsize 53 | 54 | assert cache_bytes + aux_bytes <= CACHE_SIZE_GIB * BYTES_PER_GIB 55 | assert len(cache_array) == (CACHE_SIZE_GIB * BYTES_PER_GIB - aux_bytes) // ( 56 | np.prod(DATA_DIMS) * DTYPE.itemsize 57 | ) 58 | 59 | 60 | def test_set_slot(cache: SharedCache): 61 | # set a random slot within bounds of the cache 62 | cache_len = len(cache) 63 | idx_to_set = int(torch.randint(0, cache_len, (1,)).item()) 64 | value = torch.randint(0, 255, DATA_DIMS, dtype=DTYPE) 65 | 66 | cache.set_slot(idx_to_set, value, allow_oob_idx=False, allow_overwrite=False) 67 | assert torch.all(cache[idx_to_set] == value) 68 | 69 | value2 = torch.randint(0, 255, DATA_DIMS, dtype=DTYPE) 70 | cache.set_slot(idx_to_set, value2, allow_overwrite=True) 71 | assert torch.all(cache[idx_to_set] == value2) 72 | 73 | value3 = torch.randint(0, 255, DATA_DIMS, dtype=DTYPE) 74 | with pytest.raises(RuntimeError): 75 | cache.set_slot(idx_to_set, value3, allow_overwrite=False) 76 | 77 | # set a random slot outside the bounds of the cache 78 | ds_len = len(cache.aux_array) 79 | if ds_len != cache_len: 80 | idx_to_set = int(torch.randint(cache_len, ds_len, (1,)).item()) 81 | value4 = torch.randint(0, 255, DATA_DIMS, dtype=DTYPE) 82 | 83 | with pytest.raises(IndexError): 84 | cache.set_slot(idx_to_set, value4, allow_oob_idx=False) 85 | 86 | cache.set_slot(idx_to_set, value4, allow_oob_idx=True) 87 | 88 | # try to set a slot outside the bounds of the dataset 89 | # note how it should raise an error even if allow_oob_idx=True 90 | idx = ds_len + 1 91 | value5 = torch.randint(0, 255, DATA_DIMS, dtype=DTYPE) 92 | with pytest.raises(IndexError): 93 | cache.set_slot(idx, value5, allow_oob_idx=True) 94 | 95 | 96 | def test_get_slot(cache: SharedCache): 97 | # get a random slot within bounds of the cache 98 | cache_len = len(cache) 99 | idx_to_get = int(torch.randint(0, cache_len, (1,)).item()) 100 | 101 | with pytest.raises(RuntimeError): 102 | _ = cache.get_slot(idx_to_get, allow_oob_idx=False, allow_empty_slot=False) 103 | 104 | val = cache.get_slot(idx_to_get, allow_oob_idx=False, allow_empty_slot=True) 105 | assert val is None 106 | 107 | set_val = torch.randint(0, 255, DATA_DIMS, dtype=DTYPE) 108 | cache.set_slot(idx_to_get, set_val, allow_overwrite=False) 109 | val = cache.get_slot(idx_to_get, allow_oob_idx=False, allow_empty_slot=False) 110 | assert torch.all(val == set_val) 111 | 112 | # get a random slot outside the bounds of the cache, 113 | # but within the bounds of the dataset 114 | ds_len = len(cache.aux_array) 115 | if ds_len != cache_len: 116 | idx_to_get = int(torch.randint(cache_len, ds_len, (1,)).item()) 117 | with pytest.raises(IndexError): 118 | _ = cache.get_slot(idx_to_get, allow_oob_idx=False) 119 | val = cache.get_slot(idx_to_get, allow_oob_idx=True) 120 | assert val is None 121 | 122 | # try to get a slot outside the bounds of the dataset 123 | # note how it should raise an error even if allow_oob_idx=True 124 | idx_to_get = ds_len + 1 125 | with pytest.raises(IndexError): 126 | _ = cache.get_slot(idx_to_get) 127 | 128 | 129 | def test_clear_slots(cache: SharedCache): 130 | cache_len = len(cache) 131 | idx_to_set = int(torch.randint(0, cache_len, (1,)).item()) 132 | 133 | value = torch.randint(0, 255, DATA_DIMS, dtype=DTYPE) 134 | cache.set_slot(idx_to_set, value, allow_oob_idx=False, allow_overwrite=False) 135 | assert torch.all(cache[idx_to_set] == value) 136 | 137 | value2 = torch.randint(0, 255, DATA_DIMS, dtype=DTYPE) 138 | with pytest.raises(RuntimeError): 139 | cache.set_slot(idx_to_set, value2, allow_overwrite=False) 140 | 141 | cache.clear() 142 | 143 | with pytest.raises(RuntimeError): 144 | _ = cache.get_slot(idx_to_set, allow_oob_idx=False, allow_empty_slot=False) 145 | 146 | val = cache.get_slot(idx_to_set, allow_oob_idx=False, allow_empty_slot=True) 147 | assert val is None 148 | 149 | cache.set_slot(idx_to_set, value2, allow_overwrite=False) 150 | assert torch.all(cache[idx_to_set] == value2) 151 | 152 | 153 | @pytest.mark.skip(reason="Not a test, just a helper") 154 | def set_slot(cache: SharedCache, idx: int, value: torch.Tensor): 155 | cache.set_slot(idx, value, allow_overwrite=True, allow_oob_idx=False) 156 | 157 | 158 | @pytest.mark.skip(reason="Not a test, just a helper") 159 | def get_slot(cache: SharedCache, idx: int): 160 | return cache.get_slot(idx, allow_oob_idx=False, allow_empty_slot=False) 161 | 162 | 163 | def test_slots_multiprocess(cache: SharedCache): 164 | cache_len = len(cache) 165 | write_idx_1 = int(torch.randint(0, cache_len, (1,)).item()) 166 | write_idx_2 = int(torch.randint(0, cache_len, (1,)).item()) 167 | write_value_1 = torch.randint(0, 255, DATA_DIMS, dtype=DTYPE) 168 | write_value_2 = torch.randint(0, 255, DATA_DIMS, dtype=DTYPE) 169 | 170 | with mp.Pool(2) as p: 171 | _ = p.starmap( 172 | set_slot, 173 | [ 174 | (cache, write_idx_1, write_value_1), 175 | (cache, write_idx_2, write_value_2), 176 | ], 177 | ) 178 | 179 | with mp.Pool(2) as p: 180 | read_value_1, read_value_2 = p.starmap( 181 | get_slot, [(cache, write_idx_1), (cache, write_idx_2)] 182 | ) 183 | 184 | assert torch.all(read_value_1 == write_value_1) 185 | assert torch.all(read_value_2 == write_value_2) 186 | --------------------------------------------------------------------------------