├── .github └── workflows │ ├── build.yml │ ├── release.yml │ └── release_test.yml ├── .gitignore ├── 3rd └── CMakeLists.txt ├── CMakeLists.txt ├── MANIFEST.in ├── README.md ├── benchmark ├── benchmark_adam.py └── benchmark_cpuadam.py ├── csrc ├── CMakeLists.txt ├── aio.cpp ├── async_file_io.cpp ├── backend.cpp ├── offload.cpp ├── pthread_backend.cpp ├── py_api.cpp ├── space_mgr.cpp └── uring.cpp ├── docker └── Dockerfile ├── include ├── aio.h ├── async_file_io.h ├── asyncio.h ├── backend.h ├── offload.h ├── pthread_backend.h ├── space_mgr.h ├── threadpool.hpp └── uring.h ├── requirements.txt ├── setup.py ├── tensornvme ├── _C │ └── __init__.pyi ├── __init__.py ├── async_file_io.py ├── cli │ ├── __init__.py │ ├── check.py │ └── cli.py └── offload.py ├── tests ├── CMakeLists.txt ├── catch.hpp ├── requirements.txt ├── test_adam.py ├── test_asyncio.cpp ├── test_disk_offloader.py ├── test_offload.cpp └── test_space_mgr.cpp └── version.txt /.github/workflows/build.yml: -------------------------------------------------------------------------------- 1 | name: Build 2 | 3 | on: 4 | pull_request: 5 | types: [synchronize, labeled] 6 | 7 | jobs: 8 | build: 9 | name: Build and Test TensorNVME 10 | if: | 11 | github.event.pull_request.draft == false && 12 | github.base_ref == 'main' && 13 | github.event.pull_request.base.repo.full_name == 'hpcaitech/TensorNVME' && 14 | contains( github.event.pull_request.labels.*.name, 'Run Build and Test') 15 | runs-on: ubuntu-latest 16 | timeout-minutes: 30 17 | steps: 18 | - uses: actions/checkout@v2 19 | with: 20 | ssh-key: ${{ secrets.SSH_KEY_FOR_CI }} 21 | - uses: actions/setup-python@v2 22 | with: 23 | python-version: '3.7.12' 24 | - name: Install tensornvme 25 | run: | 26 | pip install -r requirements.txt 27 | pip install -v -e . 28 | - name: Unit Testing 29 | run: | 30 | pip install -r tests/requirements.txt 31 | PYTHONPATH=$PWD pytest tests 32 | cd build 33 | cmake .. 34 | make 35 | ./tests/test_asyncio 36 | ./tests/test_space_mgr 37 | env: 38 | LD_LIBRARY_PATH: /github/home/.tensornvme/lib 39 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Publish to PyPI 2 | 3 | on: workflow_dispatch 4 | 5 | jobs: 6 | build-n-publish: 7 | if: github.repository == 'hpcaitech/TensorNVMe' && contains(fromJson('["FrankLeeeee", "ver217", "feifeibear", "kurisusnowdeng"]'), github.actor) 8 | name: Build and publish Python 🐍 distributions 📦 to PyPI 9 | runs-on: ubuntu-latest 10 | timeout-minutes: 20 11 | steps: 12 | - uses: actions/checkout@v2 13 | - uses: actions/setup-python@v2 14 | with: 15 | python-version: '3.7.12' 16 | - run: pip install packaging 17 | - run: python setup.py sdist 18 | # publish to PyPI if executed on the main branch 19 | # publish to Test PyPI if executed on the develop branch 20 | - name: Publish package to PyPI 21 | uses: pypa/gh-action-pypi-publish@release/v1 22 | with: 23 | user: __token__ 24 | password: ${{ secrets.PYPI_API_TOKEN }} 25 | verbose: true 26 | -------------------------------------------------------------------------------- /.github/workflows/release_test.yml: -------------------------------------------------------------------------------- 1 | name: Publish to Test PyPI 2 | 3 | on: workflow_dispatch 4 | 5 | jobs: 6 | build-n-publish: 7 | if: github.repository == 'hpcaitech/TensorNVMe' && contains(fromJson('["FrankLeeeee", "ver217", "feifeibear", "kurisusnowdeng"]'), github.actor) 8 | name: Build and publish Python 🐍 distributions 📦 to Test PyPI 9 | runs-on: ubuntu-latest 10 | timeout-minutes: 20 11 | steps: 12 | - uses: actions/checkout@v2 13 | - uses: actions/setup-python@v2 14 | with: 15 | python-version: '3.7.12' 16 | - run: pip install packaging 17 | - run: python setup.py sdist 18 | # publish to PyPI if executed on the main branch 19 | # publish to Test PyPI if executed on the develop branch 20 | - name: Publish package to Test PyPI 21 | uses: pypa/gh-action-pypi-publish@release/v1 22 | with: 23 | user: __token__ 24 | password: ${{ secrets.TEST_PYPI_API_TOKEN }} 25 | repository_url: https://test.pypi.org/legacy/ 26 | verbose: true 27 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | cmake-build-debug-remote-host 86 | # pyenv 87 | # For a library or package, you might want to ignore these files since the code is 88 | # intended to run in multiple environments; otherwise, check them in: 89 | # .python-version 90 | 91 | # pipenv 92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 95 | # install all needed dependencies. 96 | #Pipfile.lock 97 | 98 | # poetry 99 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 100 | # This is especially recommended for binary packages to ensure reproducibility, and is more 101 | # commonly ignored for libraries. 102 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 103 | #poetry.lock 104 | 105 | # pdm 106 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 107 | #pdm.lock 108 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 109 | # in version control. 110 | # https://pdm.fming.dev/#use-with-ide 111 | .pdm.toml 112 | 113 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 114 | __pypackages__/ 115 | 116 | # Celery stuff 117 | celerybeat-schedule 118 | celerybeat.pid 119 | 120 | # SageMath parsed files 121 | *.sage.py 122 | 123 | # Environments 124 | .env 125 | .venv 126 | env/ 127 | venv/ 128 | ENV/ 129 | env.bak/ 130 | venv.bak/ 131 | 132 | # Spyder project settings 133 | .spyderproject 134 | .spyproject 135 | 136 | # Rope project settings 137 | .ropeproject 138 | 139 | # mkdocs documentation 140 | /site 141 | 142 | # mypy 143 | .mypy_cache/ 144 | .dmypy.json 145 | dmypy.json 146 | 147 | # Pyre type checker 148 | .pyre/ 149 | 150 | # pytype static type analyzer 151 | .pytype/ 152 | 153 | # Cython debug symbols 154 | cython_debug/ 155 | 156 | # PyCharm 157 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 158 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 159 | # and can be added to the global gitignore or merged into this file. For a more nuclear 160 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 161 | .idea/ 162 | 163 | .vscode/ 164 | cmake-build/ -------------------------------------------------------------------------------- /3rd/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | include(ExternalProject) 2 | 3 | if(DEFINED BACKEND_INSTALL_PREFIX) 4 | set(INSTALL_PREFIX ${BACKEND_INSTALL_PREFIX}) 5 | else() 6 | set(INSTALL_PREFIX ${CMAKE_CURRENT_BINARY_DIR}) 7 | endif() 8 | 9 | if(DEFINED ENV{LD_LIBRARY_PATH}) 10 | string(REPLACE ":" ";" CMAKE_LIBRARY_PATH $ENV{LD_LIBRARY_PATH}) 11 | endif() 12 | 13 | find_library(LIBURING uring ${INSTALL_PREFIX}/lib) 14 | 15 | if(LIBURING) 16 | message("liburing is found in ${LIBURING}") 17 | else(LIBURING) 18 | message("liburing is not found, install in ${INSTALL_PREFIX}") 19 | ExternalProject_Add(extern_uring 20 | GIT_REPOSITORY https://github.com/axboe/liburing.git 21 | GIT_TAG liburing-2.2 22 | BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR}/liburing 23 | SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/liburing 24 | INSTALL_DIR ${CMAKE_CURRENT_BINARY_DIR}/liburing 25 | CONFIGURE_COMMAND ./configure --prefix=${INSTALL_PREFIX} 26 | BUILD_COMMAND make && make install 27 | ) 28 | 29 | add_library(uring STATIC IMPORTED GLOBAL) 30 | SET_PROPERTY(TARGET uring PROPERTY IMPORTED_LOCATION ${INSTALL_PREFIX}/lib/liburing.a) 31 | file(MAKE_DIRECTORY ${INSTALL_PREFIX}/include/) 32 | target_include_directories(uring INTERFACE ${INSTALL_PREFIX}/include/) 33 | add_dependencies(uring extern_uring) 34 | endif(LIBURING) 35 | 36 | find_library(LIBAIO aio ${INSTALL_PREFIX}/lib) 37 | 38 | if(LIBAIO) 39 | message("libaio is found in ${LIBAIO}") 40 | else(LIBAIO) 41 | message("libaio is not found, install in ${INSTALL_PREFIX}") 42 | ExternalProject_Add(extern_aio 43 | GIT_REPOSITORY https://pagure.io/libaio.git 44 | GIT_TAG libaio-0.3.113 45 | BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR}/libaio 46 | SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/libaio 47 | INSTALL_DIR ${CMAKE_CURRENT_BINARY_DIR}/libaio 48 | CONFIGURE_COMMAND "" 49 | UPDATE_COMMAND "" 50 | BUILD_COMMAND make prefix=${INSTALL_PREFIX} install 51 | INSTALL_COMMAND "" 52 | ) 53 | 54 | add_library(aio STATIC IMPORTED GLOBAL) 55 | SET_PROPERTY(TARGET aio PROPERTY IMPORTED_LOCATION ${INSTALL_PREFIX}/lib/libaio.a) 56 | file(MAKE_DIRECTORY ${INSTALL_PREFIX}/include/) 57 | target_include_directories(aio INTERFACE ${INSTALL_PREFIX}/include/) 58 | add_dependencies(aio extern_aio) 59 | endif(LIBAIO) -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.16) 2 | project( 3 | tensornvme 4 | DESCRIPTION "A tensor disk offloader without data copying." 5 | LANGUAGES CXX 6 | ) 7 | 8 | set(CMAKE_CXX_STANDARD 14) 9 | set(CMAKE_CXX_STANDARD_REQUIRED TRUE) 10 | 11 | add_subdirectory(3rd) 12 | add_subdirectory(csrc) 13 | 14 | if(EXISTS ${CMAKE_SOURCE_DIR}/tests) 15 | add_subdirectory(tests) 16 | endif() -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include *.txt README.md 2 | recursive-include csrc *.cpp *.txt 3 | recursive-include include *.h 4 | recursive-include 3rd *.txt 5 | recursive-include tensornvme *.pyi -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TensorNVME 2 | 3 | A Python Library provides APIs to move PyTorch Tensors between CPU and NVMe. 4 | 5 | ## Dependencies 6 | 7 | - [liburing](https://github.com/axboe/liburing) 8 | - [libaio](https://pagure.io/libaio) 9 | 10 | ## Install 11 | 12 | This package is only supported on Linux. `liburing` and `libaio` can be automatically installed. `liburing` is supported on Linux >= `5.10`, and it won't be installed if the version of your Linux < `5.10`. 13 | 14 | It will search `libaio` and `liburing` in `/usr/lib`, `/usr/lib64` and `$LD_LIBRARY_PATH`. If not found, backends will be installed in `~/.tensornvme`, and `~/.bashrc` will be modified to set `$LD_LIBRARY_PATH` correctly. **Please `source ~/.bashrc` after installation.** If you use other shells, please make sure `$LD_LIBRARY_PATH` is set correctly. 15 | 16 | > You must install pytorch and cmake before installing tensornvme. Once you upgrade pytorch, remember to reinstall tensornvme. 17 | 18 | ### From source 19 | 20 | ```shell 21 | git clone https://github.com/hpcaitech/TensorNVMe.git && cd TensorNVMe 22 | ``` 23 | 24 | First, install requirements: 25 | 26 | ```shell 27 | pip install -r requirements.txt 28 | ``` 29 | 30 | To install `tensornvme` with `liburing` and `libaio`: 31 | 32 | ```shell 33 | pip install -v --no-cache-dir . 34 | ``` 35 | 36 | To install `tensornvme` with only `liburing`: 37 | 38 | ```shell 39 | DISABLE_AIO=1 pip install -v --no-cache-dir . 40 | ``` 41 | 42 | To install `tensornvme` with only `libaio`: 43 | 44 | ```shell 45 | DISABLE_URING=1 pip install -v --no-cache-dir . 46 | ``` 47 | 48 | If you want to install `libaio` or `liburing` for system: 49 | 50 | ```shell 51 | WITH_ROOT=1 sudo pip install -v --no-cache-dir . 52 | ``` 53 | 54 | Then they will be installed in `/usr` and `~/.bashrc` will not be modified. Make sure you have root access. 55 | 56 | ### From PIP 57 | 58 | ```shell 59 | pip install packaging 60 | pip install tensornvme 61 | ``` 62 | 63 | All acceptable environment variables are the same as those when installing from source. 64 | 65 | ## Use docker 66 | 67 | ```shell 68 | git clone https://github.com/hpcaitech/TensorNVMe.git && cd TensorNVMe/docker && docker build -t tensornvme . 69 | ``` 70 | 71 | ## CLI 72 | 73 | We provide a CLI to test whether backends work well. 74 | 75 | ```shell 76 | tensornvme check 77 | ``` 78 | 79 | ## Usage 80 | 81 | It provide both synchronize and asynchronize I/O API. 82 | 83 | > Only CPU and contiguous tensors can be offloaded. 84 | 85 | Synchronize API: 86 | 87 | ```python 88 | import torch 89 | from tensornvme import DiskOffloader 90 | 91 | x = torch.rand(2, 2) 92 | y = torch.rand(4, 4, 4) 93 | offloader = DiskOffloader('./offload') 94 | offloader.sync_write(x) 95 | # x is saved to a file on disk (in ./offload folder) and the memory of x is freed 96 | offloader.sync_read(x) 97 | # x is restored 98 | offloader.sync_writev([x, y]) 99 | # x and y are offloaded 100 | offloader.sync_readv([x, y]) 101 | # x and y are restored. 102 | # sync_writev() and sync_readv() are order sensitive 103 | # E.g. sync_writev([x, y]) and sync_writev([y, x]) are different 104 | ``` 105 | 106 | Asynchronize API: 107 | 108 | ```python 109 | import torch 110 | from tensornvme import DiskOffloader 111 | 112 | x = torch.rand(2, 2) 113 | y = torch.rand(4, 4, 4) 114 | offloader = DiskOffloader('./offload') 115 | offloader.async_write(x) 116 | # x is being offloaded in the background 117 | offloader.sync_write_events() 118 | # x is offloaded and the memory of x is freed 119 | offloader.async_read(x) 120 | # x is being restored in the background 121 | offloader.sync_read_events() 122 | # x is restored 123 | offloader.async_writev([x, y]) 124 | # x and y are being offloaded in the background 125 | offloader.synchronize() 126 | # synchronize() will synchronize both write and read events. 127 | offloader.async_readv([x, y]) 128 | offloader.synchronize() 129 | # x and y are restored. 130 | # async_writev() and async_readv() are also order sensitive 131 | ``` 132 | 133 | You can use asynchronize API to overlap computation and data moving. 134 | 135 | ```python 136 | tensors = [] 137 | 138 | for _ in range(10): 139 | tensor = torch.rand(2, 2) 140 | tensors.append(tensor) 141 | offloader.sync_write(tensor) 142 | 143 | offloader.sync_read(tensors[0]) 144 | 145 | # prefetch=1, writing tensor[i] and reading tensor[i+1] 146 | for i, tensor in enumerate(tensors): 147 | offloader.sync_read_events() 148 | if i + 1 < len(tensors): 149 | offloader.async_read(tensors[i+1]) 150 | tensor.mul_(2.0) # compute 151 | offloader.sync_write_events() 152 | offloader.async_write(tensor) 153 | offloader.synchronize() 154 | ``` 155 | 156 | ## How to test 157 | 158 | We have C++ test scrpits for `AsyncIO` and `SpaceManager` class. Make sure you have installed `liburing` and `libaio`, and set environment variables correctly before testing. To run the tests: 159 | 160 | ```shell 161 | mkdir build 162 | cd build 163 | cmake .. 164 | make 165 | ./test_asyncio 166 | ./test_space_mgr 167 | ``` 168 | 169 | We also have python unit tests. Make sure you have installed `pytest`. To run: 170 | 171 | ```shell 172 | pytest ./tests 173 | ``` 174 | 175 | ## How to benchmark 176 | 177 | We have benchmarks for `Adam` and `CpuAdam` with different backend and prefetch depth to validate TensorNVME's speed. To run the benchmark: 178 | 179 | ```shell 180 | cd benchmark 181 | python benchmark_adam.py 182 | python benchmark_cpuadam.py 183 | ``` 184 | -------------------------------------------------------------------------------- /benchmark/benchmark_adam.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Optional 3 | 4 | import torch 5 | import torch.nn as nn 6 | from tqdm import tqdm 7 | from transformers import GPT2Config, GPT2LMHeadModel 8 | 9 | from tensornvme import DiskOffloader 10 | 11 | N_WARMUP = 2 12 | N_ACTIVATE = 4 13 | 14 | 15 | class GPTLMModel(nn.Module): 16 | def __init__(self, hidden_size=768, num_layers=12, num_attention_heads=12, max_seq_len=1024, vocab_size=50257, checkpoint=False): 17 | super().__init__() 18 | self.checkpoint = checkpoint 19 | self.model = GPT2LMHeadModel(GPT2Config(n_embd=hidden_size, n_layer=num_layers, 20 | n_head=num_attention_heads, n_positions=max_seq_len, n_ctx=max_seq_len, vocab_size=vocab_size)) 21 | if checkpoint: 22 | self.model.gradient_checkpointing_enable() 23 | 24 | def forward(self, input_ids, attention_mask): 25 | # Only return lm_logits 26 | return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=not self.checkpoint)[0] 27 | 28 | 29 | def gpt2_medium(checkpoint=False): 30 | return GPTLMModel(hidden_size=1024, num_layers=24, num_attention_heads=16, checkpoint=checkpoint) 31 | 32 | 33 | def gpt2_xl(checkpoint=False): 34 | return GPTLMModel(hidden_size=1600, num_layers=48, num_attention_heads=16, checkpoint=checkpoint) 35 | 36 | 37 | def gpt2_8b(checkpoint=False): 38 | return GPTLMModel(hidden_size=4096, num_layers=90, num_attention_heads=16, checkpoint=checkpoint) 39 | 40 | 41 | def gpt2_20b(checkpoint=False): 42 | return GPTLMModel(hidden_size=8192, num_layers=25, num_attention_heads=16, checkpoint=checkpoint) 43 | 44 | 45 | def adam(step, lr, param, grad, exp_avg, exp_avg_sq, beta1=0.9, beta2=0.999, eps=1e-12): 46 | exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) 47 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2) 48 | 49 | bias_correction1 = 1 - beta1 ** step 50 | bias_correction2 = 1 - beta2 ** step 51 | step_size = lr / bias_correction1 52 | bias_correction2_sqrt = math.sqrt(bias_correction2) 53 | denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) 54 | param.addcdiv_(exp_avg, denom, value=-step_size) 55 | 56 | 57 | class Adam(torch.optim.Optimizer): 58 | def __init__(self, params, lr, betas=(0.9, 0.999), offloader: Optional[DiskOffloader] = None, prefetch: int = 0, vecio: bool = False) -> None: 59 | default = dict(lr=lr, betas=betas) 60 | super().__init__(params, default) 61 | self.offloader = offloader 62 | self.prefetch = prefetch 63 | self.vecio = vecio 64 | self.param_to_group = {} 65 | # init states 66 | for group in self.param_groups: 67 | for p in group['params']: 68 | if p.requires_grad: 69 | self.param_to_group[p] = group 70 | state = self.state[p] 71 | state['step'] = 0 72 | state['exp_avg'] = torch.zeros_like(p) 73 | state['exp_avg_sq'] = torch.zeros_like(p) 74 | if self.offloader is None: 75 | continue 76 | if vecio: 77 | self.offloader.sync_writev( 78 | [state['exp_avg'], state['exp_avg_sq']]) 79 | else: 80 | self.offloader.sync_write(state['exp_avg']) 81 | self.offloader.sync_write(state['exp_avg_sq']) 82 | 83 | def step(self, closure=None): 84 | loss = None 85 | if closure is not None: 86 | with torch.enable_grad(): 87 | loss = closure() 88 | 89 | params = [ 90 | p for group in self.param_groups for p in group['params'] if p.grad is not None] 91 | if self.offloader is not None and self.prefetch > 0: 92 | for p in params[:self.prefetch]: 93 | state = self.state[p] 94 | if self.vecio: 95 | self.offloader.sync_readv( 96 | [state['exp_avg'], state['exp_avg_sq']]) 97 | else: 98 | self.offloader.sync_read(state['exp_avg']) 99 | self.offloader.sync_read(state['exp_avg_sq']) 100 | 101 | for i, p in enumerate(params): 102 | state = self.state[p] 103 | group = self.param_to_group[p] 104 | state['step'] += 1 105 | beta1, beta2 = group['betas'] 106 | self._pre_step(i, params) 107 | adam(state['step'], group['lr'], p, p.grad, state['exp_avg'], 108 | state['exp_avg_sq'], beta1=beta1, beta2=beta2) 109 | self._post_step(i, params) 110 | 111 | return loss 112 | 113 | def _pre_step(self, idx, params): 114 | if self.offloader is None: 115 | return 116 | if self.prefetch > 0: 117 | if idx % self.prefetch == 0: 118 | self.offloader.sync_read_events() 119 | if idx + self.prefetch < len(params): 120 | for prefetch_p in params[idx + self.prefetch:idx + self.prefetch * 2]: 121 | prefetch_state = self.state[prefetch_p] 122 | if self.vecio: 123 | self.offloader.async_readv( 124 | [prefetch_state['exp_avg'], prefetch_state['exp_avg_sq']]) 125 | else: 126 | self.offloader.async_read( 127 | prefetch_state['exp_avg']) 128 | self.offloader.async_read( 129 | prefetch_state['exp_avg_sq']) 130 | else: 131 | state = self.state[params[idx]] 132 | if self.vecio: 133 | self.offloader.sync_readv( 134 | [state['exp_avg'], state['exp_avg_sq']]) 135 | else: 136 | self.offloader.sync_read(state['exp_avg']) 137 | self.offloader.sync_read(state['exp_avg_sq']) 138 | 139 | def _post_step(self, idx, params): 140 | if self.offloader is None: 141 | return 142 | state = self.state[params[idx]] 143 | if self.prefetch > 0: 144 | if idx % self.prefetch == 0: 145 | self.offloader.sync_write_events() 146 | if self.vecio: 147 | self.offloader.async_writev( 148 | [state['exp_avg'], state['exp_avg_sq']]) 149 | else: 150 | self.offloader.async_write(state['exp_avg']) 151 | self.offloader.async_write(state['exp_avg_sq']) 152 | else: 153 | if self.vecio: 154 | self.offloader.sync_writev( 155 | [state['exp_avg'], state['exp_avg_sq']]) 156 | else: 157 | self.offloader.sync_write(state['exp_avg']) 158 | self.offloader.sync_write(state['exp_avg_sq']) 159 | 160 | 161 | def run_adam(model: torch.nn.Module, nvme_offload: bool, backend: str, prefetch: int, vecio: bool): 162 | offloader = None 163 | if nvme_offload: 164 | offloader = DiskOffloader('.', 8, backend=backend) 165 | optimizer = Adam(model.parameters(), 1e-3, 166 | offloader=offloader, prefetch=prefetch, vecio=vecio) 167 | for p in model.parameters(): 168 | p.grad = torch.rand_like(p) 169 | for _ in range(N_WARMUP): 170 | optimizer.step() 171 | if not nvme_offload: 172 | desc = 'CPU' 173 | postfix = None 174 | else: 175 | desc = 'NVME' 176 | postfix = {'backend': backend, 'prefetch': prefetch, 'vecio': vecio} 177 | for _ in tqdm(range(N_ACTIVATE), desc=desc, postfix=postfix): 178 | optimizer.step() 179 | 180 | 181 | if __name__ == '__main__': 182 | model = gpt2_xl() 183 | with torch.no_grad(): 184 | run_adam(model, False, 'uring', 0, False) 185 | run_adam(model, True, 'uring', 0, False) 186 | run_adam(model, True, 'uring', 0, True) 187 | run_adam(model, True, 'uring', 1, False) 188 | run_adam(model, True, 'uring', 1, True) 189 | run_adam(model, True, 'uring', 2, False) 190 | run_adam(model, True, 'uring', 2, True) 191 | run_adam(model, True, 'uring', 4, False) 192 | run_adam(model, True, 'uring', 4, True) 193 | -------------------------------------------------------------------------------- /benchmark/benchmark_cpuadam.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | from colossalai.nn.optimizer.cpu_adam import CPUAdam 5 | from tqdm import tqdm 6 | 7 | from tensornvme import DiskOffloader 8 | from benchmark_adam import gpt2_xl 9 | 10 | N_WARMUP = 2 11 | N_ACTIVATE = 4 12 | 13 | 14 | class NVMECPUAdam(CPUAdam): 15 | def __init__(self, model_params, 16 | lr=1e-3, 17 | bias_correction=True, 18 | betas=(0.9, 0.999), 19 | offloader: Optional[DiskOffloader] = None, 20 | prefetch: int = 0, 21 | vecio: bool = False, 22 | eps=1e-8, 23 | weight_decay=0, 24 | adamw_mode=True, 25 | simd_log=False): 26 | super(NVMECPUAdam, self).__init__( 27 | model_params, lr, bias_correction, betas, eps, weight_decay, adamw_mode, simd_log) 28 | 29 | self.offloader = offloader 30 | self.prefetch = prefetch 31 | self.vecio = vecio 32 | # init states 33 | for group in self.param_groups: 34 | for p in group['params']: 35 | if p.requires_grad: 36 | state = self.state[p] 37 | state['step'] = 0 38 | state['exp_avg'] = torch.zeros_like(p) 39 | state['exp_avg_sq'] = torch.zeros_like(p) 40 | if self.offloader is None: 41 | continue 42 | if vecio: 43 | self.offloader.sync_writev( 44 | [state['exp_avg'], state['exp_avg_sq']]) 45 | else: 46 | self.offloader.sync_write(state['exp_avg']) 47 | self.offloader.sync_write(state['exp_avg_sq']) 48 | 49 | @torch.no_grad() 50 | def step(self, closure=None): 51 | loss = None 52 | if closure is not None: 53 | with torch.enable_grad(): 54 | loss = closure() 55 | 56 | for _, group in enumerate(self.param_groups): 57 | self._init_step(group['params']) 58 | for p_i, p in enumerate(group['params']): 59 | 60 | state = self.state[p] 61 | target_device = p.device 62 | state['step'] += 1 63 | beta1, beta2 = group['betas'] 64 | 65 | if target_device.type == 'cpu': 66 | assert p.data.numel() == p.grad.data.numel( 67 | ), "parameter and gradient should have the same size" 68 | assert state['exp_avg'].device.type == 'cpu', "exp_avg should stay on cpu" 69 | assert state['exp_avg_sq'].device.type == 'cpu', "exp_avg should stay on cpu" 70 | self._pre_step(p_i, group['params']) 71 | self.cpu_adam_op.adam_update(self.opt_id, state['step'], group['lr'], beta1, beta2, group['eps'], 72 | group['weight_decay'], group['bias_correction'], p.data, p.grad.data, 73 | state['exp_avg'], state['exp_avg_sq'], -1) 74 | self._post_step(p_i, group['params']) 75 | elif target_device.type == 'cuda': 76 | assert state['exp_avg'].device.type == 'cuda', "exp_avg should stay on cuda" 77 | assert state['exp_avg_sq'].device.type == 'cuda', "exp_avg should stay on cuda" 78 | 79 | bias_correction1 = 1 - beta1 ** state['step'] 80 | bias_correction2 = 1 - beta2 ** state['step'] 81 | 82 | # adam on cuda 83 | self.torch_adam_update(p.data, p.grad.data, state['exp_avg'], state['exp_avg_sq'], group['lr'], 84 | beta1, beta2, group['eps'], group['weight_decay'], bias_correction1, 85 | bias_correction2, self.adamw_mode) 86 | else: 87 | raise RuntimeError 88 | return loss 89 | 90 | def _init_step(self, params): 91 | if self.offloader is not None and self.prefetch > 0: 92 | for p in params[:self.prefetch]: 93 | state = self.state[p] 94 | if self.vecio: 95 | self.offloader.sync_readv( 96 | [state['exp_avg'], state['exp_avg_sq']]) 97 | else: 98 | self.offloader.sync_read(state['exp_avg']) 99 | self.offloader.sync_read(state['exp_avg_sq']) 100 | 101 | def _pre_step(self, idx, params): 102 | if self.offloader is None: 103 | return 104 | if self.prefetch > 0: 105 | if idx % self.prefetch == 0: 106 | self.offloader.sync_read_events() 107 | if idx + self.prefetch < len(params): 108 | for prefetch_p in params[idx + self.prefetch:idx + self.prefetch * 2]: 109 | prefetch_state = self.state[prefetch_p] 110 | if self.vecio: 111 | self.offloader.async_readv( 112 | [prefetch_state['exp_avg'], prefetch_state['exp_avg_sq']]) 113 | else: 114 | self.offloader.async_read( 115 | prefetch_state['exp_avg']) 116 | self.offloader.async_read( 117 | prefetch_state['exp_avg_sq']) 118 | else: 119 | state = self.state[params[idx]] 120 | if self.vecio: 121 | self.offloader.sync_readv( 122 | [state['exp_avg'], state['exp_avg_sq']]) 123 | else: 124 | self.offloader.sync_read(state['exp_avg']) 125 | self.offloader.sync_read(state['exp_avg_sq']) 126 | 127 | def _post_step(self, idx, params): 128 | if self.offloader is None: 129 | return 130 | state = self.state[params[idx]] 131 | if self.prefetch > 0: 132 | if idx % self.prefetch == 0: 133 | self.offloader.sync_write_events() 134 | if self.vecio: 135 | self.offloader.async_writev( 136 | [state['exp_avg'], state['exp_avg_sq']]) 137 | else: 138 | self.offloader.async_write(state['exp_avg']) 139 | self.offloader.async_write(state['exp_avg_sq']) 140 | else: 141 | if self.vecio: 142 | self.offloader.sync_writev( 143 | [state['exp_avg'], state['exp_avg_sq']]) 144 | else: 145 | self.offloader.sync_write(state['exp_avg']) 146 | self.offloader.sync_write(state['exp_avg_sq']) 147 | 148 | 149 | def run_adam(model: torch.nn.Module, nvme_offload: bool, backend: str, prefetch: int, vecio: bool): 150 | offloader = None 151 | if nvme_offload: 152 | offloader = DiskOffloader('.', 8, backend=backend) 153 | params = list(model.cpu().parameters()) 154 | for _, p in enumerate(params): 155 | if p.grad is None and p.requires_grad: 156 | p.grad = torch.rand_like(p.data, dtype=torch.float) 157 | optimizer = NVMECPUAdam( 158 | params, 1e-3, offloader=offloader, prefetch=prefetch, vecio=vecio) 159 | for p in model.parameters(): 160 | p.grad = torch.rand_like(p) 161 | for _ in range(N_WARMUP): 162 | optimizer.step() 163 | if not nvme_offload: 164 | desc = 'CPU' 165 | postfix = None 166 | else: 167 | desc = 'NVME' 168 | postfix = {'backend': backend, 'prefetch': prefetch, 'vecio': vecio} 169 | for _ in tqdm(range(N_ACTIVATE), desc=desc, postfix=postfix): 170 | optimizer.step() 171 | 172 | 173 | if __name__ == '__main__': 174 | model = gpt2_xl() 175 | with torch.no_grad(): 176 | run_adam(model, False, 'uring', 0, False) 177 | run_adam(model, True, 'uring', 0, False) 178 | run_adam(model, True, 'uring', 0, True) 179 | run_adam(model, True, 'uring', 1, False) 180 | run_adam(model, True, 'uring', 1, True) 181 | run_adam(model, True, 'uring', 2, False) 182 | run_adam(model, True, 'uring', 2, True) 183 | run_adam(model, True, 'uring', 4, False) 184 | run_adam(model, True, 'uring', 4, True) 185 | -------------------------------------------------------------------------------- /csrc/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_library(colo_asyncio 2 | OBJECT 3 | aio.cpp 4 | uring.cpp) 5 | target_link_libraries(colo_asyncio 6 | PUBLIC uring 7 | PUBLIC aio) 8 | target_include_directories(colo_asyncio PUBLIC ../include) 9 | 10 | 11 | add_library(space_mgr 12 | OBJECT 13 | space_mgr.cpp) 14 | target_include_directories(space_mgr PUBLIC ../include) 15 | -------------------------------------------------------------------------------- /csrc/aio.cpp: -------------------------------------------------------------------------------- 1 | #include "aio.h" 2 | 3 | AIOAsyncIO::AIOAsyncIO(unsigned int n_entries, unsigned int n_tasks) 4 | { 5 | // printf("Initializing the io Context\n"); 6 | this->max_nr = n_entries; 7 | io_setup(n_entries, &(this->io_ctx)); /* 初始化ioctx*/ 8 | this->timeout.tv_sec = 0; 9 | this->timeout.tv_nsec = 100000000; 10 | } 11 | 12 | void AIOAsyncIO::register_file(int fd) {} 13 | 14 | AIOAsyncIO::~AIOAsyncIO() 15 | { 16 | // printf("Closing AsyncIO context\n"); 17 | synchronize(); 18 | io_destroy(this->io_ctx); 19 | } 20 | 21 | void AIOAsyncIO::get_event(WaitType wt) 22 | { 23 | std::unique_ptr events(new io_event[this->max_nr]); 24 | int num_events; 25 | 26 | if (wt == WAIT) 27 | num_events = io_getevents(io_ctx, this->min_nr, this->max_nr, events.get(), &(this->timeout)); /* 获得异步I/O event个数 */ 28 | else 29 | num_events = io_getevents(io_ctx, 0, this->max_nr, events.get(), &(this->timeout)); /* 获得异步I/O event个数 */ 30 | 31 | for (int i = 0; i < num_events; i++) /* 开始获取每一个event并且做相应处理 */ 32 | { 33 | struct io_event event = events.get()[i]; 34 | std::unique_ptr data(static_cast(event.data)); 35 | if (data->type == WRITE) 36 | this->n_write_events--; 37 | else if (data->type == READ) 38 | this->n_read_events--; 39 | else 40 | throw std::runtime_error("Unknown IO event type"); 41 | if (data->callback != nullptr) 42 | data->callback(); 43 | // printf("%d tasks to be done\n", this->n_write_events); 44 | } 45 | } 46 | 47 | void AIOAsyncIO::write(int fd, void *buffer, size_t n_bytes, unsigned long long offset, callback_t callback) 48 | { 49 | struct iocb iocb 50 | { 51 | }; // 建立一个异步I/O需求 52 | struct iocb *iocbs = &iocb; 53 | auto *data = new IOData(WRITE, callback); 54 | 55 | io_prep_pwrite(&iocb, fd, buffer, n_bytes, (long long)offset); // 初始化这个异步I/O需求 counter为偏移量 56 | 57 | iocb.data = data; 58 | io_submit(this->io_ctx, 1, &iocbs); // 提交这个I/O不会堵塞 59 | 60 | this->n_write_events++; 61 | } 62 | 63 | void AIOAsyncIO::read(int fd, void *buffer, size_t n_bytes, unsigned long long offset, callback_t callback) 64 | { 65 | struct iocb iocb 66 | { 67 | }; // 建立一个异步I/O需求 68 | struct iocb *iocbs = &iocb; 69 | auto *data = new IOData(READ, callback); 70 | 71 | io_prep_pread(&iocb, fd, buffer, n_bytes, (long long)offset); 72 | 73 | iocb.data = data; 74 | io_submit(this->io_ctx, 1, &iocbs); /* 提交这个I/O不会堵塞 */ 75 | 76 | this->n_read_events++; 77 | } 78 | 79 | void AIOAsyncIO::sync_write_events() 80 | { 81 | while (this->n_write_events > 0) 82 | get_event(WAIT); 83 | } 84 | 85 | void AIOAsyncIO::sync_read_events() 86 | { 87 | while (this->n_read_events > 0) 88 | get_event(WAIT); 89 | } 90 | 91 | void AIOAsyncIO::synchronize() 92 | { 93 | while (this->n_write_events > 0 || this->n_read_events > 0) 94 | get_event(WAIT); 95 | } 96 | 97 | void AIOAsyncIO::writev(int fd, const iovec *iov, unsigned int iovcnt, unsigned long long offset, callback_t callback) 98 | { 99 | struct iocb iocb 100 | { 101 | }; // 建立一个异步I/O需求 102 | struct iocb *iocbs = &iocb; 103 | auto *data = new IOData(WRITE, callback, iov); 104 | 105 | io_prep_pwritev(&iocb, fd, iov, iovcnt, (long long)offset); // 初始化这个异步I/O需求 counter为偏移量 106 | 107 | iocb.data = data; 108 | io_submit(this->io_ctx, 1, &iocbs); // 提交这个I/O不会堵塞 109 | 110 | this->n_write_events++; 111 | } 112 | 113 | void AIOAsyncIO::readv(int fd, const iovec *iov, unsigned int iovcnt, unsigned long long offset, callback_t callback) 114 | { 115 | struct iocb iocb 116 | { 117 | }; // 建立一个异步I/O需求 118 | struct iocb *iocbs = &iocb; 119 | auto *data = new IOData(READ, callback, iov); 120 | 121 | io_prep_preadv(&iocb, fd, iov, iovcnt, (long long)offset); 122 | 123 | iocb.data = data; 124 | io_submit(this->io_ctx, 1, &iocbs); /* 提交这个I/O不会堵塞 */ 125 | 126 | this->n_read_events++; 127 | } 128 | 129 | void AIOAsyncIO::write_tensor(int fd, torch::Tensor t, unsigned long long offset, callback_t callback, std::optional pinned) 130 | { 131 | if (t.is_cuda()) 132 | { 133 | if (pinned.has_value()) 134 | { 135 | pinned.value().copy_(t); 136 | t = pinned.value(); 137 | } 138 | else 139 | { 140 | t = t.to(torch::kCPU); 141 | } 142 | } 143 | void *buffer = t.data_ptr(); 144 | size_t n_bytes = t.numel() * t.element_size(); 145 | this->write(fd, buffer, n_bytes, offset, callback); 146 | } 147 | 148 | void AIOAsyncIO::register_h2d(unsigned int num_tensors) {} 149 | void AIOAsyncIO::sync_h2d() {} 150 | void AIOAsyncIO::register_tasks(unsigned int num_tasks) {} -------------------------------------------------------------------------------- /csrc/async_file_io.cpp: -------------------------------------------------------------------------------- 1 | #include "async_file_io.h" 2 | 3 | AsyncFileWriter::AsyncFileWriter(int fd, unsigned int n_entries, const std::string &backend, unsigned int n_tasks) : fd(fd), aio(create_asyncio(n_entries, backend, n_tasks)) {} 4 | 5 | void AsyncFileWriter::write(size_t buffer, size_t n_bytes, unsigned long long offset, callback_t callback) 6 | { 7 | void *ptr = reinterpret_cast(buffer); 8 | this->aio->write(this->fd, ptr, n_bytes, offset, callback); 9 | } 10 | 11 | void AsyncFileWriter::write_tensor(torch::Tensor tensor, unsigned long long offset, callback_t callback, std::optional pinned) 12 | { 13 | this->aio->write_tensor(this->fd, tensor, offset, callback, pinned); 14 | } 15 | 16 | void AsyncFileWriter::register_h2d(unsigned int num_tensors) 17 | { 18 | this->aio->register_h2d(num_tensors); 19 | } 20 | 21 | void AsyncFileWriter::sync_h2d() 22 | { 23 | this->aio->sync_h2d(); 24 | } 25 | 26 | void AsyncFileWriter::synchronize() 27 | { 28 | this->aio->synchronize(); 29 | } 30 | 31 | AsyncFileWriter::~AsyncFileWriter() 32 | { 33 | delete this->aio; 34 | } 35 | 36 | void AsyncFileWriter::register_tasks(unsigned int num_tasks) 37 | { 38 | this->aio->register_tasks(num_tasks); 39 | } -------------------------------------------------------------------------------- /csrc/backend.cpp: -------------------------------------------------------------------------------- 1 | #include "backend.h" 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | #ifndef DISABLE_URING 9 | #include "uring.h" 10 | #endif 11 | #ifndef DISABLE_AIO 12 | #include "aio.h" 13 | #endif 14 | #ifndef DISABLE_PTHREAD 15 | #include "pthread_backend.h" 16 | #endif 17 | 18 | std::unordered_set get_backends() 19 | { 20 | std::unordered_set backends; 21 | #ifndef DISABLE_URING 22 | backends.insert("uring"); 23 | #endif 24 | #ifndef DISABLE_AIO 25 | backends.insert("aio"); 26 | #endif 27 | #ifndef DISABLE_PTHREAD 28 | backends.insert("pthread"); 29 | #endif 30 | return backends; 31 | } 32 | 33 | void probe_asyncio(const std::string &backend) 34 | { 35 | FILE *fp = tmpfile(); 36 | if (!fp) 37 | { 38 | printf("Create tmpfile error: %s\n", strerror(errno)); 39 | throw std::runtime_error("uring probe failed\n"); 40 | } 41 | try 42 | { 43 | std::unique_ptr aio; 44 | if (backend == "uring") 45 | { 46 | #ifndef DISABLE_URING 47 | aio.reset(new UringAsyncIO(2, 0)); 48 | #else 49 | throw std::runtime_error("backend uring is not installed\n"); 50 | #endif 51 | } 52 | else if (backend == "aio") 53 | { 54 | #ifndef DISABLE_AIO 55 | aio.reset(new AIOAsyncIO(2, 0)); 56 | #else 57 | throw std::runtime_error("backend aio is not installed\n"); 58 | #endif 59 | } 60 | else if (backend == "pthread") 61 | { 62 | #ifndef DISABLE_PTHREAD 63 | aio.reset(new PthreadAsyncIO(2, 0)); 64 | #else 65 | throw std::runtime_error("backend pthread is not installed\n"); 66 | #endif 67 | } 68 | else 69 | { 70 | throw std::runtime_error("unknown backend"); 71 | } 72 | 73 | int fd = fileno(fp); 74 | const int n_loop = 5, n_len = 18; 75 | 76 | char text[n_loop][n_len]; 77 | 78 | int offset = 0; 79 | size_t len; 80 | for (int i = 0; i < n_loop; i++) 81 | { 82 | len = n_len; 83 | aio->write(fd, text[i], len, offset, nullptr); 84 | offset += len; 85 | } 86 | aio->sync_write_events(); 87 | 88 | char new_text[n_loop][n_len]; 89 | offset = 0; 90 | for (int i = 0; i < n_loop; i++) 91 | { 92 | len = n_len; 93 | aio->read(fd, new_text[i], len, offset, nullptr); 94 | offset += len; 95 | } 96 | aio->sync_read_events(); 97 | for (int i = 0; i < n_loop; i++) 98 | { 99 | for (int j = 0; j < n_len; j++) 100 | { 101 | assert(text[i][j] == new_text[i][j]); 102 | } 103 | } 104 | fclose(fp); 105 | } 106 | catch (...) 107 | { 108 | fclose(fp); 109 | throw std::runtime_error("uring probe failed\n"); 110 | } 111 | } 112 | 113 | bool probe_backend(const std::string &backend) 114 | { 115 | std::unordered_set backends = get_backends(); 116 | if (backends.find(backend) == backends.end()) 117 | return false; 118 | try 119 | { 120 | probe_asyncio(backend); 121 | return true; 122 | } 123 | catch (...) 124 | { 125 | return false; 126 | } 127 | } 128 | 129 | std::string get_default_backend() 130 | { 131 | const char *env = getenv("TENSORNVME_BACKEND"); 132 | if (env == nullptr) 133 | { 134 | return std::string(""); 135 | } 136 | return std::string(env); 137 | } 138 | 139 | bool get_debug_flag() 140 | { 141 | const char *env_ = getenv("TENSORNVME_DEBUG"); 142 | if (env_ == nullptr) 143 | { 144 | return false; 145 | } 146 | std::string env(env_); 147 | std::transform(env.begin(), env.end(), env.begin(), 148 | [](unsigned char c) 149 | { return std::tolower(c); }); 150 | return env == "1" || env == "true"; 151 | } 152 | 153 | std::string get_debug_log() 154 | { 155 | const char *env_ = getenv("TENSORNVME_DEBUG_LOG"); 156 | if (env_ == nullptr) 157 | { 158 | return std::string(""); 159 | } 160 | return std::string(env_); 161 | } 162 | 163 | AsyncIO *create_asyncio(unsigned int n_entries, std::string backend, unsigned int n_tasks) 164 | { 165 | std::unordered_set backends = get_backends(); 166 | std::string default_backend = get_default_backend(); 167 | bool is_debugging = get_debug_flag(); 168 | 169 | if (backends.empty()) 170 | throw std::runtime_error("No asyncio backend is installed"); 171 | 172 | if (default_backend.size() > 0) 173 | { // priority 1: environ is set 174 | if (is_debugging) 175 | { 176 | std::cout << "[backend] backend is overwritten by environ TENSORNVME_BACKEND from " << backend << " to " << default_backend << std::endl; 177 | } 178 | backend = default_backend; 179 | } 180 | else if (backend.size() > 0) 181 | { // priority 2: backend is set 182 | if (backends.find(backend) == backends.end()) 183 | throw std::runtime_error("Unsupported backend: " + backend); 184 | } 185 | 186 | if (!probe_backend(backend)) 187 | throw std::runtime_error("Backend \"" + backend + "\" is not install correctly"); 188 | 189 | #ifndef DISABLE_URING 190 | if (backend == "uring") 191 | return new UringAsyncIO(n_entries, n_tasks); 192 | #endif 193 | #ifndef DISABLE_AIO 194 | if (backend == "aio") 195 | return new AIOAsyncIO(n_entries, n_tasks); 196 | #endif 197 | #ifndef DISABLE_PTHREAD 198 | if (backend == "pthread") 199 | return new PthreadAsyncIO(n_entries, n_tasks); 200 | #endif 201 | throw std::runtime_error("Unsupported backend: " + backend); 202 | } -------------------------------------------------------------------------------- /csrc/offload.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include "backend.h" 15 | #include "offload.h" 16 | #include "space_mgr.h" 17 | 18 | iovec *tensors_to_iovec(const std::vector &tensors) 19 | { 20 | iovec *iovs = static_cast(calloc(tensors.size(), sizeof(iovec))); 21 | for (size_t i = 0; i < tensors.size(); i++) 22 | { 23 | iovs[i].iov_base = tensors[i].data_ptr(); 24 | iovs[i].iov_len = tensors[i].storage().nbytes(); 25 | } 26 | return iovs; 27 | } 28 | 29 | Offloader::Offloader(const std::string &filename, unsigned int n_entries, const std::string &backend) : filename(filename), space_mgr(SpaceManager(0)) 30 | { 31 | this->aio = create_asyncio(n_entries, backend, 0); 32 | this->fd = open(filename.c_str(), O_RDWR | O_CREAT, S_IRUSR | S_IWUSR); 33 | this->aio->register_file(fd); 34 | } 35 | 36 | SpaceInfo Offloader::prepare_write(const at::Tensor &tensor, const std::string &key) 37 | { 38 | if (!tensor.is_contiguous() || !tensor.is_cpu()) 39 | throw std::runtime_error("Tensor must be contiguous and on cpu"); 40 | ull bytes = tensor.storage().nbytes(); 41 | ull offset = this->space_mgr.alloc(bytes); 42 | SpaceInfo space_info(offset, bytes); 43 | this->tensors_info[key] = space_info; 44 | return space_info; 45 | } 46 | 47 | SpaceInfo Offloader::prepare_read(const at::Tensor &tensor, const std::string &key) 48 | { 49 | if (!tensor.is_contiguous() || !tensor.is_cpu()) 50 | throw std::runtime_error("Tensor must be contiguous and on cpu"); 51 | if (this->tensors_info.find(key) == this->tensors_info.end()) 52 | throw std::runtime_error("Read error, tensor not found"); 53 | ull bytes = tensor.storage().nbytes(); 54 | SpaceInfo space_info = this->tensors_info[key]; 55 | if (bytes != space_info.second) 56 | throw std::runtime_error("Read error, tensor shape mismatch"); 57 | this->tensors_info.erase(key); 58 | return space_info; 59 | } 60 | 61 | void Offloader::async_write(const at::Tensor &tensor, const std::string &key, callback_t callback) 62 | { 63 | ull offset, bytes; 64 | std::tie(offset, bytes) = prepare_write(tensor, key); 65 | this->aio->write(this->fd, tensor.data_ptr(), bytes, offset, callback); 66 | 67 | this->aio->get_event(NOWAIT); 68 | } 69 | 70 | void Offloader::async_read(const at::Tensor &tensor, const std::string &key, callback_t callback) 71 | { 72 | ull offset, bytes; 73 | std::tie(offset, bytes) = prepare_read(tensor, key); 74 | auto fn = std::bind(&Offloader::release, this, offset, bytes, callback); 75 | this->aio->read(this->fd, tensor.data_ptr(), bytes, offset, fn); 76 | 77 | this->aio->get_event(NOWAIT); 78 | } 79 | 80 | void Offloader::sync_write(const at::Tensor &tensor, const std::string &key) 81 | { 82 | ull offset, bytes; 83 | std::tie(offset, bytes) = prepare_write(tensor, key); 84 | lseek(this->fd, offset, SEEK_SET); 85 | write(this->fd, tensor.data_ptr(), bytes); 86 | } 87 | 88 | void Offloader::sync_read(const at::Tensor &tensor, const std::string &key) 89 | { 90 | ull offset, bytes; 91 | std::tie(offset, bytes) = prepare_read(tensor, key); 92 | lseek(this->fd, offset, SEEK_SET); 93 | read(this->fd, tensor.data_ptr(), bytes); 94 | release(offset, bytes); 95 | } 96 | 97 | void Offloader::sync_write_events() 98 | { 99 | this->aio->sync_write_events(); 100 | } 101 | 102 | void Offloader::sync_read_events() 103 | { 104 | this->aio->sync_read_events(); 105 | } 106 | 107 | void Offloader::synchronize() 108 | { 109 | this->aio->synchronize(); 110 | } 111 | 112 | Offloader::~Offloader() 113 | { 114 | errno = 0; 115 | delete this->aio; 116 | close(this->fd); 117 | if (remove(this->filename.c_str()) != 0) 118 | printf("Remove \"%s\" error(%d): %s\n", this->filename.c_str(), errno, strerror(errno)); 119 | } 120 | 121 | SpaceInfo Offloader::prepare_writev(const std::vector &tensors, const std::string &key) 122 | { 123 | ull total_bytes = 0; 124 | for (const at::Tensor &tensor : tensors) 125 | { 126 | if (!tensor.is_contiguous() || !tensor.is_cpu()) 127 | throw std::runtime_error("Tensor must be contiguous and on cpu"); 128 | total_bytes += tensor.storage().nbytes(); 129 | } 130 | ull offset = this->space_mgr.alloc(total_bytes); 131 | SpaceInfo space_info(offset, total_bytes); 132 | this->tensors_info[key] = space_info; 133 | return space_info; 134 | } 135 | 136 | SpaceInfo Offloader::prepare_readv(const std::vector &tensors, const std::string &key) 137 | { 138 | ull total_bytes = 0; 139 | for (const at::Tensor &tensor : tensors) 140 | { 141 | if (!tensor.is_contiguous() || !tensor.is_cpu()) 142 | throw std::runtime_error("Tensor must be contiguous and on cpu"); 143 | total_bytes += tensor.storage().nbytes(); 144 | } 145 | if (this->tensors_info.find(key) == this->tensors_info.end()) 146 | throw std::runtime_error("Read error, tensor not found"); 147 | SpaceInfo space_info = this->tensors_info[key]; 148 | if (total_bytes != space_info.second) 149 | throw std::runtime_error("Read error, tensor shape mismatch"); 150 | this->tensors_info.erase(key); 151 | return space_info; 152 | } 153 | 154 | void Offloader::async_writev(const std::vector &tensors, const std::string &key, callback_t callback) 155 | { 156 | ull offset, bytes; 157 | std::tie(offset, bytes) = prepare_writev(tensors, key); 158 | iovec *iov = tensors_to_iovec(tensors); 159 | this->aio->writev(this->fd, iov, tensors.size(), offset, callback); 160 | 161 | this->aio->get_event(NOWAIT); 162 | } 163 | 164 | void Offloader::async_readv(const std::vector &tensors, const std::string &key, callback_t callback) 165 | { 166 | 167 | ull offset, bytes; 168 | std::tie(offset, bytes) = prepare_readv(tensors, key); 169 | iovec *iov = tensors_to_iovec(tensors); 170 | auto fn = std::bind(&Offloader::release, this, offset, bytes, callback); 171 | this->aio->readv(this->fd, iov, tensors.size(), offset, fn); 172 | 173 | this->aio->get_event(NOWAIT); 174 | } 175 | 176 | void Offloader::sync_writev(const std::vector &tensors, const std::string &key) 177 | { 178 | ull offset, bytes; 179 | std::tie(offset, bytes) = prepare_writev(tensors, key); 180 | iovec *iov = tensors_to_iovec(tensors); 181 | lseek(this->fd, offset, SEEK_SET); 182 | writev(this->fd, iov, tensors.size()); 183 | delete iov; 184 | } 185 | 186 | void Offloader::sync_readv(const std::vector &tensors, const std::string &key) 187 | { 188 | ull offset, bytes; 189 | std::tie(offset, bytes) = prepare_readv(tensors, key); 190 | iovec *iov = tensors_to_iovec(tensors); 191 | lseek(this->fd, offset, SEEK_SET); 192 | readv(this->fd, iov, tensors.size()); 193 | delete iov; 194 | } 195 | 196 | void Offloader::release(ull offset, ull bytes, callback_t callback) 197 | { 198 | this->space_mgr.free(offset, bytes); 199 | if (callback != nullptr) 200 | callback(); 201 | } 202 | -------------------------------------------------------------------------------- /csrc/pthread_backend.cpp: -------------------------------------------------------------------------------- 1 | #include "pthread_backend.h" 2 | #include 3 | 4 | void write_file(const std::string &filename, const std::string &content) 5 | { 6 | std::ofstream file(filename, std::ios::app); 7 | file << content << std::endl; 8 | file.close(); 9 | } 10 | 11 | void PthreadAsyncIO::write(int fd, void *buffer, size_t n_bytes, unsigned long long offset, callback_t callback) 12 | { 13 | auto fut = this->pool.submit_task( 14 | [this, fd, buffer, n_bytes, offset] 15 | { 16 | auto val = pwrite(fd, buffer, n_bytes, offset); 17 | if (this->is_debug) 18 | { 19 | auto cur_tasks = this->tasks_in_progress.fetch_add(1); 20 | if (cur_tasks + 1 == this->total_tasks) 21 | { 22 | if (this->debug_log.empty()) 23 | { 24 | std::cout << "All tasks are completed" << std::endl; 25 | } 26 | else 27 | { 28 | write_file(this->debug_log, "All tasks are completed"); 29 | } 30 | } 31 | } 32 | return val; 33 | }); 34 | this->write_fut.push_back(std::make_tuple(std::move(fut), callback)); 35 | } 36 | 37 | void PthreadAsyncIO::writev(int fd, const iovec *iov, unsigned int iovcnt, unsigned long long offset, callback_t callback) 38 | { 39 | auto fut = this->pool.submit_task( 40 | [fd, iov, iovcnt, offset] 41 | { 42 | return pwritev(fd, iov, iovcnt, offset); 43 | }); 44 | this->write_fut.push_back(std::make_tuple(std::move(fut), callback)); 45 | } 46 | 47 | void PthreadAsyncIO::read(int fd, void *buffer, size_t n_bytes, unsigned long long offset, callback_t callback) 48 | { 49 | auto fut = this->pool.submit_task( 50 | [fd, buffer, n_bytes, offset] 51 | { 52 | return pread(fd, buffer, n_bytes, offset); 53 | }); 54 | this->read_fut.push_back(std::make_tuple(std::move(fut), callback)); 55 | } 56 | 57 | void PthreadAsyncIO::readv(int fd, const iovec *iov, unsigned int iovcnt, unsigned long long offset, callback_t callback) 58 | { 59 | auto fut = this->pool.submit_task( 60 | [fd, iov, iovcnt, offset] 61 | { 62 | return preadv(fd, iov, iovcnt, offset); 63 | }); 64 | this->read_fut.push_back(std::make_tuple(std::move(fut), callback)); 65 | } 66 | 67 | void PthreadAsyncIO::get_event(WaitType wt) 68 | { 69 | if (wt == NOWAIT) 70 | return; 71 | this->sync_write_events(); 72 | this->sync_read_events(); 73 | } 74 | 75 | void PthreadAsyncIO::sync_write_events() 76 | { 77 | while (this->write_fut.size() > 0) 78 | { 79 | auto front = std::move(this->write_fut.front()); 80 | this->write_fut.pop_front(); 81 | 82 | auto fut(std::move(std::get<0>(front))); 83 | fut.wait(); 84 | 85 | auto callback = std::get<1>(front); 86 | if (callback != nullptr) 87 | { 88 | callback(); 89 | } 90 | } 91 | } 92 | 93 | void PthreadAsyncIO::sync_read_events() 94 | { 95 | while (this->read_fut.size() > 0) 96 | { 97 | auto front = std::move(this->read_fut.front()); 98 | this->read_fut.pop_front(); 99 | 100 | auto fut(std::move(std::get<0>(front))); 101 | fut.wait(); 102 | 103 | auto callback = std::get<1>(front); 104 | if (callback != nullptr) 105 | { 106 | callback(); 107 | } 108 | } 109 | } 110 | 111 | void PthreadAsyncIO::synchronize() 112 | { 113 | this->get_event(WAIT); 114 | } 115 | 116 | void PthreadAsyncIO::register_file(int fd) {} 117 | 118 | void PthreadAsyncIO::register_h2d(unsigned int num_tensors) 119 | { 120 | this->total_h2d = num_tensors; 121 | } 122 | 123 | void PthreadAsyncIO::sync_h2d() 124 | { 125 | std::unique_lock lock(this->mtx); 126 | this->cv.wait(lock, [this] 127 | { return this->h2d_in_progress == this->total_h2d; }); // block until all in-progress h2d are completed 128 | } 129 | 130 | void PthreadAsyncIO::write_tensor(int fd, torch::Tensor t, unsigned long long offset, callback_t callback, std::optional pinned) 131 | { 132 | auto stream = c10::cuda::getCurrentCUDAStream(); 133 | if (!t.is_cuda()) 134 | { 135 | auto cur_h2d = this->h2d_in_progress.fetch_add(1); // already moved to cpu 136 | if (cur_h2d + 1 == this->total_h2d) 137 | { // notify when all h2d are completed and safe to optimizer.step() 138 | std::lock_guard lock(this->mtx); 139 | cv.notify_one(); 140 | } 141 | } 142 | auto fut = this->pool.submit_task( 143 | [this, fd, t, offset, pinned, stream] 144 | { 145 | torch::Tensor cpu_tensor; 146 | if (t.is_cuda()) 147 | { 148 | at::cuda::CUDAStreamGuard guard(stream); // https://pytorch.org/cppdocs/notes/tensor_cuda_stream.html 149 | if (pinned.has_value()) 150 | { 151 | pinned.value().copy_(t, /*non_blocking*/ false); 152 | cpu_tensor = pinned.value(); 153 | } 154 | else 155 | { 156 | cpu_tensor = t.to(t.options().device(c10::DeviceType::CPU), /*non_blocking*/ false, /*copy*/ false); // modified from torch::Tensor::cpu() 157 | } 158 | auto cur_h2d = this->h2d_in_progress.fetch_add(1); 159 | if (cur_h2d + 1 == this->total_h2d) 160 | { // notify when all h2d are completed and safe to optimizer.step() 161 | std::lock_guard lock(this->mtx); 162 | cv.notify_one(); 163 | } 164 | } 165 | else 166 | { 167 | cpu_tensor = t; 168 | } 169 | void *buf = cpu_tensor.data_ptr(); 170 | size_t n_bytes = cpu_tensor.numel() * cpu_tensor.element_size(); 171 | auto val = pwrite(fd, buf, n_bytes, offset); 172 | if (this->is_debug) 173 | { 174 | auto cur_tasks = this->tasks_in_progress.fetch_add(1); 175 | if (cur_tasks + 1 == this->total_tasks) 176 | { 177 | if (this->debug_log.empty()) 178 | { 179 | std::cout << "All tasks are completed" << std::endl; 180 | } 181 | else 182 | { 183 | write_file(this->debug_log, "All tasks are completed"); 184 | } 185 | } 186 | } 187 | return val; 188 | }); 189 | this->write_fut.push_back(std::make_tuple(std::move(fut), callback)); 190 | } 191 | 192 | void PthreadAsyncIO::register_tasks(unsigned int num_tasks) 193 | { 194 | if (this->is_debug) 195 | { 196 | this->tasks_in_progress.store(num_tasks); 197 | } 198 | } -------------------------------------------------------------------------------- /csrc/py_api.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include "offload.h" 6 | #include "async_file_io.h" 7 | #include "backend.h" 8 | #include 9 | 10 | namespace py = pybind11; 11 | 12 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 13 | { 14 | py::class_(m, "Offloader") 15 | .def(py::init(), py::arg("filename"), py::arg("n_entries"), py::arg("backend") = "aio") 16 | .def("async_write", &Offloader::async_write, py::arg("tensor"), py::arg("key"), py::arg("callback") = py::none()) 17 | .def("async_read", &Offloader::async_read, py::arg("tensor"), py::arg("key"), py::arg("callback") = py::none()) 18 | .def("sync_write", &Offloader::sync_write, py::arg("tensor"), py::arg("key")) 19 | .def("sync_read", &Offloader::sync_read, py::arg("tensor"), py::arg("key")) 20 | .def("sync_write_events", &Offloader::sync_write_events) 21 | .def("sync_read_events", &Offloader::sync_write_events) 22 | .def("synchronize", &Offloader::synchronize) 23 | .def("async_writev", &Offloader::async_writev, py::arg("tensors"), py::arg("key"), py::arg("callback") = py::none()) 24 | .def("async_readv", &Offloader::async_readv, py::arg("tensors"), py::arg("key"), py::arg("callback") = py::none()) 25 | .def("sync_writev", &Offloader::sync_writev, py::arg("tensors"), py::arg("key")) 26 | .def("sync_readv", &Offloader::sync_readv, py::arg("tensors"), py::arg("key")); 27 | m.def("get_backends", get_backends); 28 | m.def("probe_backend", probe_backend, py::arg("backend")); 29 | py::class_(m, "AsyncFileWriter") 30 | .def(py::init(), py::arg("fd"), py::arg("n_entries"), py::arg("backend") = "aio", py::arg("n_tasks") = 0) 31 | .def("write", &AsyncFileWriter::write, py::arg("buffer"), py::arg("n_bytes"), py::arg("offset"), py::arg("callback") = py::none()) 32 | .def("write_tensor", &AsyncFileWriter::write_tensor, py::arg("tensor"), py::arg("offset"), py::arg("callback") = py::none(), py::arg("pinned") = py::none()) 33 | .def("synchronize", &AsyncFileWriter::synchronize) 34 | .def("sync_h2d", &AsyncFileWriter::sync_h2d) 35 | .def("register_h2d", &AsyncFileWriter::register_h2d, py::arg("num_tensors")) 36 | .def("register_tasks", &AsyncFileWriter::register_tasks, py::arg("num_tasks")); 37 | } -------------------------------------------------------------------------------- /csrc/space_mgr.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include "space_mgr.h" 3 | #include 4 | 5 | SpaceManager::SpaceManager(unsigned long long limit) : limit(limit), used_bytes(0) 6 | { 7 | } 8 | 9 | SpaceManager::~SpaceManager() 10 | { 11 | } 12 | 13 | ull SpaceManager::alloc(ull bytes) 14 | { 15 | if (bytes == 0) 16 | throw std::runtime_error("Invalid alloc size (0)"); 17 | auto target_iter = avail_spaces.end(); 18 | ull target_bytes = 0; 19 | for (auto iter = avail_spaces.begin(); iter != avail_spaces.end(); iter++) 20 | { 21 | if (iter->second >= bytes && (target_iter == avail_spaces.end() || iter->second < target_bytes)) 22 | { 23 | target_iter = iter; 24 | target_bytes = iter->second; 25 | } 26 | } 27 | // no available space, use new space 28 | if (target_iter == avail_spaces.end()) 29 | { 30 | // limit=0 means unlimit 31 | if (limit > 0 && limit - used_bytes < bytes) 32 | throw std::runtime_error("File size exceed limit"); 33 | ull offset = used_bytes; 34 | used_bytes += bytes; 35 | return offset; 36 | } 37 | ull offset = target_iter->first; 38 | target_iter->first += bytes; 39 | target_iter->second -= bytes; 40 | if (target_iter->second == 0) 41 | avail_spaces.erase(target_iter); 42 | return offset; 43 | } 44 | 45 | void SpaceManager::free(ull offset, ull bytes) 46 | { 47 | if (bytes == 0) 48 | throw std::runtime_error("Invalid free size (0)"); 49 | SpaceInfo new_avail_space(offset, bytes); 50 | for (auto iter = avail_spaces.begin(); iter != avail_spaces.end();) 51 | { 52 | if (offset > iter->first && offset - iter->first == iter->second) 53 | { 54 | new_avail_space.first = iter->first; 55 | new_avail_space.second += iter->second; 56 | iter = avail_spaces.erase(iter); 57 | } 58 | else if (offset < iter->first && iter->first - offset == bytes) 59 | { 60 | new_avail_space.second += iter->second; 61 | iter = avail_spaces.erase(iter); 62 | } 63 | else 64 | { 65 | iter++; 66 | } 67 | } 68 | if (offset + bytes == used_bytes) 69 | used_bytes = used_bytes - bytes; 70 | else 71 | avail_spaces.push_back(new_avail_space); 72 | } 73 | 74 | void SpaceManager::print() 75 | { 76 | printf("Used bytes: %lld", used_bytes); 77 | for (auto iter = avail_spaces.begin(); iter != avail_spaces.end(); iter++) 78 | { 79 | printf(", [%lld, %lld)", iter->first, iter->second); 80 | } 81 | printf("\n"); 82 | } -------------------------------------------------------------------------------- /csrc/uring.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include "uring.h" 4 | 5 | UringAsyncIO::UringAsyncIO(unsigned int n_entries, unsigned int n_tasks) : n_write_events(0), n_read_events(0), n_entries(n_entries) 6 | { 7 | io_uring_queue_init(n_entries, &this->ring, 0); 8 | } 9 | 10 | void UringAsyncIO::register_file(int fd) 11 | { 12 | io_uring_register_files(&ring, &fd, 1); 13 | } 14 | 15 | UringAsyncIO::~UringAsyncIO() 16 | { 17 | synchronize(); 18 | io_uring_queue_exit(&this->ring); 19 | } 20 | 21 | void UringAsyncIO::get_event(WaitType wt) 22 | { 23 | io_uring_cqe *cqe; 24 | if (wt == WAIT) 25 | { 26 | io_uring_wait_cqe(&this->ring, &cqe); 27 | } 28 | else 29 | { 30 | int ret = io_uring_peek_cqe(&this->ring, &cqe); 31 | if (ret != 0) 32 | return; 33 | } 34 | 35 | std::unique_ptr data(static_cast(io_uring_cqe_get_data(cqe))); 36 | if (data->type == WRITE) 37 | this->n_write_events--; 38 | else if (data->type == READ) 39 | this->n_read_events--; 40 | else 41 | throw std::runtime_error("Unknown IO event type"); 42 | io_uring_cqe_seen(&this->ring, cqe); 43 | if (data->callback != nullptr) 44 | data->callback(); 45 | } 46 | 47 | void UringAsyncIO::write(int fd, void *buffer, size_t n_bytes, unsigned long long offset, callback_t callback) 48 | { 49 | io_uring_sqe *sqe = io_uring_get_sqe(&this->ring); 50 | IOData *data = new IOData(WRITE, callback); 51 | io_uring_prep_write(sqe, fd, buffer, n_bytes, offset); 52 | io_uring_sqe_set_data(sqe, data); 53 | io_uring_submit(&this->ring); 54 | this->n_write_events++; 55 | } 56 | 57 | void UringAsyncIO::read(int fd, void *buffer, size_t n_bytes, unsigned long long offset, callback_t callback) 58 | { 59 | io_uring_sqe *sqe = io_uring_get_sqe(&this->ring); 60 | IOData *data = new IOData(READ, callback); 61 | io_uring_prep_read(sqe, fd, buffer, n_bytes, offset); 62 | io_uring_sqe_set_data(sqe, data); 63 | io_uring_submit(&this->ring); 64 | this->n_read_events++; 65 | } 66 | 67 | void UringAsyncIO::sync_write_events() 68 | { 69 | while (this->n_write_events > 0) 70 | get_event(WAIT); 71 | } 72 | 73 | void UringAsyncIO::sync_read_events() 74 | { 75 | while (this->n_read_events > 0) 76 | get_event(WAIT); 77 | } 78 | 79 | void UringAsyncIO::synchronize() 80 | { 81 | while (this->n_write_events > 0 || this->n_read_events > 0) 82 | get_event(WAIT); 83 | } 84 | 85 | void UringAsyncIO::writev(int fd, const iovec *iov, unsigned int iovcnt, unsigned long long offset, callback_t callback) 86 | { 87 | io_uring_sqe *sqe = io_uring_get_sqe(&this->ring); 88 | IOData *data = new IOData(WRITE, callback, iov); 89 | io_uring_prep_writev(sqe, fd, iov, iovcnt, offset); 90 | io_uring_sqe_set_data(sqe, data); 91 | io_uring_submit(&this->ring); 92 | this->n_write_events++; 93 | } 94 | 95 | void UringAsyncIO::readv(int fd, const iovec *iov, unsigned int iovcnt, unsigned long long offset, callback_t callback) 96 | { 97 | io_uring_sqe *sqe = io_uring_get_sqe(&this->ring); 98 | IOData *data = new IOData(READ, callback, iov); 99 | io_uring_prep_readv(sqe, fd, iov, iovcnt, offset); 100 | io_uring_sqe_set_data(sqe, data); 101 | io_uring_submit(&this->ring); 102 | this->n_read_events++; 103 | } 104 | 105 | void UringAsyncIO::write_tensor(int fd, torch::Tensor t, unsigned long long offset, callback_t callback, std::optional pinned) 106 | { 107 | if (t.is_cuda()) 108 | { 109 | if (pinned.has_value()) 110 | { 111 | pinned.value().copy_(t); 112 | t = pinned.value(); 113 | } 114 | else 115 | { 116 | t = t.to(torch::kCPU); 117 | } 118 | } 119 | void *buffer = t.data_ptr(); 120 | size_t n_bytes = t.numel() * t.element_size(); 121 | this->write(fd, buffer, n_bytes, offset, callback); 122 | } 123 | 124 | void UringAsyncIO::register_h2d(unsigned int num_tensors) {} 125 | void UringAsyncIO::sync_h2d() {} 126 | void UringAsyncIO::register_tasks(unsigned int num_tasks) {} -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM hpcaitech/pytorch-cuda:1.11.0-11.3.0 2 | 3 | # install dependencies 4 | RUN conda install -y cmake 5 | 6 | # install tensornvme 7 | RUN git clone https://github.com/hpcaitech/TensorNVMe.git && \ 8 | cd TensorNVMe && \ 9 | pip install -r requirements.txt && \ 10 | pip install -v --no-cache-dir . -------------------------------------------------------------------------------- /include/aio.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include "asyncio.h" 8 | 9 | class AIOAsyncIO : public AsyncIO 10 | { 11 | private: 12 | io_context_t io_ctx = nullptr; 13 | int n_write_events = 0; /* event个数 */ 14 | int n_read_events = 0; 15 | int max_nr; 16 | int min_nr = 1; 17 | struct timespec timeout; 18 | 19 | void get_event(WaitType wt); 20 | 21 | public: 22 | AIOAsyncIO(unsigned int n_entries, unsigned int n_tasks); 23 | ~AIOAsyncIO(); 24 | 25 | void write(int fd, void *buffer, size_t n_bytes, unsigned long long offset, callback_t callback); 26 | void read(int fd, void *buffer, size_t n_bytes, unsigned long long offset, callback_t callback); 27 | void writev(int fd, const iovec *iov, unsigned int iovcnt, unsigned long long offset, callback_t callback); 28 | void readv(int fd, const iovec *iov, unsigned int iovcnt, unsigned long long offset, callback_t callback); 29 | 30 | void register_h2d(unsigned int num_tensors); 31 | void sync_h2d(); 32 | void sync_write_events(); 33 | void sync_read_events(); 34 | void synchronize(); 35 | 36 | void register_file(int fd); 37 | void write_tensor(int fd, torch::Tensor t, unsigned long long offset, callback_t callback, std::optional pinned); 38 | void register_tasks(unsigned int num_tasks); 39 | }; -------------------------------------------------------------------------------- /include/async_file_io.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include 5 | 6 | #include "asyncio.h" 7 | #include "backend.h" 8 | 9 | #ifndef DISABLE_URING 10 | #include "uring.h" 11 | #endif 12 | 13 | #ifndef DISABLE_AIO 14 | #include "aio.h" 15 | #endif 16 | 17 | class AsyncFileWriter 18 | { 19 | public: 20 | AsyncFileWriter(int fd, unsigned int n_entries, const std::string &backend, unsigned int n_tasks); 21 | void write(size_t buffer, size_t n_bytes, unsigned long long offset, callback_t callback); 22 | void write_tensor(torch::Tensor tensor, unsigned long long offset, callback_t callback, std::optional pinned); 23 | void synchronize(); 24 | void register_h2d(unsigned int num_tensors); 25 | void sync_h2d(); 26 | void register_tasks(unsigned int num_tasks); 27 | ~AsyncFileWriter(); 28 | 29 | private: 30 | int fd; 31 | AsyncIO *aio; 32 | }; -------------------------------------------------------------------------------- /include/asyncio.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | using callback_t = std::function; 8 | 9 | enum IOType 10 | { 11 | WRITE, 12 | READ 13 | }; 14 | 15 | enum WaitType 16 | { 17 | WAIT, 18 | NOWAIT 19 | }; 20 | 21 | struct IOData 22 | { 23 | IOType type; 24 | callback_t callback; 25 | const iovec *iov; 26 | 27 | IOData(IOType type, callback_t callback = nullptr, const iovec *iov = nullptr) : type(type), callback(callback), iov(iov) {} 28 | ~IOData() 29 | { 30 | if (iov) 31 | delete iov; 32 | } 33 | }; 34 | 35 | class AsyncIO 36 | { 37 | public: 38 | virtual ~AsyncIO() = default; 39 | 40 | virtual void write(int fd, void *buffer, size_t n_bytes, unsigned long long offset, callback_t callback) = 0; 41 | virtual void read(int fd, void *buffer, size_t n_bytes, unsigned long long offset, callback_t callback) = 0; 42 | virtual void writev(int fd, const iovec *iov, unsigned int iovcnt, unsigned long long offset, callback_t callback) = 0; 43 | virtual void readv(int fd, const iovec *iov, unsigned int iovcnt, unsigned long long offset, callback_t callback) = 0; 44 | 45 | virtual void get_event(WaitType wt) = 0; 46 | virtual void sync_write_events() = 0; 47 | virtual void sync_read_events() = 0; 48 | virtual void register_h2d(unsigned int num_tensors) = 0; 49 | virtual void register_tasks(unsigned int num_tasks) = 0; 50 | virtual void sync_h2d() = 0; 51 | virtual void synchronize() = 0; 52 | 53 | virtual void register_file(int fd) = 0; 54 | virtual void write_tensor(int fd, torch::Tensor t, unsigned long long offset, callback_t callback, std::optional pinned) = 0; 55 | }; -------------------------------------------------------------------------------- /include/backend.h: -------------------------------------------------------------------------------- 1 | #include "asyncio.h" 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | std::unordered_set get_backends(); 10 | 11 | bool probe_backend(const std::string &backend); 12 | 13 | std::string get_default_backend(); 14 | 15 | bool get_debug_flag(); 16 | 17 | AsyncIO *create_asyncio(unsigned int n_entries, std::string backend, unsigned int n_tasks); 18 | 19 | std::string get_debug_log(); -------------------------------------------------------------------------------- /include/offload.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "asyncio.h" 4 | #include 5 | 6 | #include "space_mgr.h" 7 | #ifndef DISABLE_URING 8 | #include "uring.h" 9 | #endif 10 | #ifndef DISABLE_AIO 11 | #include "aio.h" 12 | #endif 13 | 14 | class Offloader 15 | { 16 | public: 17 | Offloader(const std::string &filename, unsigned int n_entries, const std::string &backend); 18 | SpaceInfo prepare_write(const at::Tensor &tensor, const std::string &key); 19 | SpaceInfo prepare_read(const at::Tensor &tensor, const std::string &key); 20 | void async_write(const at::Tensor &tensor, const std::string &key, callback_t callback = nullptr); 21 | void async_read(const at::Tensor &tensor, const std::string &key, callback_t callback = nullptr); 22 | void sync_write(const at::Tensor &tensor, const std::string &key); 23 | void sync_read(const at::Tensor &tensor, const std::string &key); 24 | void sync_write_events(); 25 | void sync_read_events(); 26 | void synchronize(); 27 | ~Offloader(); 28 | SpaceInfo prepare_writev(const std::vector &tensors, const std::string &key); 29 | SpaceInfo prepare_readv(const std::vector &tensors, const std::string &key); 30 | void async_writev(const std::vector &tensors, const std::string &key, callback_t callback = nullptr); 31 | void async_readv(const std::vector &tensors, const std::string &key, callback_t callback = nullptr); 32 | void sync_writev(const std::vector &tensors, const std::string &key); 33 | void sync_readv(const std::vector &tensors, const std::string &key); 34 | private: 35 | const std::string filename; 36 | int fd; 37 | AsyncIO *aio; 38 | SpaceManager space_mgr; 39 | std::unordered_map tensors_info; 40 | 41 | void release(ull offset, ull bytes, callback_t callback = nullptr); 42 | }; -------------------------------------------------------------------------------- /include/pthread_backend.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | 19 | #include "asyncio.h" 20 | #include "threadpool.hpp" 21 | #include "backend.h" 22 | #include 23 | 24 | class PthreadAsyncIO : public AsyncIO 25 | { 26 | private: 27 | BS::thread_pool pool; 28 | std::atomic h2d_in_progress; 29 | unsigned int total_h2d; 30 | std::condition_variable cv; 31 | std::mutex mtx; 32 | std::deque, callback_t>> write_fut; 33 | std::deque, callback_t>> read_fut; 34 | const bool is_debug = get_debug_flag(); 35 | const std::string debug_log = get_debug_log(); 36 | 37 | std::atomic tasks_in_progress; 38 | unsigned int total_tasks; 39 | 40 | public: 41 | PthreadAsyncIO(unsigned int n_entries, unsigned int n_tasks) 42 | : pool(n_entries), h2d_in_progress(0), tasks_in_progress(0), total_tasks(n_tasks), total_h2d(0) {} 43 | 44 | ~PthreadAsyncIO() {} 45 | 46 | void write(int fd, void *buffer, size_t n_bytes, unsigned long long offset, callback_t callback); 47 | void read(int fd, void *buffer, size_t n_bytes, unsigned long long offset, callback_t callback); 48 | void writev(int fd, const iovec *iov, unsigned int iovcnt, unsigned long long offset, callback_t callback); 49 | void readv(int fd, const iovec *iov, unsigned int iovcnt, unsigned long long offset, callback_t callback); 50 | 51 | void get_event(WaitType wt); 52 | void sync_write_events(); 53 | void sync_read_events(); 54 | void register_h2d(unsigned int num_tensors); 55 | void sync_h2d(); 56 | void synchronize(); 57 | 58 | void register_file(int fd); 59 | 60 | void write_tensor(int fd, torch::Tensor t, unsigned long long offset, callback_t callback, std::optional pinned); 61 | void register_tasks(unsigned int num_tasks); 62 | }; -------------------------------------------------------------------------------- /include/space_mgr.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | using std::vector; 7 | using ull = unsigned long long; 8 | // (offset, bytes) pair 9 | using SpaceInfo = std::pair; 10 | 11 | class SpaceManager 12 | { 13 | private: 14 | ull limit, used_bytes; 15 | vector avail_spaces; 16 | 17 | public: 18 | SpaceManager(ull limit); 19 | ~SpaceManager(); 20 | ull alloc(ull bytes); 21 | void free(ull offset, ull bytes); 22 | void print(); 23 | }; -------------------------------------------------------------------------------- /include/threadpool.hpp: -------------------------------------------------------------------------------- 1 | // Copied from https://github.com/bshoshany/thread-pool/blob/097aa718f25d44315cadb80b407144ad455ee4f9/include/BS_thread_pool.hpp 2 | 3 | /* Original license 4 | 5 | MIT License 6 | 7 | Copyright (c) 2024 Barak Shoshany 8 | 9 | Permission is hereby granted, free of charge, to any person obtaining a copy 10 | of this software and associated documentation files (the "Software"), to deal 11 | in the Software without restriction, including without limitation the rights 12 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 13 | copies of the Software, and to permit persons to whom the Software is 14 | furnished to do so, subject to the following conditions: 15 | 16 | The above copyright notice and this permission notice shall be included in all 17 | copies or substantial portions of the Software. 18 | 19 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 20 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 21 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 22 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 23 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 24 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 25 | SOFTWARE. 26 | */ 27 | #ifndef BS_THREAD_POOL_HPP 28 | #define BS_THREAD_POOL_HPP 29 | /** 30 | * @file BS_thread_pool.hpp 31 | * @author Barak Shoshany (baraksh@gmail.com) (https://baraksh.com) 32 | * @version 4.1.0 33 | * @date 2024-03-22 34 | * @copyright Copyright (c) 2024 Barak Shoshany. Licensed under the MIT license. If you found this project useful, please consider starring it on GitHub! If you use this library in software of any kind, please provide a link to the GitHub repository https://github.com/bshoshany/thread-pool in the source code and documentation. If you use this library in published research, please cite it as follows: Barak Shoshany, "A C++17 Thread Pool for High-Performance Scientific Computing", doi:10.1016/j.softx.2024.101687, SoftwareX 26 (2024) 101687, arXiv:2105.00613 35 | * 36 | * @brief BS::thread_pool: a fast, lightweight, and easy-to-use C++17 thread pool library. This header file contains the main thread pool class and some additional classes and definitions. No other files are needed in order to use the thread pool itself. 37 | */ 38 | 39 | #ifndef __cpp_exceptions 40 | #define BS_THREAD_POOL_DISABLE_EXCEPTION_HANDLING 41 | #undef BS_THREAD_POOL_ENABLE_WAIT_DEADLOCK_CHECK 42 | #endif 43 | 44 | #include // std::chrono 45 | #include // std::condition_variable 46 | #include // std::size_t 47 | #ifdef BS_THREAD_POOL_ENABLE_PRIORITY 48 | #include // std::int_least16_t 49 | #endif 50 | #ifndef BS_THREAD_POOL_DISABLE_EXCEPTION_HANDLING 51 | #include // std::current_exception 52 | #endif 53 | #include // std::function 54 | #include // std::future, std::future_status, std::promise 55 | #include // std::make_shared, std::make_unique, std::shared_ptr, std::unique_ptr 56 | #include // std::mutex, std::scoped_lock, std::unique_lock 57 | #include // std::nullopt, std::optional 58 | #include // std::priority_queue (if priority enabled), std::queue 59 | #ifdef BS_THREAD_POOL_ENABLE_WAIT_DEADLOCK_CHECK 60 | #include // std::runtime_error 61 | #endif 62 | #include // std::thread 63 | #include // std::conditional_t, std::decay_t, std::invoke_result_t, std::is_void_v, std::remove_const_t (if priority enabled) 64 | #include // std::forward, std::move 65 | #include // std::vector 66 | 67 | /** 68 | * @brief A namespace used by Barak Shoshany's projects. 69 | */ 70 | namespace BS { 71 | // Macros indicating the version of the thread pool library. 72 | #define BS_THREAD_POOL_VERSION_MAJOR 4 73 | #define BS_THREAD_POOL_VERSION_MINOR 1 74 | #define BS_THREAD_POOL_VERSION_PATCH 0 75 | 76 | class thread_pool; 77 | 78 | /** 79 | * @brief A type to represent the size of things. 80 | */ 81 | using size_t = std::size_t; 82 | 83 | /** 84 | * @brief A convenient shorthand for the type of `std::thread::hardware_concurrency()`. Should evaluate to unsigned int. 85 | */ 86 | using concurrency_t = std::invoke_result_t; 87 | 88 | #ifdef BS_THREAD_POOL_ENABLE_PRIORITY 89 | /** 90 | * @brief A type used to indicate the priority of a task. Defined to be an integer with a width of (at least) 16 bits. 91 | */ 92 | using priority_t = std::int_least16_t; 93 | 94 | /** 95 | * @brief A namespace containing some pre-defined priorities for convenience. 96 | */ 97 | namespace pr { 98 | constexpr priority_t highest = 32767; 99 | constexpr priority_t high = 16383; 100 | constexpr priority_t normal = 0; 101 | constexpr priority_t low = -16384; 102 | constexpr priority_t lowest = -32768; 103 | } // namespace pr 104 | 105 | // Macros used internally to enable or disable the priority arguments in the relevant functions. 106 | #define BS_THREAD_POOL_PRIORITY_INPUT , const priority_t priority = 0 107 | #define BS_THREAD_POOL_PRIORITY_OUTPUT , priority 108 | #else 109 | #define BS_THREAD_POOL_PRIORITY_INPUT 110 | #define BS_THREAD_POOL_PRIORITY_OUTPUT 111 | #endif 112 | 113 | /** 114 | * @brief A namespace used to obtain information about the current thread. 115 | */ 116 | namespace this_thread { 117 | /** 118 | * @brief A type returned by `BS::this_thread::get_index()` which can optionally contain the index of a thread, if that thread belongs to a `BS::thread_pool`. Otherwise, it will contain no value. 119 | */ 120 | using optional_index = std::optional; 121 | 122 | /** 123 | * @brief A type returned by `BS::this_thread::get_pool()` which can optionally contain the pointer to the pool that owns a thread, if that thread belongs to a `BS::thread_pool`. Otherwise, it will contain no value. 124 | */ 125 | using optional_pool = std::optional; 126 | 127 | /** 128 | * @brief A helper class to store information about the index of the current thread. 129 | */ 130 | class [[nodiscard]] thread_info_index 131 | { 132 | friend class BS::thread_pool; 133 | 134 | public: 135 | /** 136 | * @brief Get the index of the current thread. If this thread belongs to a `BS::thread_pool` object, it will have an index from 0 to `BS::thread_pool::get_thread_count() - 1`. Otherwise, for example if this thread is the main thread or an independent `std::thread`, `std::nullopt` will be returned. 137 | * 138 | * @return An `std::optional` object, optionally containing a thread index. Unless you are 100% sure this thread is in a pool, first use `std::optional::has_value()` to check if it contains a value, and if so, use `std::optional::value()` to obtain that value. 139 | */ 140 | [[nodiscard]] optional_index operator()() const 141 | { 142 | return index; 143 | } 144 | 145 | private: 146 | /** 147 | * @brief The index of the current thread. 148 | */ 149 | optional_index index = std::nullopt; 150 | }; // class thread_info_index 151 | 152 | /** 153 | * @brief A helper class to store information about the thread pool that owns the current thread. 154 | */ 155 | class [[nodiscard]] thread_info_pool 156 | { 157 | friend class BS::thread_pool; 158 | 159 | public: 160 | /** 161 | * @brief Get the pointer to the thread pool that owns the current thread. If this thread belongs to a `BS::thread_pool` object, a pointer to that object will be returned. Otherwise, for example if this thread is the main thread or an independent `std::thread`, `std::nullopt` will be returned. 162 | * 163 | * @return An `std::optional` object, optionally containing a pointer to a thread pool. Unless you are 100% sure this thread is in a pool, first use `std::optional::has_value()` to check if it contains a value, and if so, use `std::optional::value()` to obtain that value. 164 | */ 165 | [[nodiscard]] optional_pool operator()() const 166 | { 167 | return pool; 168 | } 169 | 170 | private: 171 | /** 172 | * @brief A pointer to the thread pool that owns the current thread. 173 | */ 174 | optional_pool pool = std::nullopt; 175 | }; // class thread_info_pool 176 | 177 | /** 178 | * @brief A `thread_local` object used to obtain information about the index of the current thread. 179 | */ 180 | inline thread_local thread_info_index get_index; 181 | 182 | /** 183 | * @brief A `thread_local` object used to obtain information about the thread pool that owns the current thread. 184 | */ 185 | inline thread_local thread_info_pool get_pool; 186 | } // namespace this_thread 187 | 188 | /** 189 | * @brief A helper class to facilitate waiting for and/or getting the results of multiple futures at once. 190 | * 191 | * @tparam T The return type of the futures. 192 | */ 193 | template 194 | class [[nodiscard]] multi_future : public std::vector> 195 | { 196 | public: 197 | // Inherit all constructors from the base class `std::vector`. 198 | using std::vector>::vector; 199 | 200 | // The copy constructor and copy assignment operator are deleted. The elements stored in a `multi_future` are futures, which cannot be copied. 201 | multi_future(const multi_future&) = delete; 202 | multi_future& operator=(const multi_future&) = delete; 203 | 204 | // The move constructor and move assignment operator are defaulted. 205 | multi_future(multi_future&&) = default; 206 | multi_future& operator=(multi_future&&) = default; 207 | 208 | /** 209 | * @brief Get the results from all the futures stored in this `multi_future`, rethrowing any stored exceptions. 210 | * 211 | * @return If the futures return `void`, this function returns `void` as well. Otherwise, it returns a vector containing the results. 212 | */ 213 | [[nodiscard]] std::conditional_t, void, std::vector> get() 214 | { 215 | if constexpr (std::is_void_v) 216 | { 217 | for (std::future& future : *this) 218 | future.get(); 219 | return; 220 | } 221 | else 222 | { 223 | std::vector results; 224 | results.reserve(this->size()); 225 | for (std::future& future : *this) 226 | results.push_back(future.get()); 227 | return results; 228 | } 229 | } 230 | 231 | /** 232 | * @brief Check how many of the futures stored in this `multi_future` are ready. 233 | * 234 | * @return The number of ready futures. 235 | */ 236 | [[nodiscard]] size_t ready_count() const 237 | { 238 | size_t count = 0; 239 | for (const std::future& future : *this) 240 | { 241 | if (future.wait_for(std::chrono::duration::zero()) == std::future_status::ready) 242 | ++count; 243 | } 244 | return count; 245 | } 246 | 247 | /** 248 | * @brief Check if all the futures stored in this `multi_future` are valid. 249 | * 250 | * @return `true` if all futures are valid, `false` if at least one of the futures is not valid. 251 | */ 252 | [[nodiscard]] bool valid() const 253 | { 254 | bool is_valid = true; 255 | for (const std::future& future : *this) 256 | is_valid = is_valid && future.valid(); 257 | return is_valid; 258 | } 259 | 260 | /** 261 | * @brief Wait for all the futures stored in this `multi_future`. 262 | */ 263 | void wait() const 264 | { 265 | for (const std::future& future : *this) 266 | future.wait(); 267 | } 268 | 269 | /** 270 | * @brief Wait for all the futures stored in this `multi_future`, but stop waiting after the specified duration has passed. This function first waits for the first future for the desired duration. If that future is ready before the duration expires, this function waits for the second future for whatever remains of the duration. It continues similarly until the duration expires. 271 | * 272 | * @tparam R An arithmetic type representing the number of ticks to wait. 273 | * @tparam P An `std::ratio` representing the length of each tick in seconds. 274 | * @param duration The amount of time to wait. 275 | * @return `true` if all futures have been waited for before the duration expired, `false` otherwise. 276 | */ 277 | template 278 | bool wait_for(const std::chrono::duration& duration) const 279 | { 280 | const std::chrono::time_point start_time = std::chrono::steady_clock::now(); 281 | for (const std::future& future : *this) 282 | { 283 | future.wait_for(duration - (std::chrono::steady_clock::now() - start_time)); 284 | if (duration < std::chrono::steady_clock::now() - start_time) 285 | return false; 286 | } 287 | return true; 288 | } 289 | 290 | /** 291 | * @brief Wait for all the futures stored in this `multi_future`, but stop waiting after the specified time point has been reached. This function first waits for the first future until the desired time point. If that future is ready before the time point is reached, this function waits for the second future until the desired time point. It continues similarly until the time point is reached. 292 | * 293 | * @tparam C The type of the clock used to measure time. 294 | * @tparam D An `std::chrono::duration` type used to indicate the time point. 295 | * @param timeout_time The time point at which to stop waiting. 296 | * @return `true` if all futures have been waited for before the time point was reached, `false` otherwise. 297 | */ 298 | template 299 | bool wait_until(const std::chrono::time_point& timeout_time) const 300 | { 301 | for (const std::future& future : *this) 302 | { 303 | future.wait_until(timeout_time); 304 | if (timeout_time < std::chrono::steady_clock::now()) 305 | return false; 306 | } 307 | return true; 308 | } 309 | }; // class multi_future 310 | 311 | /** 312 | * @brief A fast, lightweight, and easy-to-use C++17 thread pool class. 313 | */ 314 | class [[nodiscard]] thread_pool 315 | { 316 | public: 317 | // ============================ 318 | // Constructors and destructors 319 | // ============================ 320 | 321 | /** 322 | * @brief Construct a new thread pool. The number of threads will be the total number of hardware threads available, as reported by the implementation. This is usually determined by the number of cores in the CPU. If a core is hyperthreaded, it will count as two threads. 323 | */ 324 | thread_pool() : thread_pool(0, [] {}) {} 325 | 326 | /** 327 | * @brief Construct a new thread pool with the specified number of threads. 328 | * 329 | * @param num_threads The number of threads to use. 330 | */ 331 | explicit thread_pool(const concurrency_t num_threads) : thread_pool(num_threads, [] {}) {} 332 | 333 | /** 334 | * @brief Construct a new thread pool with the specified initialization function. 335 | * 336 | * @param init_task An initialization function to run in each thread before it starts to execute any submitted tasks. The function must take no arguments and have no return value. It will only be executed exactly once, when the thread is first constructed. 337 | */ 338 | explicit thread_pool(const std::function& init_task) : thread_pool(0, init_task) {} 339 | 340 | /** 341 | * @brief Construct a new thread pool with the specified number of threads and initialization function. 342 | * 343 | * @param num_threads The number of threads to use. 344 | * @param init_task An initialization function to run in each thread before it starts to execute any submitted tasks. The function must take no arguments and have no return value. It will only be executed exactly once, when the thread is first constructed. 345 | */ 346 | thread_pool(const concurrency_t num_threads, const std::function& init_task) : thread_count(determine_thread_count(num_threads)), threads(std::make_unique(determine_thread_count(num_threads))) 347 | { 348 | create_threads(init_task); 349 | } 350 | 351 | // The copy and move constructors and assignment operators are deleted. The thread pool uses a mutex, which cannot be copied or moved. 352 | thread_pool(const thread_pool&) = delete; 353 | thread_pool(thread_pool&&) = delete; 354 | thread_pool& operator=(const thread_pool&) = delete; 355 | thread_pool& operator=(thread_pool&&) = delete; 356 | 357 | /** 358 | * @brief Destruct the thread pool. Waits for all tasks to complete, then destroys all threads. Note that if the pool is paused, then any tasks still in the queue will never be executed. 359 | */ 360 | ~thread_pool() 361 | { 362 | wait(); 363 | destroy_threads(); 364 | } 365 | 366 | // ======================= 367 | // Public member functions 368 | // ======================= 369 | 370 | #ifdef BS_THREAD_POOL_ENABLE_NATIVE_HANDLES 371 | /** 372 | * @brief Get a vector containing the underlying implementation-defined thread handles for each of the pool's threads, as obtained by `std::thread::native_handle()`. Only enabled if `BS_THREAD_POOL_ENABLE_NATIVE_HANDLES` is defined. 373 | * 374 | * @return The native thread handles. 375 | */ 376 | [[nodiscard]] std::vector get_native_handles() const 377 | { 378 | std::vector native_handles(thread_count); 379 | for (concurrency_t i = 0; i < thread_count; ++i) 380 | { 381 | native_handles[i] = threads[i].native_handle(); 382 | } 383 | return native_handles; 384 | } 385 | #endif 386 | 387 | /** 388 | * @brief Get the number of tasks currently waiting in the queue to be executed by the threads. 389 | * 390 | * @return The number of queued tasks. 391 | */ 392 | [[nodiscard]] size_t get_tasks_queued() const 393 | { 394 | const std::scoped_lock tasks_lock(tasks_mutex); 395 | return tasks.size(); 396 | } 397 | 398 | /** 399 | * @brief Get the number of tasks currently being executed by the threads. 400 | * 401 | * @return The number of running tasks. 402 | */ 403 | [[nodiscard]] size_t get_tasks_running() const 404 | { 405 | const std::scoped_lock tasks_lock(tasks_mutex); 406 | return tasks_running; 407 | } 408 | 409 | /** 410 | * @brief Get the total number of unfinished tasks: either still waiting in the queue, or running in a thread. Note that `get_tasks_total() == get_tasks_queued() + get_tasks_running()`. 411 | * 412 | * @return The total number of tasks. 413 | */ 414 | [[nodiscard]] size_t get_tasks_total() const 415 | { 416 | const std::scoped_lock tasks_lock(tasks_mutex); 417 | return tasks_running + tasks.size(); 418 | } 419 | 420 | /** 421 | * @brief Get the number of threads in the pool. 422 | * 423 | * @return The number of threads. 424 | */ 425 | [[nodiscard]] concurrency_t get_thread_count() const 426 | { 427 | return thread_count; 428 | } 429 | 430 | /** 431 | * @brief Get a vector containing the unique identifiers for each of the pool's threads, as obtained by `std::thread::get_id()`. 432 | * 433 | * @return The unique thread identifiers. 434 | */ 435 | [[nodiscard]] std::vector get_thread_ids() const 436 | { 437 | std::vector thread_ids(thread_count); 438 | for (concurrency_t i = 0; i < thread_count; ++i) 439 | { 440 | thread_ids[i] = threads[i].get_id(); 441 | } 442 | return thread_ids; 443 | } 444 | 445 | #ifdef BS_THREAD_POOL_ENABLE_PAUSE 446 | /** 447 | * @brief Check whether the pool is currently paused. Only enabled if `BS_THREAD_POOL_ENABLE_PAUSE` is defined. 448 | * 449 | * @return `true` if the pool is paused, `false` if it is not paused. 450 | */ 451 | [[nodiscard]] bool is_paused() const 452 | { 453 | const std::scoped_lock tasks_lock(tasks_mutex); 454 | return paused; 455 | } 456 | 457 | /** 458 | * @brief Pause the pool. The workers will temporarily stop retrieving new tasks out of the queue, although any tasks already executed will keep running until they are finished. Only enabled if `BS_THREAD_POOL_ENABLE_PAUSE` is defined. 459 | */ 460 | void pause() 461 | { 462 | const std::scoped_lock tasks_lock(tasks_mutex); 463 | paused = true; 464 | } 465 | #endif 466 | 467 | /** 468 | * @brief Purge all the tasks waiting in the queue. Tasks that are currently running will not be affected, but any tasks still waiting in the queue will be discarded, and will never be executed by the threads. Please note that there is no way to restore the purged tasks. 469 | */ 470 | void purge() 471 | { 472 | const std::scoped_lock tasks_lock(tasks_mutex); 473 | while (!tasks.empty()) 474 | tasks.pop(); 475 | } 476 | 477 | /** 478 | * @brief Submit a function with no arguments and no return value into the task queue, with the specified priority. To push a function with arguments, enclose it in a lambda expression. Does not return a future, so the user must use `wait()` or some other method to ensure that the task finishes executing, otherwise bad things will happen. 479 | * 480 | * @tparam F The type of the function. 481 | * @param task The function to push. 482 | * @param priority The priority of the task. Should be between -32,768 and 32,767 (a signed 16-bit integer). The default is 0. Only enabled if `BS_THREAD_POOL_ENABLE_PRIORITY` is defined. 483 | */ 484 | template 485 | void detach_task(F&& task BS_THREAD_POOL_PRIORITY_INPUT) 486 | { 487 | { 488 | const std::scoped_lock tasks_lock(tasks_mutex); 489 | tasks.emplace(std::forward(task) BS_THREAD_POOL_PRIORITY_OUTPUT); 490 | } 491 | task_available_cv.notify_one(); 492 | } 493 | 494 | /** 495 | * @brief Parallelize a loop by automatically splitting it into blocks and submitting each block separately to the queue, with the specified priority. The block function takes two arguments, the start and end of the block, so that it is only called only once per block, but it is up to the user make sure the block function correctly deals with all the indices in each block. Does not return a `multi_future`, so the user must use `wait()` or some other method to ensure that the loop finishes executing, otherwise bad things will happen. 496 | * 497 | * @tparam T The type of the indices. Should be a signed or unsigned integer. 498 | * @tparam F The type of the function to loop through. 499 | * @param first_index The first index in the loop. 500 | * @param index_after_last The index after the last index in the loop. The loop will iterate from `first_index` to `(index_after_last - 1)` inclusive. In other words, it will be equivalent to `for (T i = first_index; i < index_after_last; ++i)`. Note that if `index_after_last <= first_index`, no blocks will be submitted. 501 | * @param block A function that will be called once per block. Should take exactly two arguments: the first index in the block and the index after the last index in the block. `block(start, end)` should typically involve a loop of the form `for (T i = start; i < end; ++i)`. 502 | * @param num_blocks The maximum number of blocks to split the loop into. The default is 0, which means the number of blocks will be equal to the number of threads in the pool. 503 | * @param priority The priority of the tasks. Should be between -32,768 and 32,767 (a signed 16-bit integer). The default is 0. Only enabled if `BS_THREAD_POOL_ENABLE_PRIORITY` is defined. 504 | */ 505 | template 506 | void detach_blocks(const T first_index, const T index_after_last, F&& block, const size_t num_blocks = 0 BS_THREAD_POOL_PRIORITY_INPUT) 507 | { 508 | if (index_after_last > first_index) 509 | { 510 | const blocks blks(first_index, index_after_last, num_blocks ? num_blocks : thread_count); 511 | for (size_t blk = 0; blk < blks.get_num_blocks(); ++blk) 512 | detach_task( 513 | [block = std::forward(block), start = blks.start(blk), end = blks.end(blk)] 514 | { 515 | block(start, end); 516 | } BS_THREAD_POOL_PRIORITY_OUTPUT); 517 | } 518 | } 519 | 520 | /** 521 | * @brief Parallelize a loop by automatically splitting it into blocks and submitting each block separately to the queue, with the specified priority. The loop function takes one argument, the loop index, so that it is called many times per block. Does not return a `multi_future`, so the user must use `wait()` or some other method to ensure that the loop finishes executing, otherwise bad things will happen. 522 | * 523 | * @tparam T The type of the indices. Should be a signed or unsigned integer. 524 | * @tparam F The type of the function to loop through. 525 | * @param first_index The first index in the loop. 526 | * @param index_after_last The index after the last index in the loop. The loop will iterate from `first_index` to `(index_after_last - 1)` inclusive. In other words, it will be equivalent to `for (T i = first_index; i < index_after_last; ++i)`. Note that if `index_after_last <= first_index`, no blocks will be submitted. 527 | * @param loop The function to loop through. Will be called once per index, many times per block. Should take exactly one argument: the loop index. 528 | * @param num_blocks The maximum number of blocks to split the loop into. The default is 0, which means the number of blocks will be equal to the number of threads in the pool. 529 | * @param priority The priority of the tasks. Should be between -32,768 and 32,767 (a signed 16-bit integer). The default is 0. Only enabled if `BS_THREAD_POOL_ENABLE_PRIORITY` is defined. 530 | */ 531 | template 532 | void detach_loop(const T first_index, const T index_after_last, F&& loop, const size_t num_blocks = 0 BS_THREAD_POOL_PRIORITY_INPUT) 533 | { 534 | if (index_after_last > first_index) 535 | { 536 | const blocks blks(first_index, index_after_last, num_blocks ? num_blocks : thread_count); 537 | for (size_t blk = 0; blk < blks.get_num_blocks(); ++blk) 538 | detach_task( 539 | [loop = std::forward(loop), start = blks.start(blk), end = blks.end(blk)] 540 | { 541 | for (T i = start; i < end; ++i) 542 | loop(i); 543 | } BS_THREAD_POOL_PRIORITY_OUTPUT); 544 | } 545 | } 546 | 547 | /** 548 | * @brief Submit a sequence of tasks enumerated by indices to the queue, with the specified priority. Does not return a `multi_future`, so the user must use `wait()` or some other method to ensure that the sequence finishes executing, otherwise bad things will happen. 549 | * 550 | * @tparam T The type of the indices. Should be a signed or unsigned integer. 551 | * @tparam F The type of the function used to define the sequence. 552 | * @param first_index The first index in the sequence. 553 | * @param index_after_last The index after the last index in the sequence. The sequence will iterate from `first_index` to `(index_after_last - 1)` inclusive. In other words, it will be equivalent to `for (T i = first_index; i < index_after_last; ++i)`. Note that if `index_after_last <= first_index`, no tasks will be submitted. 554 | * @param sequence The function used to define the sequence. Will be called once per index. Should take exactly one argument, the index. 555 | * @param priority The priority of the tasks. Should be between -32,768 and 32,767 (a signed 16-bit integer). The default is 0. Only enabled if `BS_THREAD_POOL_ENABLE_PRIORITY` is defined. 556 | */ 557 | template 558 | void detach_sequence(const T first_index, const T index_after_last, F&& sequence BS_THREAD_POOL_PRIORITY_INPUT) 559 | { 560 | for (T i = first_index; i < index_after_last; ++i) 561 | detach_task( 562 | [sequence = std::forward(sequence), i] 563 | { 564 | sequence(i); 565 | } BS_THREAD_POOL_PRIORITY_OUTPUT); 566 | } 567 | 568 | /** 569 | * @brief Reset the pool with the total number of hardware threads available, as reported by the implementation. Waits for all currently running tasks to be completed, then destroys all threads in the pool and creates a new thread pool with the new number of threads. Any tasks that were waiting in the queue before the pool was reset will then be executed by the new threads. If the pool was paused before resetting it, the new pool will be paused as well. 570 | */ 571 | void reset() 572 | { 573 | reset(0, [] {}); 574 | } 575 | 576 | /** 577 | * @brief Reset the pool with a new number of threads. Waits for all currently running tasks to be completed, then destroys all threads in the pool and creates a new thread pool with the new number of threads. Any tasks that were waiting in the queue before the pool was reset will then be executed by the new threads. If the pool was paused before resetting it, the new pool will be paused as well. 578 | * 579 | * @param num_threads The number of threads to use. 580 | */ 581 | void reset(const concurrency_t num_threads) 582 | { 583 | reset(num_threads, [] {}); 584 | } 585 | 586 | /** 587 | * @brief Reset the pool with the total number of hardware threads available, as reported by the implementation, and a new initialization function. Waits for all currently running tasks to be completed, then destroys all threads in the pool and creates a new thread pool with the new number of threads and initialization function. Any tasks that were waiting in the queue before the pool was reset will then be executed by the new threads. If the pool was paused before resetting it, the new pool will be paused as well. 588 | * 589 | * @param init_task An initialization function to run in each thread before it starts to execute any submitted tasks. The function must take no arguments and have no return value. It will only be executed exactly once, when the thread is first constructed. 590 | */ 591 | void reset(const std::function& init_task) 592 | { 593 | reset(0, init_task); 594 | } 595 | 596 | /** 597 | * @brief Reset the pool with a new number of threads and a new initialization function. Waits for all currently running tasks to be completed, then destroys all threads in the pool and creates a new thread pool with the new number of threads and initialization function. Any tasks that were waiting in the queue before the pool was reset will then be executed by the new threads. If the pool was paused before resetting it, the new pool will be paused as well. 598 | * 599 | * @param num_threads The number of threads to use. 600 | * @param init_task An initialization function to run in each thread before it starts to execute any submitted tasks. The function must take no arguments and have no return value. It will only be executed exactly once, when the thread is first constructed. 601 | */ 602 | void reset(const concurrency_t num_threads, const std::function& init_task) 603 | { 604 | #ifdef BS_THREAD_POOL_ENABLE_PAUSE 605 | std::unique_lock tasks_lock(tasks_mutex); 606 | const bool was_paused = paused; 607 | paused = true; 608 | tasks_lock.unlock(); 609 | #endif 610 | wait(); 611 | destroy_threads(); 612 | thread_count = determine_thread_count(num_threads); 613 | threads = std::make_unique(thread_count); 614 | create_threads(init_task); 615 | #ifdef BS_THREAD_POOL_ENABLE_PAUSE 616 | tasks_lock.lock(); 617 | paused = was_paused; 618 | #endif 619 | } 620 | 621 | /** 622 | * @brief Submit a function with no arguments into the task queue, with the specified priority. To submit a function with arguments, enclose it in a lambda expression. If the function has a return value, get a future for the eventual returned value. If the function has no return value, get an `std::future` which can be used to wait until the task finishes. 623 | * 624 | * @tparam F The type of the function. 625 | * @tparam R The return type of the function (can be `void`). 626 | * @param task The function to submit. 627 | * @param priority The priority of the task. Should be between -32,768 and 32,767 (a signed 16-bit integer). The default is 0. Only enabled if `BS_THREAD_POOL_ENABLE_PRIORITY` is defined. 628 | * @return A future to be used later to wait for the function to finish executing and/or obtain its returned value if it has one. 629 | */ 630 | template >> 631 | [[nodiscard]] std::future submit_task(F&& task BS_THREAD_POOL_PRIORITY_INPUT) 632 | { 633 | const std::shared_ptr> task_promise = std::make_shared>(); 634 | detach_task( 635 | [task = std::forward(task), task_promise] 636 | { 637 | #ifndef BS_THREAD_POOL_DISABLE_EXCEPTION_HANDLING 638 | try 639 | { 640 | #endif 641 | if constexpr (std::is_void_v) 642 | { 643 | task(); 644 | task_promise->set_value(); 645 | } 646 | else 647 | { 648 | task_promise->set_value(task()); 649 | } 650 | #ifndef BS_THREAD_POOL_DISABLE_EXCEPTION_HANDLING 651 | } 652 | catch (...) 653 | { 654 | try 655 | { 656 | task_promise->set_exception(std::current_exception()); 657 | } 658 | catch (...) 659 | { 660 | } 661 | } 662 | #endif 663 | } BS_THREAD_POOL_PRIORITY_OUTPUT); 664 | return task_promise->get_future(); 665 | } 666 | 667 | /** 668 | * @brief Parallelize a loop by automatically splitting it into blocks and submitting each block separately to the queue, with the specified priority. The block function takes two arguments, the start and end of the block, so that it is only called only once per block, but it is up to the user make sure the block function correctly deals with all the indices in each block. Returns a `multi_future` that contains the futures for all of the blocks. 669 | * 670 | * @tparam T The type of the indices. Should be a signed or unsigned integer. 671 | * @tparam F The type of the function to loop through. 672 | * @tparam R The return type of the function to loop through (can be `void`). 673 | * @param first_index The first index in the loop. 674 | * @param index_after_last The index after the last index in the loop. The loop will iterate from `first_index` to `(index_after_last - 1)` inclusive. In other words, it will be equivalent to `for (T i = first_index; i < index_after_last; ++i)`. Note that if `index_after_last <= first_index`, no blocks will be submitted, and an empty `multi_future` will be returned. 675 | * @param block A function that will be called once per block. Should take exactly two arguments: the first index in the block and the index after the last index in the block. `block(start, end)` should typically involve a loop of the form `for (T i = start; i < end; ++i)`. 676 | * @param num_blocks The maximum number of blocks to split the loop into. The default is 0, which means the number of blocks will be equal to the number of threads in the pool. 677 | * @param priority The priority of the tasks. Should be between -32,768 and 32,767 (a signed 16-bit integer). The default is 0. Only enabled if `BS_THREAD_POOL_ENABLE_PRIORITY` is defined. 678 | * @return A `multi_future` that can be used to wait for all the blocks to finish. If the block function returns a value, the `multi_future` can also be used to obtain the values returned by each block. 679 | */ 680 | template , T, T>> 681 | [[nodiscard]] multi_future submit_blocks(const T first_index, const T index_after_last, F&& block, const size_t num_blocks = 0 BS_THREAD_POOL_PRIORITY_INPUT) 682 | { 683 | if (index_after_last > first_index) 684 | { 685 | const blocks blks(first_index, index_after_last, num_blocks ? num_blocks : thread_count); 686 | multi_future future; 687 | future.reserve(blks.get_num_blocks()); 688 | for (size_t blk = 0; blk < blks.get_num_blocks(); ++blk) 689 | future.push_back(submit_task( 690 | [block = std::forward(block), start = blks.start(blk), end = blks.end(blk)] 691 | { 692 | return block(start, end); 693 | } BS_THREAD_POOL_PRIORITY_OUTPUT)); 694 | return future; 695 | } 696 | return {}; 697 | } 698 | 699 | /** 700 | * @brief Parallelize a loop by automatically splitting it into blocks and submitting each block separately to the queue, with the specified priority. The loop function takes one argument, the loop index, so that it is called many times per block. It must have no return value. Returns a `multi_future` that contains the futures for all of the blocks. 701 | * 702 | * @tparam T The type of the indices. Should be a signed or unsigned integer. 703 | * @tparam F The type of the function to loop through. 704 | * @param first_index The first index in the loop. 705 | * @param index_after_last The index after the last index in the loop. The loop will iterate from `first_index` to `(index_after_last - 1)` inclusive. In other words, it will be equivalent to `for (T i = first_index; i < index_after_last; ++i)`. Note that if `index_after_last <= first_index`, no tasks will be submitted, and an empty `multi_future` will be returned. 706 | * @param loop The function to loop through. Will be called once per index, many times per block. Should take exactly one argument: the loop index. It cannot have a return value. 707 | * @param num_blocks The maximum number of blocks to split the loop into. The default is 0, which means the number of blocks will be equal to the number of threads in the pool. 708 | * @param priority The priority of the tasks. Should be between -32,768 and 32,767 (a signed 16-bit integer). The default is 0. Only enabled if `BS_THREAD_POOL_ENABLE_PRIORITY` is defined. 709 | * @return A `multi_future` that can be used to wait for all the blocks to finish. 710 | */ 711 | template 712 | [[nodiscard]] multi_future submit_loop(const T first_index, const T index_after_last, F&& loop, const size_t num_blocks = 0 BS_THREAD_POOL_PRIORITY_INPUT) 713 | { 714 | if (index_after_last > first_index) 715 | { 716 | const blocks blks(first_index, index_after_last, num_blocks ? num_blocks : thread_count); 717 | multi_future future; 718 | future.reserve(blks.get_num_blocks()); 719 | for (size_t blk = 0; blk < blks.get_num_blocks(); ++blk) 720 | future.push_back(submit_task( 721 | [loop = std::forward(loop), start = blks.start(blk), end = blks.end(blk)] 722 | { 723 | for (T i = start; i < end; ++i) 724 | loop(i); 725 | } BS_THREAD_POOL_PRIORITY_OUTPUT)); 726 | return future; 727 | } 728 | return {}; 729 | } 730 | 731 | /** 732 | * @brief Submit a sequence of tasks enumerated by indices to the queue, with the specified priority. Returns a `multi_future` that contains the futures for all of the tasks. 733 | * 734 | * @tparam T The type of the indices. Should be a signed or unsigned integer. 735 | * @tparam F The type of the function used to define the sequence. 736 | * @tparam R The return type of the function used to define the sequence (can be `void`). 737 | * @param first_index The first index in the sequence. 738 | * @param index_after_last The index after the last index in the sequence. The sequence will iterate from `first_index` to `(index_after_last - 1)` inclusive. In other words, it will be equivalent to `for (T i = first_index; i < index_after_last; ++i)`. Note that if `index_after_last <= first_index`, no tasks will be submitted, and an empty `multi_future` will be returned. 739 | * @param sequence The function used to define the sequence. Will be called once per index. Should take exactly one argument, the index. 740 | * @param priority The priority of the tasks. Should be between -32,768 and 32,767 (a signed 16-bit integer). The default is 0. Only enabled if `BS_THREAD_POOL_ENABLE_PRIORITY` is defined. 741 | * @return A `multi_future` that can be used to wait for all the tasks to finish. If the sequence function returns a value, the `multi_future` can also be used to obtain the values returned by each task. 742 | */ 743 | template , T>> 744 | [[nodiscard]] multi_future submit_sequence(const T first_index, const T index_after_last, F&& sequence BS_THREAD_POOL_PRIORITY_INPUT) 745 | { 746 | if (index_after_last > first_index) 747 | { 748 | multi_future future; 749 | future.reserve(static_cast(index_after_last - first_index)); 750 | for (T i = first_index; i < index_after_last; ++i) 751 | future.push_back(submit_task( 752 | [sequence = std::forward(sequence), i] 753 | { 754 | return sequence(i); 755 | } BS_THREAD_POOL_PRIORITY_OUTPUT)); 756 | return future; 757 | } 758 | return {}; 759 | } 760 | 761 | #ifdef BS_THREAD_POOL_ENABLE_PAUSE 762 | /** 763 | * @brief Unpause the pool. The workers will resume retrieving new tasks out of the queue. Only enabled if `BS_THREAD_POOL_ENABLE_PAUSE` is defined. 764 | */ 765 | void unpause() 766 | { 767 | { 768 | const std::scoped_lock tasks_lock(tasks_mutex); 769 | paused = false; 770 | } 771 | task_available_cv.notify_all(); 772 | } 773 | #endif 774 | 775 | // Macros used internally to enable or disable pausing in the waiting and worker functions. 776 | #ifdef BS_THREAD_POOL_ENABLE_PAUSE 777 | #define BS_THREAD_POOL_PAUSED_OR_EMPTY (paused || tasks.empty()) 778 | #else 779 | #define BS_THREAD_POOL_PAUSED_OR_EMPTY tasks.empty() 780 | #endif 781 | 782 | /** 783 | * @brief Wait for tasks to be completed. Normally, this function waits for all tasks, both those that are currently running in the threads and those that are still waiting in the queue. However, if the pool is paused, this function only waits for the currently running tasks (otherwise it would wait forever). Note: To wait for just one specific task, use `submit_task()` instead, and call the `wait()` member function of the generated future. 784 | * 785 | * @throws `wait_deadlock` if called from within a thread of the same pool, which would result in a deadlock. Only enabled if `BS_THREAD_POOL_ENABLE_WAIT_DEADLOCK_CHECK` is defined. 786 | */ 787 | void wait() 788 | { 789 | #ifdef BS_THREAD_POOL_ENABLE_WAIT_DEADLOCK_CHECK 790 | if (this_thread::get_pool() == this) 791 | throw wait_deadlock(); 792 | #endif 793 | std::unique_lock tasks_lock(tasks_mutex); 794 | waiting = true; 795 | tasks_done_cv.wait(tasks_lock, 796 | [this] 797 | { 798 | return (tasks_running == 0) && BS_THREAD_POOL_PAUSED_OR_EMPTY; 799 | }); 800 | waiting = false; 801 | } 802 | 803 | /** 804 | * @brief Wait for tasks to be completed, but stop waiting after the specified duration has passed. 805 | * 806 | * @tparam R An arithmetic type representing the number of ticks to wait. 807 | * @tparam P An `std::ratio` representing the length of each tick in seconds. 808 | * @param duration The amount of time to wait. 809 | * @return `true` if all tasks finished running, `false` if the duration expired but some tasks are still running. 810 | * 811 | * @throws `wait_deadlock` if called from within a thread of the same pool, which would result in a deadlock. Only enabled if `BS_THREAD_POOL_ENABLE_WAIT_DEADLOCK_CHECK` is defined. 812 | */ 813 | template 814 | bool wait_for(const std::chrono::duration& duration) 815 | { 816 | #ifdef BS_THREAD_POOL_ENABLE_WAIT_DEADLOCK_CHECK 817 | if (this_thread::get_pool() == this) 818 | throw wait_deadlock(); 819 | #endif 820 | std::unique_lock tasks_lock(tasks_mutex); 821 | waiting = true; 822 | const bool status = tasks_done_cv.wait_for(tasks_lock, duration, 823 | [this] 824 | { 825 | return (tasks_running == 0) && BS_THREAD_POOL_PAUSED_OR_EMPTY; 826 | }); 827 | waiting = false; 828 | return status; 829 | } 830 | 831 | /** 832 | * @brief Wait for tasks to be completed, but stop waiting after the specified time point has been reached. 833 | * 834 | * @tparam C The type of the clock used to measure time. 835 | * @tparam D An `std::chrono::duration` type used to indicate the time point. 836 | * @param timeout_time The time point at which to stop waiting. 837 | * @return `true` if all tasks finished running, `false` if the time point was reached but some tasks are still running. 838 | * 839 | * @throws `wait_deadlock` if called from within a thread of the same pool, which would result in a deadlock. Only enabled if `BS_THREAD_POOL_ENABLE_WAIT_DEADLOCK_CHECK` is defined. 840 | */ 841 | template 842 | bool wait_until(const std::chrono::time_point& timeout_time) 843 | { 844 | #ifdef BS_THREAD_POOL_ENABLE_WAIT_DEADLOCK_CHECK 845 | if (this_thread::get_pool() == this) 846 | throw wait_deadlock(); 847 | #endif 848 | std::unique_lock tasks_lock(tasks_mutex); 849 | waiting = true; 850 | const bool status = tasks_done_cv.wait_until(tasks_lock, timeout_time, 851 | [this] 852 | { 853 | return (tasks_running == 0) && BS_THREAD_POOL_PAUSED_OR_EMPTY; 854 | }); 855 | waiting = false; 856 | return status; 857 | } 858 | 859 | #ifdef BS_THREAD_POOL_ENABLE_WAIT_DEADLOCK_CHECK 860 | // ============== 861 | // Public classes 862 | // ============== 863 | 864 | /** 865 | * @brief An exception that will be thrown by `wait()`, `wait_for()`, and `wait_until()` if the user tries to call them from within a thread of the same pool, which would result in a deadlock. 866 | */ 867 | struct wait_deadlock : public std::runtime_error 868 | { 869 | wait_deadlock() : std::runtime_error("BS::thread_pool::wait_deadlock"){}; 870 | }; 871 | #endif 872 | 873 | private: 874 | // ======================== 875 | // Private member functions 876 | // ======================== 877 | 878 | /** 879 | * @brief Create the threads in the pool and assign a worker to each thread. 880 | * 881 | * @param init_task An initialization function to run in each thread before it starts to execute any submitted tasks. 882 | */ 883 | void create_threads(const std::function& init_task) 884 | { 885 | { 886 | const std::scoped_lock tasks_lock(tasks_mutex); 887 | tasks_running = thread_count; 888 | workers_running = true; 889 | } 890 | for (concurrency_t i = 0; i < thread_count; ++i) 891 | { 892 | threads[i] = std::thread(&thread_pool::worker, this, i, init_task); 893 | } 894 | } 895 | 896 | /** 897 | * @brief Destroy the threads in the pool. 898 | */ 899 | void destroy_threads() 900 | { 901 | { 902 | const std::scoped_lock tasks_lock(tasks_mutex); 903 | workers_running = false; 904 | } 905 | task_available_cv.notify_all(); 906 | for (concurrency_t i = 0; i < thread_count; ++i) 907 | { 908 | threads[i].join(); 909 | } 910 | } 911 | 912 | /** 913 | * @brief Determine how many threads the pool should have, based on the parameter passed to the constructor or reset(). 914 | * 915 | * @param num_threads The parameter passed to the constructor or `reset()`. If the parameter is a positive number, then the pool will be created with this number of threads. If the parameter is non-positive, or a parameter was not supplied (in which case it will have the default value of 0), then the pool will be created with the total number of hardware threads available, as obtained from `std::thread::hardware_concurrency()`. If the latter returns zero for some reason, then the pool will be created with just one thread. 916 | * @return The number of threads to use for constructing the pool. 917 | */ 918 | [[nodiscard]] static concurrency_t determine_thread_count(const concurrency_t num_threads) 919 | { 920 | if (num_threads > 0) 921 | return num_threads; 922 | if (std::thread::hardware_concurrency() > 0) 923 | return std::thread::hardware_concurrency(); 924 | return 1; 925 | } 926 | 927 | /** 928 | * @brief A worker function to be assigned to each thread in the pool. Waits until it is notified by `detach_task()` that a task is available, and then retrieves the task from the queue and executes it. Once the task finishes, the worker notifies `wait()` in case it is waiting. 929 | * 930 | * @param idx The index of this thread. 931 | * @param init_task An initialization function to run in this thread before it starts to execute any submitted tasks. 932 | */ 933 | void worker(const concurrency_t idx, const std::function& init_task) 934 | { 935 | this_thread::get_index.index = idx; 936 | this_thread::get_pool.pool = this; 937 | init_task(); 938 | std::unique_lock tasks_lock(tasks_mutex); 939 | while (true) 940 | { 941 | --tasks_running; 942 | tasks_lock.unlock(); 943 | if (waiting && (tasks_running == 0) && BS_THREAD_POOL_PAUSED_OR_EMPTY) 944 | tasks_done_cv.notify_all(); 945 | tasks_lock.lock(); 946 | task_available_cv.wait(tasks_lock, 947 | [this] 948 | { 949 | return !BS_THREAD_POOL_PAUSED_OR_EMPTY || !workers_running; 950 | }); 951 | if (!workers_running) 952 | break; 953 | { 954 | #ifdef BS_THREAD_POOL_ENABLE_PRIORITY 955 | const std::function task = std::move(std::remove_const_t(tasks.top()).task); 956 | tasks.pop(); 957 | #else 958 | const std::function task = std::move(tasks.front()); 959 | tasks.pop(); 960 | #endif 961 | ++tasks_running; 962 | tasks_lock.unlock(); 963 | task(); 964 | } 965 | tasks_lock.lock(); 966 | } 967 | this_thread::get_index.index = std::nullopt; 968 | this_thread::get_pool.pool = std::nullopt; 969 | } 970 | 971 | // =============== 972 | // Private classes 973 | // =============== 974 | 975 | /** 976 | * @brief A helper class to divide a range into blocks. Used by `detach_blocks()`, `submit_blocks()`, `detach_loop()`, and `submit_loop()`. 977 | * 978 | * @tparam T The type of the indices. Should be a signed or unsigned integer. 979 | */ 980 | template 981 | class [[nodiscard]] blocks 982 | { 983 | public: 984 | /** 985 | * @brief Construct a `blocks` object with the given specifications. 986 | * 987 | * @param first_index_ The first index in the range. 988 | * @param index_after_last_ The index after the last index in the range. 989 | * @param num_blocks_ The desired number of blocks to divide the range into. 990 | */ 991 | blocks(const T first_index_, const T index_after_last_, const size_t num_blocks_) : first_index(first_index_), index_after_last(index_after_last_), num_blocks(num_blocks_) 992 | { 993 | if (index_after_last > first_index) 994 | { 995 | const size_t total_size = static_cast(index_after_last - first_index); 996 | if (num_blocks > total_size) 997 | num_blocks = total_size; 998 | block_size = total_size / num_blocks; 999 | remainder = total_size % num_blocks; 1000 | if (block_size == 0) 1001 | { 1002 | block_size = 1; 1003 | num_blocks = (total_size > 1) ? total_size : 1; 1004 | } 1005 | } 1006 | else 1007 | { 1008 | num_blocks = 0; 1009 | } 1010 | } 1011 | 1012 | /** 1013 | * @brief Get the first index of a block. 1014 | * 1015 | * @param block The block number. 1016 | * @return The first index. 1017 | */ 1018 | [[nodiscard]] T start(const size_t block) const 1019 | { 1020 | return first_index + static_cast(block * block_size) + static_cast(block < remainder ? block : remainder); 1021 | } 1022 | 1023 | /** 1024 | * @brief Get the index after the last index of a block. 1025 | * 1026 | * @param block The block number. 1027 | * @return The index after the last index. 1028 | */ 1029 | [[nodiscard]] T end(const size_t block) const 1030 | { 1031 | return (block == num_blocks - 1) ? index_after_last : start(block + 1); 1032 | } 1033 | 1034 | /** 1035 | * @brief Get the number of blocks. Note that this may be different than the desired number of blocks that was passed to the constructor. 1036 | * 1037 | * @return The number of blocks. 1038 | */ 1039 | [[nodiscard]] size_t get_num_blocks() const 1040 | { 1041 | return num_blocks; 1042 | } 1043 | 1044 | private: 1045 | /** 1046 | * @brief The size of each block (except possibly the last block). 1047 | */ 1048 | size_t block_size = 0; 1049 | 1050 | /** 1051 | * @brief The first index in the range. 1052 | */ 1053 | T first_index = 0; 1054 | 1055 | /** 1056 | * @brief The index after the last index in the range. 1057 | */ 1058 | T index_after_last = 0; 1059 | 1060 | /** 1061 | * @brief The number of blocks. 1062 | */ 1063 | size_t num_blocks = 0; 1064 | 1065 | /** 1066 | * @brief The remainder obtained after dividing the total size by the number of blocks. 1067 | */ 1068 | size_t remainder = 0; 1069 | }; // class blocks 1070 | 1071 | #ifdef BS_THREAD_POOL_ENABLE_PRIORITY 1072 | /** 1073 | * @brief A helper class to store a task with an assigned priority. 1074 | */ 1075 | class [[nodiscard]] pr_task 1076 | { 1077 | friend class thread_pool; 1078 | 1079 | public: 1080 | /** 1081 | * @brief Construct a new task with an assigned priority by copying the task. 1082 | * 1083 | * @param task_ The task. 1084 | * @param priority_ The desired priority. 1085 | */ 1086 | explicit pr_task(const std::function& task_, const priority_t priority_ = 0) : task(task_), priority(priority_) {} 1087 | 1088 | /** 1089 | * @brief Construct a new task with an assigned priority by moving the task. 1090 | * 1091 | * @param task_ The task. 1092 | * @param priority_ The desired priority. 1093 | */ 1094 | explicit pr_task(std::function&& task_, const priority_t priority_ = 0) : task(std::move(task_)), priority(priority_) {} 1095 | 1096 | /** 1097 | * @brief Compare the priority of two tasks. 1098 | * 1099 | * @param lhs The first task. 1100 | * @param rhs The second task. 1101 | * @return `true` if the first task has a lower priority than the second task, `false` otherwise. 1102 | */ 1103 | [[nodiscard]] friend bool operator<(const pr_task& lhs, const pr_task& rhs) 1104 | { 1105 | return lhs.priority < rhs.priority; 1106 | } 1107 | 1108 | private: 1109 | /** 1110 | * @brief The task. 1111 | */ 1112 | std::function task = {}; 1113 | 1114 | /** 1115 | * @brief The priority of the task. 1116 | */ 1117 | priority_t priority = 0; 1118 | }; // class pr_task 1119 | #endif 1120 | 1121 | // ============ 1122 | // Private data 1123 | // ============ 1124 | 1125 | #ifdef BS_THREAD_POOL_ENABLE_PAUSE 1126 | /** 1127 | * @brief A flag indicating whether the workers should pause. When set to `true`, the workers temporarily stop retrieving new tasks out of the queue, although any tasks already executed will keep running until they are finished. When set to `false` again, the workers resume retrieving tasks. 1128 | */ 1129 | bool paused = false; 1130 | #endif 1131 | 1132 | /** 1133 | * @brief A condition variable to notify `worker()` that a new task has become available. 1134 | */ 1135 | std::condition_variable task_available_cv = {}; 1136 | 1137 | /** 1138 | * @brief A condition variable to notify `wait()` that the tasks are done. 1139 | */ 1140 | std::condition_variable tasks_done_cv = {}; 1141 | 1142 | /** 1143 | * @brief A queue of tasks to be executed by the threads. 1144 | */ 1145 | #ifdef BS_THREAD_POOL_ENABLE_PRIORITY 1146 | std::priority_queue tasks = {}; 1147 | #else 1148 | std::queue> tasks = {}; 1149 | #endif 1150 | 1151 | /** 1152 | * @brief A counter for the total number of currently running tasks. 1153 | */ 1154 | size_t tasks_running = 0; 1155 | 1156 | /** 1157 | * @brief A mutex to synchronize access to the task queue by different threads. 1158 | */ 1159 | mutable std::mutex tasks_mutex = {}; 1160 | 1161 | /** 1162 | * @brief The number of threads in the pool. 1163 | */ 1164 | concurrency_t thread_count = 0; 1165 | 1166 | /** 1167 | * @brief A smart pointer to manage the memory allocated for the threads. 1168 | */ 1169 | std::unique_ptr threads = nullptr; 1170 | 1171 | /** 1172 | * @brief A flag indicating that `wait()` is active and expects to be notified whenever a task is done. 1173 | */ 1174 | bool waiting = false; 1175 | 1176 | /** 1177 | * @brief A flag indicating to the workers to keep running. When set to `false`, the workers terminate permanently. 1178 | */ 1179 | bool workers_running = false; 1180 | }; // class thread_pool 1181 | } // namespace BS 1182 | #endif -------------------------------------------------------------------------------- /include/uring.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include "asyncio.h" 5 | 6 | class UringAsyncIO : public AsyncIO 7 | { 8 | private: 9 | unsigned int n_write_events, n_read_events; 10 | unsigned int n_entries; 11 | io_uring ring; 12 | 13 | void get_event(WaitType wt); 14 | 15 | public: 16 | UringAsyncIO(unsigned int n_entries, unsigned int n_tasks); 17 | ~UringAsyncIO(); 18 | 19 | void write(int fd, void *buffer, size_t n_bytes, unsigned long long offset, callback_t callback); 20 | void read(int fd, void *buffer, size_t n_bytes, unsigned long long offset, callback_t callback); 21 | void writev(int fd, const iovec *iov, unsigned int iovcnt, unsigned long long offset, callback_t callback); 22 | void readv(int fd, const iovec *iov, unsigned int iovcnt, unsigned long long offset, callback_t callback); 23 | 24 | void register_h2d(unsigned int num_tensors); 25 | void sync_h2d(); 26 | void sync_write_events(); 27 | void sync_read_events(); 28 | void synchronize(); 29 | 30 | void register_file(int fd); 31 | void write_tensor(int fd, torch::Tensor t, unsigned long long offset, callback_t callback, std::optional pinned); 32 | void register_tasks(unsigned int num_tasks); 33 | }; -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | packaging 2 | click 3 | torch 4 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import sys 4 | from platform import uname 5 | from subprocess import call 6 | from typing import List 7 | 8 | from packaging import version 9 | from setuptools import find_packages, setup 10 | 11 | TENSORNVME_INITIALIZE_RE_BLOCK = ( 12 | r"^# >>> tensornvme initialize >>>(?:\n|\r\n)" r"([\s\S]*?)" r"# <<< tensornvme initialize <<<(?:\n|\r\n)?" 13 | ) 14 | 15 | 16 | def check_uring_compatibility(): 17 | uname_info = uname() 18 | if uname_info.system != "Linux": 19 | raise RuntimeError("Only Linux is supported") 20 | kernel_version = version.parse(uname_info.release.split("-")[0]) 21 | return kernel_version >= version.parse("5.10") 22 | 23 | 24 | def check_pthread_compatibility(): 25 | uname_info = uname() 26 | if uname_info.system != "Linux": 27 | raise RuntimeError("Only Linux is supported") 28 | return True 29 | 30 | 31 | this_dir = os.path.dirname(os.path.abspath(__file__)) 32 | backend_install_dir = os.path.join(os.path.expanduser("~"), ".tensornvme") 33 | 34 | debug = os.environ.get("DEBUG", "0") == "1" 35 | enable_uring = True 36 | enable_aio = True 37 | enable_pthread = True 38 | if os.environ.get("DISABLE_URING") == "1" or not check_uring_compatibility(): 39 | enable_uring = False 40 | if os.environ.get("DISABLE_AIO") == "1": 41 | enable_aio = False 42 | if os.environ.get("DISABLE_PTHREAD") == "1" or not check_pthread_compatibility(): 43 | enable_pthread = False 44 | 45 | assert enable_aio or enable_uring or enable_pthread 46 | if os.environ.get("WITH_ROOT") == "1": 47 | backend_install_dir = "/usr" 48 | if not os.access(backend_install_dir, os.W_OK): 49 | raise RuntimeError("Permission denied, please make sure you have root access") 50 | 51 | libraries = ["aio"] 52 | sources = [ 53 | "csrc/offload.cpp", 54 | "csrc/uring.cpp", 55 | "csrc/aio.cpp", 56 | "csrc/space_mgr.cpp", 57 | "csrc/backend.cpp", 58 | "csrc/async_file_io.cpp", 59 | "csrc/py_api.cpp", 60 | "csrc/pthread_backend.cpp", 61 | ] 62 | extra_objects = [] 63 | define_macros = [] 64 | ext_modules = [] 65 | cmdclass = {} 66 | 67 | 68 | def cpp_ext_helper(name, sources, **kwargs): 69 | from torch.utils.cpp_extension import CUDA_HOME, CppExtension 70 | 71 | extra_include_dirs = [] 72 | 73 | if CUDA_HOME is not None: 74 | extra_include_dirs.append(os.path.join(CUDA_HOME, "include")) 75 | 76 | if "C_INCLUDE_PATH" in os.environ: 77 | extra_include_dirs.extend(os.environ["C_INCLUDE_PATH"].split(":")) 78 | if "CPLUS_INCLUDE_PATH" in os.environ: 79 | extra_include_dirs.extend(os.environ["CPLUS_INCLUDE_PATH"].split(":")) 80 | extra_include_dirs = list(filter(lambda s: len(s) > 0, set(extra_include_dirs))) 81 | return CppExtension( 82 | name, 83 | [os.path.join(this_dir, path) for path in sources], 84 | include_dirs=[ 85 | os.path.join(this_dir, "csrc"), 86 | os.path.join(this_dir, "include"), 87 | os.path.join(backend_install_dir, "include"), 88 | *extra_include_dirs, 89 | ], 90 | library_dirs=[os.path.join(backend_install_dir, "lib")], 91 | **kwargs, 92 | ) 93 | 94 | 95 | def find_static_lib(lib_name: str, lib_paths: List[str] = []) -> str: 96 | static_lib_name = f"lib{lib_name}.a" 97 | lib_paths.extend(["/usr/lib", "/usr/lib64", "/usr/lib/x86_64-linux-gnu/"]) 98 | if os.environ.get("LIBRARY_PATH", None) is not None: 99 | lib_paths.extend(os.environ["LIBRARY_PATH"].split(":")) 100 | for lib_dir in lib_paths: 101 | if os.path.isdir(lib_dir): 102 | for filename in os.listdir(lib_dir): 103 | if filename == static_lib_name: 104 | return os.path.join(lib_dir, static_lib_name) 105 | raise RuntimeError(f"{static_lib_name} is not found in {lib_paths}") 106 | 107 | 108 | def setup_bachrc(): 109 | init_rules = f'export LD_LIBRARY_PATH="{backend_install_dir}/lib:$LD_LIBRARY_PATH"' 110 | bachrc_path = os.path.join(os.path.expanduser("~"), ".bashrc") 111 | with open(bachrc_path, "a+") as f: 112 | f.seek(0) 113 | rules = f.read() 114 | if not re.search(TENSORNVME_INITIALIZE_RE_BLOCK, rules, flags=re.DOTALL | re.MULTILINE): 115 | f.write(f"# >>> tensornvme initialize >>>\n{init_rules}\n# <<< tensornvme initialize <<<\n") 116 | print(f"{bachrc_path} is changed, please source it.") 117 | 118 | 119 | def setup_dependencies(): 120 | build_dir = os.path.join(this_dir, "cmake-build") 121 | if debug: 122 | define_macros.append(("DEBUG", None)) 123 | if not enable_uring: 124 | define_macros.append(("DISABLE_URING", None)) 125 | sources.remove("csrc/uring.cpp") 126 | if not enable_aio: 127 | define_macros.append(("DISABLE_AIO", None)) 128 | sources.remove("csrc/aio.cpp") 129 | libraries.remove("aio") 130 | os.makedirs(build_dir, exist_ok=True) 131 | os.makedirs(backend_install_dir, exist_ok=True) 132 | os.chdir(build_dir) 133 | call(["cmake", "..", f"-DBACKEND_INSTALL_PREFIX={backend_install_dir}"]) 134 | if enable_uring: 135 | call(["make", "extern_uring"]) 136 | extra_objects.append(find_static_lib("uring", [os.path.join(backend_install_dir, "lib")])) 137 | if enable_aio: 138 | call(["make", "extern_aio"]) 139 | os.chdir(this_dir) 140 | if os.environ.get("WITH_ROOT") != "1": 141 | setup_bachrc() 142 | 143 | 144 | if sys.argv[1] in ("install", "develop", "bdist_wheel"): 145 | setup_dependencies() 146 | from torch.utils.cpp_extension import BuildExtension 147 | 148 | ext_modules.append( 149 | cpp_ext_helper( 150 | "tensornvme._C", sources, extra_objects=extra_objects, libraries=libraries, define_macros=define_macros 151 | ) 152 | ) 153 | cmdclass["build_ext"] = BuildExtension 154 | 155 | 156 | def get_version(): 157 | with open("version.txt") as f: 158 | version = f.read().strip() 159 | return version 160 | 161 | 162 | def fetch_requirements(path): 163 | with open(path, "r") as fd: 164 | return [r.strip() for r in fd.readlines()] 165 | 166 | 167 | def fetch_readme(): 168 | with open("README.md", encoding="utf-8") as f: 169 | return f.read() 170 | 171 | 172 | setup( 173 | name="tensornvme", 174 | version=get_version(), 175 | packages=find_packages(exclude=("3rd", "csrc", "tests", "include", "*.egg-info")), 176 | ext_modules=ext_modules, 177 | cmdclass=cmdclass, 178 | entry_points={"console_scripts": ["tensornvme=tensornvme.cli:cli"]}, 179 | description="A tensor disk offloader without data copying.", 180 | long_description=fetch_readme(), 181 | long_description_content_type="text/markdown", 182 | license="Apache Software License 2.0", 183 | install_requires=fetch_requirements("requirements.txt"), 184 | python_requires=">=3.6", 185 | classifiers=[ 186 | "Programming Language :: Python :: 3", 187 | "License :: OSI Approved :: Apache Software License", 188 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 189 | ], 190 | ) 191 | -------------------------------------------------------------------------------- /tensornvme/_C/__init__.pyi: -------------------------------------------------------------------------------- 1 | from typing import Callable, List, Optional, Set 2 | 3 | from torch import Tensor 4 | 5 | class Offloader: 6 | def __init__(self, filename: str, n_entries: int, backend: str = "aio") -> None: ... 7 | def async_write(self, tensor: Tensor, key: str, callback: Optional[Callable[[], None]] = None) -> None: ... 8 | def async_read(self, tensor: Tensor, key: str, callback: Optional[Callable[[], None]] = None) -> None: ... 9 | def sync_write(self, tensor: Tensor, key: str) -> None: ... 10 | def sync_read(self, tensor: Tensor, key: str) -> None: ... 11 | def sync_write_events(self) -> None: ... 12 | def sync_read_events(self) -> None: ... 13 | def synchronize(self) -> None: ... 14 | def async_writev(self, tensors: List[Tensor], key: str, callback: Optional[Callable[[], None]] = None) -> None: ... 15 | def async_readv(self, tensors: List[Tensor], key: str, callback: Optional[Callable[[], None]] = None) -> None: ... 16 | def sync_writev(self, tensors: List[Tensor], key: str) -> None: ... 17 | def sync_readv(self, tensors: List[Tensor], key: str) -> None: ... 18 | 19 | def get_backends() -> Set[str]: ... 20 | def probe_backend(backend: str) -> bool: ... 21 | 22 | class AsyncFileWriter: 23 | def __init__(self, fd: int, n_entries: int, backend: str = "aio", n_tasks: int = 0) -> None: ... 24 | def write(self, buffer: int, n_bytes: int, offset: int, callback: Optional[Callable[[], None]] = None) -> None: ... 25 | def write_tensor( 26 | self, 27 | tensor: Tensor, 28 | offset: int, 29 | callback: Optional[Callable[[], None]] = None, 30 | pinned: Optional[Tensor] = None, 31 | ) -> None: ... 32 | def synchronize(self) -> None: ... 33 | def register_tasks(self, num_tasks: int) -> None: ... 34 | -------------------------------------------------------------------------------- /tensornvme/__init__.py: -------------------------------------------------------------------------------- 1 | from .offload import DiskOffloader 2 | 3 | __all__ = [ 4 | 'DiskOffloader' 5 | ] 6 | -------------------------------------------------------------------------------- /tensornvme/async_file_io.py: -------------------------------------------------------------------------------- 1 | import ctypes 2 | import os 3 | from functools import partial 4 | from typing import List, Optional 5 | 6 | import torch 7 | from torch import Tensor 8 | 9 | from tensornvme._C import AsyncFileWriter as AsyncFileWriterC 10 | 11 | 12 | class AsyncFileWriter: 13 | def __init__(self, path: str, n_entries: int = 16, backend=None, n_tasks: int = 0) -> None: 14 | # this still takes ram buffer, which may lead to OOM 15 | # self.f = open(path, "wb", buffering=0) 16 | self.fd = os.open(path, os.O_WRONLY | os.O_CREAT, mode=0o664) 17 | if backend is not None: 18 | self.io = AsyncFileWriterC(self.fd, n_entries, backend=backend, n_tasks=n_tasks) 19 | else: 20 | self.io = AsyncFileWriterC(self.fd, n_entries, n_tasks=n_tasks) 21 | self.offset = 0 22 | # must ensure the data is not garbage collected 23 | self.buffers = [] 24 | self.comm_stream = torch.cuda.Stream() 25 | 26 | def write(self, data: bytes) -> int: 27 | ptr = ctypes.cast(data, ctypes.POINTER(ctypes.c_char)) 28 | addr = ctypes.addressof(ptr.contents) 29 | self.buffers.append(data) # append before callback is called 30 | self.io.write( 31 | addr, len(data), self.offset, partial(AsyncFileWriter.gc_callback, self.buffers, len(self.buffers) - 1) 32 | ) 33 | self.offset += len(data) 34 | 35 | return len(data) 36 | 37 | def write_raw(self, py_ref: object, buffer: int, n_bytes: int, offset: int) -> None: 38 | self.buffers.append(py_ref) # append before callback is called 39 | self.io.write( 40 | buffer, n_bytes, offset, partial(AsyncFileWriter.gc_callback, self.buffers, len(self.buffers) - 1) 41 | ) 42 | self.offset += n_bytes 43 | 44 | def write_tensor(self, tensor: Tensor, pinned: Optional[Tensor] = None) -> None: 45 | with torch.cuda.stream(self.comm_stream): 46 | self.buffers.append(tensor) # append before callback is called 47 | self.io.write_tensor( 48 | tensor, self.offset, partial(AsyncFileWriter.gc_callback, self.buffers, len(self.buffers) - 1), pinned 49 | ) 50 | self.offset += tensor.numel() * tensor.element_size() 51 | 52 | def register_h2d(self, num_tensors: int) -> None: 53 | self.io.register_h2d(num_tensors) 54 | 55 | def sync_before_step(self): 56 | self.io.sync_h2d() 57 | 58 | @staticmethod 59 | def gc_callback(listt: List, idx: int) -> None: 60 | listt[idx] = None 61 | 62 | def flush(self) -> None: 63 | pass 64 | 65 | def synchronize(self) -> None: 66 | self.io.synchronize() 67 | self.buffers.clear() 68 | 69 | def __del__(self) -> None: 70 | self.synchronize() 71 | os.close(self.fd) 72 | 73 | def register_tasks(self, num_tasks: int) -> None: 74 | self.io.register_tasks(num_tasks) 75 | -------------------------------------------------------------------------------- /tensornvme/cli/__init__.py: -------------------------------------------------------------------------------- 1 | from .cli import cli 2 | 3 | __all__ = ['cli'] 4 | -------------------------------------------------------------------------------- /tensornvme/cli/check.py: -------------------------------------------------------------------------------- 1 | import click 2 | from tensornvme._C import probe_backend, get_backends 3 | 4 | 5 | def check_backend(backend: str): 6 | if backend not in get_backends(): 7 | click.echo(f'Invalid backend: {backend}') 8 | return 9 | status = 'x' 10 | if probe_backend(backend): 11 | status = u'\u2713' 12 | click.echo(f'{backend}: {status}') 13 | 14 | 15 | @click.command(help='Check if backends are available') 16 | @click.option('--backend', type=click.Choice(['all', *get_backends()]), default='all') 17 | def check(backend: str): 18 | click.echo('Check backends:') 19 | if backend == 'all': 20 | for backend in get_backends(): 21 | check_backend(backend) 22 | else: 23 | check_backend(backend) 24 | -------------------------------------------------------------------------------- /tensornvme/cli/cli.py: -------------------------------------------------------------------------------- 1 | import click 2 | from .check import check 3 | 4 | 5 | @click.group() 6 | def cli(): 7 | pass 8 | 9 | 10 | cli.add_command(check) 11 | 12 | if __name__ == '__main__': 13 | cli() 14 | -------------------------------------------------------------------------------- /tensornvme/offload.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import uuid 4 | from typing import Callable, Optional, List 5 | from tensornvme._C import Offloader, get_backends 6 | 7 | 8 | class DiskOffloader(Offloader): 9 | def __init__(self, dir_name: str, n_entries: int = 16, backend: str = 'uring') -> None: 10 | if not os.path.exists(dir_name): 11 | os.mkdir(dir_name) 12 | assert os.path.isdir(dir_name) 13 | filename = os.path.join(dir_name, f'offload-{uuid.uuid4().hex}') 14 | while os.path.exists(filename): 15 | filename = os.path.join(dir_name, f'offload-{uuid.uuid4().hex}') 16 | super().__init__(filename, n_entries, backend) 17 | 18 | def async_write(self, tensor: torch.Tensor, callback: Optional[Callable[[], None]] = None) -> None: 19 | assert tensor.storage().size() > 0 20 | 21 | def callback_fn(): 22 | tensor.storage().resize_(0) 23 | if callback is not None: 24 | callback() 25 | super().async_write(tensor, str(id(tensor)), callback_fn) 26 | 27 | def async_read(self, tensor: torch.Tensor, callback: Optional[Callable[[], None]] = None) -> None: 28 | if tensor.storage().size() == 0: 29 | tensor.storage().resize_(tensor.numel()) 30 | super().async_read(tensor, str(id(tensor)), callback) 31 | 32 | def sync_write(self, tensor: torch.Tensor) -> None: 33 | assert tensor.storage().size() > 0 34 | super().sync_write(tensor, str(id(tensor))) 35 | tensor.storage().resize_(0) 36 | 37 | def sync_read(self, tensor: torch.Tensor) -> None: 38 | if tensor.storage().size() == 0: 39 | tensor.storage().resize_(tensor.numel()) 40 | super().sync_read(tensor, str(id(tensor))) 41 | 42 | def async_writev(self, tensors: List[torch.Tensor], callback: Optional[Callable[[], None]] = None) -> None: 43 | for tensor in tensors: 44 | assert tensor.storage().size() > 0 45 | key = str(hash(tuple(tensors))) 46 | 47 | def callback_fn(): 48 | for tensor in tensors: 49 | tensor.storage().resize_(0) 50 | if callback is not None: 51 | callback() 52 | super().async_writev(tensors, key, callback_fn) 53 | 54 | def async_readv(self, tensors: List[torch.Tensor], callback: Optional[Callable[[], None]] = None) -> None: 55 | for tensor in tensors: 56 | if tensor.storage().size() == 0: 57 | tensor.storage().resize_(tensor.numel()) 58 | key = str(hash(tuple(tensors))) 59 | super().async_readv(tensors, key, callback) 60 | 61 | def sync_writev(self, tensors: List[torch.Tensor]) -> None: 62 | for tensor in tensors: 63 | assert tensor.storage().size() > 0 64 | key = str(hash(tuple(tensors))) 65 | super().sync_writev(tensors, key) 66 | for tensor in tensors: 67 | tensor.storage().resize_(0) 68 | 69 | def sync_readv(self, tensors: List[torch.Tensor]) -> None: 70 | for tensor in tensors: 71 | if tensor.storage().size() == 0: 72 | tensor.storage().resize_(tensor.numel()) 73 | key = str(hash(tuple(tensors))) 74 | super().sync_readv(tensors, key) 75 | -------------------------------------------------------------------------------- /tests/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_executable(test_asyncio 2 | test_asyncio.cpp) 3 | target_link_libraries(test_asyncio colo_asyncio) 4 | target_include_directories(test_asyncio INTERFACE .) 5 | add_test(NAME test_asyncio COMMAND test_asyncio) 6 | 7 | 8 | add_executable(test_space_mgr 9 | test_space_mgr.cpp) 10 | target_link_libraries(test_space_mgr space_mgr) 11 | target_include_directories(test_space_mgr INTERFACE .) 12 | add_test(NAME test_space_mgr COMMAND test_space_mgr) 13 | -------------------------------------------------------------------------------- /tests/requirements.txt: -------------------------------------------------------------------------------- 1 | pytest 2 | transformers -------------------------------------------------------------------------------- /tests/test_adam.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import math 3 | from typing import Optional 4 | 5 | import torch 6 | from torch import nn 7 | from transformers import GPT2Config, GPT2LMHeadModel 8 | 9 | from tensornvme import DiskOffloader 10 | 11 | 12 | class GPTLMModel(nn.Module): 13 | def __init__(self, hidden_size=768, num_layers=12, num_attention_heads=12, max_seq_len=1024, 14 | vocab_size=50257, checkpoint=False): 15 | super().__init__() 16 | self.checkpoint = checkpoint 17 | self.model = GPT2LMHeadModel(GPT2Config(n_embd=hidden_size, n_layer=num_layers, n_head=num_attention_heads, 18 | n_positions=max_seq_len, n_ctx=max_seq_len, vocab_size=vocab_size)) 19 | if checkpoint: 20 | self.model.gradient_checkpointing_enable() 21 | 22 | def forward(self, input_ids, attention_mask): 23 | # Only return lm_logits 24 | return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=not self.checkpoint)[0] 25 | 26 | 27 | def gpt2_toy(): 28 | return GPTLMModel(hidden_size=8, num_layers=2, num_attention_heads=2, checkpoint=False) 29 | 30 | 31 | def adam(step, lr, param, grad, exp_avg, exp_avg_sq, beta1=0.9, beta2=0.999, eps=1e-12): 32 | exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) 33 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2) 34 | 35 | bias_correction1 = 1 - beta1 ** step 36 | bias_correction2 = 1 - beta2 ** step 37 | step_size = lr / bias_correction1 38 | bias_correction2_sqrt = math.sqrt(bias_correction2) 39 | denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) 40 | param.addcdiv_(exp_avg, denom, value=-step_size) 41 | 42 | 43 | class NVMEAdam(torch.optim.Optimizer): 44 | def __init__(self, params, lr, betas=(0.9, 0.999), 45 | offloader: Optional[DiskOffloader] = None, prefetch: int = 0, vecio: bool = False) -> None: 46 | default = dict(lr=lr, betas=betas) 47 | super().__init__(params, default) 48 | self.offloader = offloader 49 | self.prefetch = prefetch 50 | self.vecio = vecio 51 | self.param_to_group = {} 52 | # init states 53 | for group in self.param_groups: 54 | for p in group['params']: 55 | if p.requires_grad: 56 | self.param_to_group[p] = group 57 | state = self.state[p] 58 | state['step'] = 0 59 | state['exp_avg'] = torch.zeros_like(p) 60 | state['exp_avg_sq'] = torch.zeros_like(p) 61 | if self.offloader is None: 62 | continue 63 | if vecio: 64 | self.offloader.sync_writev( 65 | [state['exp_avg'], state['exp_avg_sq']]) 66 | else: 67 | self.offloader.sync_write(state['exp_avg']) 68 | self.offloader.sync_write(state['exp_avg_sq']) 69 | 70 | def step(self, closure=None): 71 | loss = None 72 | if closure is not None: 73 | with torch.enable_grad(): 74 | loss = closure() 75 | 76 | params = [ 77 | p for group in self.param_groups for p in group['params'] if p.grad is not None] 78 | if self.offloader is not None and self.prefetch > 0: 79 | for p in params[:self.prefetch]: 80 | state = self.state[p] 81 | if self.vecio: 82 | self.offloader.sync_readv( 83 | [state['exp_avg'], state['exp_avg_sq']]) 84 | else: 85 | self.offloader.sync_read(state['exp_avg']) 86 | self.offloader.sync_read(state['exp_avg_sq']) 87 | 88 | for i, p in enumerate(params): 89 | state = self.state[p] 90 | group = self.param_to_group[p] 91 | state['step'] += 1 92 | beta1, beta2 = group['betas'] 93 | self._pre_step(i, params) 94 | adam(state['step'], group['lr'], p, p.grad, state['exp_avg'], 95 | state['exp_avg_sq'], beta1=beta1, beta2=beta2) 96 | self._post_step(i, params) 97 | 98 | return loss 99 | 100 | def _pre_step(self, idx, params): 101 | if self.offloader is None: 102 | return 103 | if self.prefetch > 0: 104 | if idx % self.prefetch == 0: 105 | self.offloader.sync_read_events() 106 | if idx + self.prefetch < len(params): 107 | for prefetch_p in params[idx + self.prefetch:idx + self.prefetch * 2]: 108 | prefetch_state = self.state[prefetch_p] 109 | if self.vecio: 110 | self.offloader.async_readv( 111 | [prefetch_state['exp_avg'], prefetch_state['exp_avg_sq']]) 112 | else: 113 | self.offloader.async_read( 114 | prefetch_state['exp_avg']) 115 | self.offloader.async_read( 116 | prefetch_state['exp_avg_sq']) 117 | else: 118 | state = self.state[params[idx]] 119 | if self.vecio: 120 | self.offloader.sync_readv( 121 | [state['exp_avg'], state['exp_avg_sq']]) 122 | else: 123 | self.offloader.sync_read(state['exp_avg']) 124 | self.offloader.sync_read(state['exp_avg_sq']) 125 | 126 | def _post_step(self, idx, params): 127 | if self.offloader is None: 128 | return 129 | state = self.state[params[idx]] 130 | if self.prefetch > 0: 131 | if idx % self.prefetch == 0: 132 | self.offloader.sync_write_events() 133 | if self.vecio: 134 | self.offloader.async_writev( 135 | [state['exp_avg'], state['exp_avg_sq']]) 136 | else: 137 | self.offloader.async_write(state['exp_avg']) 138 | self.offloader.async_write(state['exp_avg_sq']) 139 | else: 140 | if self.vecio: 141 | self.offloader.sync_writev( 142 | [state['exp_avg'], state['exp_avg_sq']]) 143 | else: 144 | self.offloader.sync_write(state['exp_avg']) 145 | self.offloader.sync_write(state['exp_avg_sq']) 146 | 147 | 148 | class Adam(torch.optim.Optimizer): 149 | def __init__(self, params, lr, betas=(0.9, 0.999)) -> None: 150 | default = dict(lr=lr, betas=betas) 151 | super().__init__(params, default) 152 | self.param_to_group = {} 153 | # init states 154 | for group in self.param_groups: 155 | for p in group['params']: 156 | if p.requires_grad: 157 | self.param_to_group[p] = group 158 | state = self.state[p] 159 | state['step'] = 0 160 | state['exp_avg'] = torch.zeros_like(p) 161 | state['exp_avg_sq'] = torch.zeros_like(p) 162 | 163 | def step(self, closure=None): 164 | loss = None 165 | if closure is not None: 166 | with torch.enable_grad(): 167 | loss = closure() 168 | 169 | params = [ 170 | p for group in self.param_groups for p in group['params'] if p.grad is not None] 171 | 172 | for i, p in enumerate(params): 173 | state = self.state[p] 174 | group = self.param_to_group[p] 175 | state['step'] += 1 176 | beta1, beta2 = group['betas'] 177 | adam(state['step'], group['lr'], p, p.grad, state['exp_avg'], 178 | state['exp_avg_sq'], beta1=beta1, beta2=beta2) 179 | 180 | return loss 181 | 182 | 183 | @torch.no_grad() 184 | def test_adam(): 185 | params = list(gpt2_toy().cpu().parameters()) 186 | for _, p in enumerate(params): 187 | if p.grad is None and p.requires_grad: 188 | p.grad = torch.ones_like(p.data, dtype=torch.float) * 0.12345 189 | 190 | params_gt = copy.deepcopy(params) 191 | for _, p in enumerate(params_gt): 192 | if p.grad is None and p.requires_grad: 193 | p.grad = torch.ones_like(p.data, dtype=torch.float) * 0.12345 194 | optimizer = Adam(params_gt, 1e-3) 195 | optimizer.step() 196 | 197 | test_config = [ 198 | {'n_entries': 1, 'backend': None, 'prefetch': 0, 'vecio': False}, 199 | 200 | {'n_entries': 1, 'backend': 'uring', 'prefetch': 0, 'vecio': False}, 201 | {'n_entries': 8, 'backend': 'uring', 'prefetch': 2, 'vecio': False}, 202 | 203 | {'n_entries': 1, 'backend': 'uring', 'prefetch': 0, 'vecio': True}, 204 | {'n_entries': 8, 'backend': 'uring', 'prefetch': 2, 'vecio': True}, 205 | 206 | {'n_entries': 1, 'backend': 'aio', 'prefetch': 0, 'vecio': False}, 207 | {'n_entries': 8, 'backend': 'aio', 'prefetch': 2, 'vecio': False}, 208 | 209 | {'n_entries': 1, 'backend': 'aio', 'prefetch': 0, 'vecio': True}, 210 | {'n_entries': 8, 'backend': 'aio', 'prefetch': 2, 'vecio': True}, 211 | 212 | {'n_entries': 1, 'backend': 'pthread', 'prefetch': 0, 'vecio': False}, 213 | {'n_entries': 8, 'backend': 'pthread', 'prefetch': 2, 'vecio': False}, 214 | 215 | {'n_entries': 1, 'backend': 'pthread', 'prefetch': 0, 'vecio': True}, 216 | {'n_entries': 8, 'backend': 'pthread', 'prefetch': 2, 'vecio': True}, 217 | ] 218 | 219 | for i, cfg in enumerate(test_config): 220 | params_test = copy.deepcopy(params) 221 | for _, p in enumerate(params_test): 222 | if p.grad is None and p.requires_grad: 223 | p.grad = torch.ones_like(p.data, dtype=torch.float) * 0.12345 224 | if cfg['backend'] is None: 225 | offloader = None 226 | else: 227 | offloader = DiskOffloader( 228 | '.', cfg['n_entries'], backend=cfg['backend']) 229 | optimizer_test = NVMEAdam( 230 | params_test, 1e-3, offloader=offloader, prefetch=cfg['prefetch'], vecio=cfg['vecio']) 231 | optimizer_test.step() 232 | 233 | for p1, p2, p3 in zip(params_gt, params_test, params): 234 | assert torch.equal(p1, p2) 235 | assert not torch.equal(p1, p3) 236 | 237 | 238 | if __name__ == '__main__': 239 | test_adam() 240 | -------------------------------------------------------------------------------- /tests/test_asyncio.cpp: -------------------------------------------------------------------------------- 1 | #define CATCH_CONFIG_MAIN 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include "aio.h" 9 | #include "catch.hpp" 10 | #include "uring.h" 11 | using namespace std; 12 | 13 | void callback_n(int &x) 14 | { 15 | x++; 16 | } 17 | 18 | void callback_empty() 19 | { 20 | } 21 | 22 | TEST_CASE( "Test async io fucntion of libaio and liburing") { 23 | 24 | AsyncIO *aios[] = { 25 | new AIOAsyncIO(1), 26 | new AIOAsyncIO(8), 27 | new AIOAsyncIO(16), 28 | new UringAsyncIO(1), 29 | new UringAsyncIO(8), 30 | new UringAsyncIO(16), 31 | }; 32 | auto aio_idx = GENERATE(range(0, 6)); 33 | auto aio = aios[aio_idx]; 34 | string aio_str = to_string(aio_idx); 35 | 36 | SECTION("read and write double array to a file" + aio_str) { 37 | int fd = open("./test.txt", O_RDWR | O_CREAT, S_IRUSR | S_IWUSR); 38 | const int n_loop = 5, n_len = 18; 39 | 40 | double data1[n_loop][n_len]; 41 | int n = 0, offset = 0; 42 | size_t len; 43 | for (int i = 0; i < n_loop; i++) { 44 | auto fn = std::bind(callback_n, std::ref(n)); 45 | //len = strlen(data1[i]) + 1; 46 | len = n_len * sizeof(double); 47 | aio->write(fd, data1[i], len, offset, fn); 48 | offset += len; 49 | } 50 | aio->sync_write_events(); 51 | 52 | double data2[n_loop][n_len]; 53 | n = 0, offset = 0; 54 | for (int i = 0; i < n_loop; i++) { 55 | auto fn = std::bind(callback_n, std::ref(n)); 56 | len = n_len * sizeof(double); 57 | aio->read(fd, data2[i], len, offset, fn); 58 | offset += len; 59 | } 60 | aio->sync_read_events(); 61 | for (int i = 0; i < n_loop; i++) { 62 | for (int j = 0; j < n_len; j++) { 63 | REQUIRE(data1[i][j] == Approx(data2[i][j]).epsilon(0.001)); 64 | } 65 | } 66 | close(fd); 67 | remove("./test.txt"); 68 | } 69 | SECTION("read and write none double array to a file" + aio_str) { 70 | int fd = open("./test.txt", O_RDWR | O_CREAT, S_IRUSR | S_IWUSR); 71 | const int n_loop = 0, n_len = 0; 72 | 73 | double data1[n_loop][n_len]; 74 | int n = 0, offset = 0; 75 | size_t len; 76 | for (int i = 0; i < n_loop; i++) { 77 | auto fn = std::bind(callback_n, std::ref(n)); 78 | //len = strlen(data1[i]) + 1; 79 | len = n_len * sizeof(double); 80 | aio->write(fd, data1[i], len, offset, fn); 81 | offset += len; 82 | } 83 | aio->sync_write_events(); 84 | 85 | double data2[n_loop][n_len]; 86 | n = 0, offset = 0; 87 | for (int i = 0; i < n_loop; i++) { 88 | auto fn = std::bind(callback_n, std::ref(n)); 89 | len = n_len * sizeof(double); 90 | aio->read(fd, data2[i], len, offset, fn); 91 | offset += len; 92 | } 93 | aio->sync_read_events(); 94 | for (int i = 0; i < n_loop; i++) { 95 | for (int j = 0; j < n_len; j++) { 96 | REQUIRE(data1[i][j] == Approx(data2[i][j]).epsilon(0.001)); 97 | } 98 | } 99 | close(fd); 100 | remove("./test.txt"); 101 | } 102 | SECTION("read and write small double array to a file" + aio_str) { 103 | int fd = open("./test.txt", O_RDWR | O_CREAT, S_IRUSR | S_IWUSR); 104 | const int n_loop = 1, n_len = 1; 105 | 106 | double data1[n_loop][n_len]; 107 | int n = 0, offset = 0; 108 | size_t len; 109 | for (int i = 0; i < n_loop; i++) { 110 | auto fn = std::bind(callback_n, std::ref(n)); 111 | //len = strlen(data1[i]) + 1; 112 | len = n_len * sizeof(double); 113 | aio->write(fd, data1[i], len, offset, fn); 114 | offset += len; 115 | } 116 | aio->sync_write_events(); 117 | 118 | double data2[n_loop][n_len]; 119 | n = 0, offset = 0; 120 | for (int i = 0; i < n_loop; i++) { 121 | auto fn = std::bind(callback_n, std::ref(n)); 122 | len = n_len * sizeof(double); 123 | aio->read(fd, data2[i], len, offset, fn); 124 | offset += len; 125 | } 126 | aio->sync_read_events(); 127 | for (int i = 0; i < n_loop; i++) { 128 | for (int j = 0; j < n_len; j++) { 129 | REQUIRE(data1[i][j] == Approx(data2[i][j]).epsilon(0.001)); 130 | } 131 | } 132 | close(fd); 133 | remove("./test.txt"); 134 | } 135 | SECTION("read and write large double array to a file" + aio_str) { 136 | int fd = open("./test.txt", O_RDWR | O_CREAT, S_IRUSR | S_IWUSR); 137 | const int n_loop = 50, n_len = 50; 138 | 139 | double data1[n_loop][n_len]; 140 | int n = 0, offset = 0; 141 | size_t len; 142 | for (int i = 0; i < n_loop; i++) { 143 | auto fn = std::bind(callback_n, std::ref(n)); 144 | //len = strlen(data1[i]) + 1; 145 | len = n_len * sizeof(double); 146 | aio->write(fd, data1[i], len, offset, fn); 147 | offset += len; 148 | } 149 | aio->sync_write_events(); 150 | 151 | double data2[n_loop][n_len]; 152 | n = 0, offset = 0; 153 | for (int i = 0; i < n_loop; i++) { 154 | auto fn = std::bind(callback_n, std::ref(n)); 155 | len = n_len * sizeof(double); 156 | aio->read(fd, data2[i], len, offset, fn); 157 | offset += len; 158 | } 159 | aio->sync_read_events(); 160 | for (int i = 0; i < n_loop; i++) { 161 | for (int j = 0; j < n_len; j++) { 162 | REQUIRE(data1[i][j] == Approx(data2[i][j]).epsilon(0.001)); 163 | } 164 | } 165 | close(fd); 166 | remove("./test.txt"); 167 | } 168 | SECTION("read and write double array to multiple files" + aio_str) { 169 | const int n_loop = 5, n_len = 18; 170 | double data1[n_loop][n_len]; 171 | int n = 0, offset = 0; 172 | size_t len; 173 | int fds[n_loop]; 174 | char file_name[] = "testn"; 175 | 176 | for (int i = 0; i < n_loop; i++) { 177 | file_name[4] = (char) (i + 1); 178 | fds[i] = open(file_name, O_RDWR | O_CREAT, S_IRUSR | S_IWUSR); 179 | auto fn = std::bind(callback_n, std::ref(n)); 180 | len = n_len * sizeof(double); 181 | aio->write(fds[i], data1[i], len, offset, fn); 182 | offset += len; 183 | } 184 | aio->sync_write_events(); 185 | REQUIRE(n == n_loop); 186 | 187 | double data2[n_loop][n_len]; 188 | n = 0, offset = 0; 189 | for (int i = 0; i < n_loop; i++) { 190 | auto fn = std::bind(callback_n, std::ref(n)); 191 | len = n_len * sizeof(double); 192 | aio->read(fds[i], data2[i], len, offset, fn); 193 | offset += len; 194 | } 195 | aio->sync_read_events(); 196 | REQUIRE(n == n_loop); 197 | 198 | for (int i = 0; i < n_loop; i++) { 199 | for (int j = 0; j < n_len; j++) { 200 | REQUIRE(data1[i][j] == Approx(data2[i][j]).epsilon(0.001)); 201 | } 202 | } 203 | 204 | for (int i = 0; i < n_loop; i++) { 205 | close(fds[i]); 206 | file_name[4] = (char) (i + 1); 207 | remove(file_name); 208 | } 209 | } 210 | SECTION("read and write none double array to multiple files" + aio_str) { 211 | const int n_loop = 10, n_len = 0; 212 | double data1[n_loop][n_len]; 213 | int n = 0, offset = 0; 214 | size_t len; 215 | int fds[n_loop]; 216 | char file_name[] = "testn"; 217 | 218 | for (int i = 0; i < n_loop; i++) { 219 | file_name[4] = (char) (i + 1); 220 | fds[i] = open(file_name, O_RDWR | O_CREAT, S_IRUSR | S_IWUSR); 221 | auto fn = std::bind(callback_n, std::ref(n)); 222 | len = n_len * sizeof(double); 223 | aio->write(fds[i], data1[i], len, offset, fn); 224 | offset += len; 225 | } 226 | aio->sync_write_events(); 227 | 228 | double data2[n_loop][n_len]; 229 | n = 0, offset = 0; 230 | for (int i = 0; i < n_loop; i++) { 231 | auto fn = std::bind(callback_n, std::ref(n)); 232 | len = n_len * sizeof(double); 233 | aio->read(fds[i], data2[i], len, offset, fn); 234 | offset += len; 235 | } 236 | aio->sync_read_events(); 237 | 238 | for (int i = 0; i < n_loop; i++) { 239 | for (int j = 0; j < n_len; j++) { 240 | REQUIRE(data1[i][j] == Approx(data2[i][j]).epsilon(0.001)); 241 | } 242 | } 243 | 244 | for (int i = 0; i < n_loop; i++) { 245 | close(fds[i]); 246 | file_name[4] = (char) (i + 1); 247 | remove(file_name); 248 | } 249 | } 250 | SECTION("read and write small double array to multiple files" + aio_str) { 251 | const int n_loop = 1, n_len = 1; 252 | double data1[n_loop][n_len]; 253 | int n = 0, offset = 0; 254 | size_t len; 255 | int fds[n_loop]; 256 | char file_name[] = "testn"; 257 | 258 | for (int i = 0; i < n_loop; i++) { 259 | file_name[4] = (char) (i + 1); 260 | fds[i] = open(file_name, O_RDWR | O_CREAT, S_IRUSR | S_IWUSR); 261 | auto fn = std::bind(callback_n, std::ref(n)); 262 | len = n_len * sizeof(double); 263 | aio->write(fds[i], data1[i], len, offset, fn); 264 | offset += len; 265 | } 266 | aio->sync_write_events(); 267 | 268 | double data2[n_loop][n_len]; 269 | n = 0, offset = 0; 270 | for (int i = 0; i < n_loop; i++) { 271 | auto fn = std::bind(callback_n, std::ref(n)); 272 | len = n_len * sizeof(double); 273 | aio->read(fds[i], data2[i], len, offset, fn); 274 | offset += len; 275 | } 276 | aio->sync_read_events(); 277 | 278 | for (int i = 0; i < n_loop; i++) { 279 | for (int j = 0; j < n_len; j++) { 280 | REQUIRE(data1[i][j] == Approx(data2[i][j]).epsilon(0.001)); 281 | } 282 | } 283 | 284 | for (int i = 0; i < n_loop; i++) { 285 | close(fds[i]); 286 | file_name[4] = (char) (i + 1); 287 | remove(file_name); 288 | } 289 | } 290 | SECTION("read and write large double array to multiple files" + aio_str) { 291 | const int n_loop = 50, n_len = 50; 292 | double data1[n_loop][n_len]; 293 | int n = 0, offset = 0; 294 | size_t len; 295 | int fds[n_loop]; 296 | string filename = "test", new_name; 297 | 298 | for (int i = 0; i < n_loop; i++) { 299 | new_name = filename + to_string(i); 300 | fds[i] = open(new_name.c_str(), O_RDWR | O_CREAT, S_IRUSR | S_IWUSR); 301 | auto fn = std::bind(callback_n, std::ref(n)); 302 | len = n_len * sizeof(double); 303 | aio->write(fds[i], data1[i], len, offset, fn); 304 | offset += len; 305 | } 306 | aio->sync_write_events(); 307 | REQUIRE(n == n_loop); 308 | 309 | double data2[n_loop][n_len]; 310 | n = 0, offset = 0; 311 | for (int i = 0; i < n_loop; i++) { 312 | auto fn = std::bind(callback_n, std::ref(n)); 313 | len = n_len * sizeof(double); 314 | aio->read(fds[i], data2[i], len, offset, fn); 315 | offset += len; 316 | } 317 | aio->sync_read_events(); 318 | REQUIRE(n == n_loop); 319 | 320 | for (int i = 0; i < n_loop; i++) { 321 | for (int j = 0; j < n_len; j++) { 322 | REQUIRE(data1[i][j] == Approx(data2[i][j]).epsilon(0.001)); 323 | } 324 | } 325 | 326 | for (int i = 0; i < n_loop; i++) { 327 | close(fds[i]); 328 | new_name = filename + to_string(i); 329 | remove(new_name.c_str()); 330 | } 331 | } 332 | SECTION("use nullptr cb to read and write double array to multiple files" + aio_str) { 333 | const int n_loop = 5, n_len = 18; 334 | double data1[n_loop][n_len]; 335 | int offset = 0; 336 | size_t len; 337 | int fds[n_loop]; 338 | char file_name[] = "testn"; 339 | 340 | for (int i = 0; i < n_loop; i++) { 341 | file_name[4] = (char) (i + 1); 342 | fds[i] = open(file_name, O_RDWR | O_CREAT, S_IRUSR | S_IWUSR); 343 | len = n_len * sizeof(double); 344 | REQUIRE_NOTHROW(aio->write(fds[i], data1[i], len, offset, nullptr)); 345 | offset += len; 346 | } 347 | REQUIRE_NOTHROW(aio->sync_write_events()); 348 | 349 | double data2[n_loop][n_len]; 350 | offset = 0; 351 | for (int i = 0; i < n_loop; i++) { 352 | len = n_len * sizeof(double); 353 | REQUIRE_NOTHROW(aio->read(fds[i], data2[i], len, offset, nullptr)); 354 | offset += len; 355 | } 356 | REQUIRE_NOTHROW(aio->sync_read_events()); 357 | 358 | for (int i = 0; i < n_loop; i++) { 359 | for (int j = 0; j < n_len; j++) { 360 | REQUIRE(data1[i][j] == Approx(data2[i][j]).epsilon(0.001)); 361 | } 362 | } 363 | 364 | for (int i = 0; i < n_loop; i++) { 365 | close(fds[i]); 366 | file_name[4] = (char) (i + 1); 367 | remove(file_name); 368 | } 369 | } 370 | SECTION("use empty cb to read and write double array to multiple files" + aio_str) { 371 | const int n_loop = 5, n_len = 18; 372 | double data1[n_loop][n_len]; 373 | int offset = 0; 374 | size_t len; 375 | int fds[n_loop]; 376 | char file_name[] = "testn"; 377 | 378 | for (int i = 0; i < n_loop; i++) { 379 | file_name[4] = (char) (i + 1); 380 | fds[i] = open(file_name, O_RDWR | O_CREAT, S_IRUSR | S_IWUSR); 381 | len = n_len * sizeof(double); 382 | REQUIRE_NOTHROW(aio->write(fds[i], data1[i], len, offset, callback_empty)); 383 | offset += len; 384 | } 385 | REQUIRE_NOTHROW(aio->sync_write_events()); 386 | 387 | double data2[n_loop][n_len]; 388 | offset = 0; 389 | for (int i = 0; i < n_loop; i++) { 390 | len = n_len * sizeof(double); 391 | REQUIRE_NOTHROW(aio->read(fds[i], data2[i], len, offset, callback_empty)); 392 | offset += len; 393 | } 394 | REQUIRE_NOTHROW(aio->sync_read_events()); 395 | 396 | for (int i = 0; i < n_loop; i++) { 397 | for (int j = 0; j < n_len; j++) { 398 | REQUIRE(data1[i][j] == Approx(data2[i][j]).epsilon(0.001)); 399 | } 400 | } 401 | 402 | for (int i = 0; i < n_loop; i++) { 403 | close(fds[i]); 404 | file_name[4] = (char) (i + 1); 405 | remove(file_name); 406 | } 407 | } 408 | delete aio; 409 | } 410 | -------------------------------------------------------------------------------- /tests/test_disk_offloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytest 3 | from tensornvme import DiskOffloader 4 | 5 | 6 | @pytest.mark.parametrize('backend', ['uring', 'aio']) 7 | def test_sync_io(backend): 8 | x = torch.rand(2, 2) 9 | x_copy = x.clone() 10 | of = DiskOffloader('.', backend=backend) 11 | try: 12 | of.sync_read(x) 13 | assert False 14 | except RuntimeError: 15 | pass 16 | of.sync_write(x) 17 | assert x.storage().size() == 0 18 | of.sync_read(x) 19 | assert torch.equal(x, x_copy) 20 | 21 | 22 | @pytest.mark.parametrize('backend', ['uring', 'aio']) 23 | def test_async_io(backend): 24 | x = torch.rand(2, 2) 25 | x_copy = x.clone() 26 | of = DiskOffloader('.', backend=backend) 27 | try: 28 | of.async_read(x) 29 | assert False 30 | except RuntimeError: 31 | pass 32 | of.async_write(x) 33 | # assert x.storage().size() > 0 34 | of.sync_write_events() 35 | assert x.storage().size() == 0 36 | of.sync_read(x) 37 | of.sync_read_events() 38 | assert torch.equal(x, x_copy) 39 | 40 | 41 | @pytest.mark.parametrize('backend', ['uring', 'aio']) 42 | def test_sync_vec_io(backend): 43 | x = torch.rand(2, 2) 44 | y = torch.rand(2, 2, 2) 45 | x_copy = x.clone() 46 | y_copy = y.clone() 47 | of = DiskOffloader('.', backend=backend) 48 | try: 49 | of.sync_readv([x, y]) 50 | assert False 51 | except RuntimeError: 52 | pass 53 | of.sync_writev([x, y]) 54 | assert x.storage().size() == 0 55 | assert y.storage().size() == 0 56 | try: 57 | of.sync_readv(x) 58 | assert False 59 | except RuntimeError: 60 | pass 61 | try: 62 | of.sync_readv([y, x]) 63 | assert False 64 | except RuntimeError: 65 | pass 66 | of.sync_readv([x, y]) 67 | assert torch.equal(x, x_copy) 68 | assert torch.equal(y, y_copy) 69 | 70 | 71 | @pytest.mark.parametrize('backend', ['uring', 'aio']) 72 | def test_async_vec_io(backend): 73 | x = torch.rand(2, 2) 74 | y = torch.rand(2, 2, 2) 75 | x_copy = x.clone() 76 | y_copy = y.clone() 77 | of = DiskOffloader('.', backend=backend) 78 | try: 79 | of.async_readv([x, y]) 80 | assert False 81 | except RuntimeError: 82 | pass 83 | of.async_writev([x, y]) 84 | # assert x.storage().size() > 0 85 | # assert y.storage().size() > 0 86 | of.sync_write_events() 87 | assert x.storage().size() == 0 88 | assert y.storage().size() == 0 89 | try: 90 | of.async_readv(x) 91 | assert False 92 | except RuntimeError: 93 | pass 94 | try: 95 | of.async_readv([y, x]) 96 | assert False 97 | except RuntimeError: 98 | pass 99 | of.async_readv([x, y]) 100 | of.sync_read_events() 101 | assert torch.equal(x, x_copy) 102 | assert torch.equal(y, y_copy) 103 | 104 | 105 | if __name__ == '__main__': 106 | test_sync_io('uring') 107 | test_async_io('uring') 108 | test_sync_vec_io('uring') 109 | test_async_vec_io('uring') 110 | -------------------------------------------------------------------------------- /tests/test_offload.cpp: -------------------------------------------------------------------------------- 1 | #define CATCH_CONFIG_MAIN 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include "aio.h" 9 | #include "catch.hpp" 10 | #include "uring.h" 11 | #include "offload.h" 12 | using namespace std; 13 | 14 | 15 | TEST_CASE( "Test offload" ) { 16 | Offloader *offload = new Offloader("./test", 2, "uring"); 17 | at::Tensor tensor1 = at::ones({12, 12}), tensor2 = at::rand({12, 12}), tensor3 = at::zeros({12, 12}); 18 | offload->async_write(tensor1, "12345"); 19 | offload->sync_write_events(); 20 | offload->async_read(tensor2, "12345"); 21 | offload->sync_write_events(); 22 | 23 | auto foo_1 = tensor1.accessor(); 24 | auto foo_2 = tensor2.accessor(); 25 | auto foo_3 = tensor3.accessor(); 26 | 27 | for(int i = 0; i < 12; i++) { 28 | for(int j = 0; j < 12; j++) { 29 | // use the accessor foo_a to get tensor data. 30 | REQUIRE(foo_1[i][j] == foo_2[i][j]); 31 | REQUIRE(foo_1[i][j] != foo_3[i][j]); 32 | } 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /tests/test_space_mgr.cpp: -------------------------------------------------------------------------------- 1 | #define CATCH_CONFIG_MAIN 2 | #include 3 | #include "catch.hpp" 4 | 5 | #define private public 6 | #include "space_mgr.h" 7 | 8 | 9 | TEST_CASE( "Test space manager function" ) { 10 | SpaceManager space_mgr(0); 11 | SpaceInfo new_avail_space; 12 | vector test_avail_spaces; 13 | ull offset; 14 | 15 | SECTION( "origin test" ){ 16 | offset = space_mgr.alloc(4); 17 | // [0, 4) is used 18 | REQUIRE(offset == 0); 19 | offset = space_mgr.alloc(8); 20 | // [0, 12) is used 21 | REQUIRE(offset == 4); 22 | space_mgr.free(4, 4); 23 | // [0, 4) and [8, 12) are used 24 | offset = space_mgr.alloc(2); 25 | // [0, 6) and [8, 12) are used 26 | REQUIRE(offset == 4); 27 | offset = space_mgr.alloc(4); 28 | // [0, 6) and [8, 16) are used 29 | REQUIRE(offset == 12); 30 | space_mgr.free(0, 2); 31 | space_mgr.free(4, 2); 32 | // [2, 4) and [8, 16) are used 33 | space_mgr.free(2, 2); 34 | // [8, 16) is used 35 | offset = space_mgr.alloc(5); 36 | // [0, 5) and [8, 16) are used 37 | REQUIRE(offset == 0); 38 | offset = space_mgr.alloc(4); 39 | // [0, 5) and [8, 20) are used 40 | REQUIRE(offset == 16); 41 | offset = space_mgr.alloc(2); 42 | // [0, 7) and [8, 20) are used 43 | REQUIRE(offset == 5); 44 | } 45 | 46 | SECTION( "alloc and" ) { 47 | offset = space_mgr.alloc(20); 48 | REQUIRE(offset == 0); 49 | REQUIRE(space_mgr.used_bytes == 20); 50 | auto test_iter = test_avail_spaces.begin(); 51 | for (auto iter = space_mgr.avail_spaces.begin(); iter != space_mgr.avail_spaces.end(); iter++){ 52 | REQUIRE(iter->first == test_iter->first); 53 | REQUIRE(iter->second == test_iter->second); 54 | REQUIRE_NOTHROW(test_iter++); 55 | } 56 | 57 | SECTION( "drop the last one and" ) { 58 | space_mgr.free(19, 1); 59 | new_avail_space.first = 19; 60 | new_avail_space.second = 1; 61 | test_avail_spaces.push_back(new_avail_space); 62 | test_iter = test_avail_spaces.begin(); 63 | for (auto iter = space_mgr.avail_spaces.begin(); iter != space_mgr.avail_spaces.end(); iter++) { 64 | REQUIRE(iter->first == test_iter->first); 65 | REQUIRE(iter->second == test_iter->second); 66 | REQUIRE_NOTHROW(test_iter++); 67 | } 68 | REQUIRE(space_mgr.used_bytes == 19); 69 | 70 | SECTION( "alloc a little one" ) { 71 | offset = space_mgr.alloc(1); 72 | REQUIRE(offset == 19); 73 | REQUIRE(space_mgr.used_bytes == 20); 74 | } 75 | 76 | SECTION( "alloc a big one" ) { 77 | offset = space_mgr.alloc(2); 78 | CHECK(offset == 19); 79 | CHECK(space_mgr.used_bytes == 21); 80 | } 81 | 82 | SECTION( "alloc a bigger one" ) { 83 | offset = space_mgr.alloc(10); 84 | CHECK(offset == 19); 85 | CHECK(space_mgr.used_bytes == 29); 86 | } 87 | } 88 | 89 | SECTION( "drop the last ones and" ) { 90 | space_mgr.free(16, 4); 91 | new_avail_space.first = 16; 92 | new_avail_space.second = 4; 93 | test_avail_spaces.push_back(new_avail_space); 94 | test_iter = test_avail_spaces.begin(); 95 | for (auto iter = space_mgr.avail_spaces.begin(); iter != space_mgr.avail_spaces.end(); iter++) { 96 | REQUIRE(iter->first == test_iter->first); 97 | REQUIRE(iter->second == test_iter->second); 98 | REQUIRE_NOTHROW(test_iter++); 99 | } 100 | REQUIRE(space_mgr.used_bytes == 16); 101 | 102 | SECTION( "alloc a little one" ) { 103 | offset = space_mgr.alloc(4); 104 | REQUIRE(offset == 16); 105 | REQUIRE(space_mgr.used_bytes == 20); 106 | } 107 | 108 | SECTION( "alloc a big one" ) { 109 | offset = space_mgr.alloc(5); 110 | CHECK(offset == 16); 111 | CHECK(space_mgr.used_bytes == 21); 112 | } 113 | 114 | SECTION( "alloc a bigger one" ) { 115 | offset = space_mgr.alloc(10); 116 | CHECK(offset == 16); 117 | CHECK(space_mgr.used_bytes == 26); 118 | } 119 | } 120 | 121 | SECTION( "drop middle ones and" ) { 122 | space_mgr.free(10, 4); 123 | new_avail_space.first = 10; 124 | new_avail_space.second = 4; 125 | test_avail_spaces.push_back(new_avail_space); 126 | test_iter = test_avail_spaces.begin(); 127 | for (auto iter = space_mgr.avail_spaces.begin(); iter != space_mgr.avail_spaces.end(); iter++) { 128 | REQUIRE(iter->first == test_iter->first); 129 | REQUIRE(iter->second == test_iter->second); 130 | REQUIRE_NOTHROW(test_iter++); 131 | } 132 | REQUIRE(space_mgr.used_bytes == 20); 133 | 134 | SECTION( "alloc a small one and" ) { 135 | offset = space_mgr.alloc(4); 136 | REQUIRE(offset == 10); 137 | REQUIRE(space_mgr.used_bytes == 20); 138 | } 139 | 140 | SECTION( "alloc a big one and" ) { 141 | offset = space_mgr.alloc(5); 142 | REQUIRE(offset == 20); 143 | REQUIRE(space_mgr.used_bytes == 25); 144 | } 145 | } 146 | 147 | SECTION( "drop the first one and" ) { 148 | space_mgr.free(0, 1); 149 | new_avail_space.first = 0; 150 | new_avail_space.second = 1; 151 | test_avail_spaces.push_back(new_avail_space); 152 | test_iter = test_avail_spaces.begin(); 153 | for (auto iter = space_mgr.avail_spaces.begin(); iter != space_mgr.avail_spaces.end(); iter++) { 154 | REQUIRE(iter->first == test_iter->first); 155 | REQUIRE(iter->second == test_iter->second); 156 | REQUIRE_NOTHROW(test_iter++); 157 | } 158 | REQUIRE(space_mgr.used_bytes == 20); 159 | 160 | SECTION( "alloc a small one and" ) { 161 | offset = space_mgr.alloc(1); 162 | REQUIRE(offset == 0); 163 | REQUIRE(space_mgr.used_bytes == 20); 164 | } 165 | 166 | SECTION( "alloc a big one and" ) { 167 | offset = space_mgr.alloc(2); 168 | REQUIRE(offset == 20); 169 | REQUIRE(space_mgr.used_bytes == 22); 170 | } 171 | } 172 | } 173 | 174 | } -------------------------------------------------------------------------------- /version.txt: -------------------------------------------------------------------------------- 1 | 0.1.0 2 | --------------------------------------------------------------------------------