├── .github └── workflows │ ├── publish-wheel.yml │ ├── python.yml │ └── rust.yml ├── .gitignore ├── .vscode └── settings.json ├── Cargo.lock ├── Cargo.toml ├── LICENSE.md ├── README.md ├── demos └── generate_text.ipynb ├── environment.yml ├── pyproject.toml ├── src ├── bindings │ ├── in_memory_index.rs │ ├── memmap_index.rs │ ├── mod.rs │ ├── sharded_in_memory_index.rs │ └── sharded_memmap_index.rs ├── in_memory_index.rs ├── lib.rs ├── memmap_index.rs ├── mmap_slice.rs ├── par_quicksort.rs ├── sample.rs ├── sharded_in_memory_index.rs ├── sharded_memmap_index.rs ├── table.rs └── util.rs ├── tests └── tests.rs └── tokengrams ├── __init__.py ├── benchmark ├── InMemoryIndex_build_times.png ├── InMemoryIndex_count_next_times.png ├── MemmapIndex_build_times.png ├── MemmapIndex_count_next_times.png └── benchmark.py ├── tests ├── __init__.py ├── test_gram_index.py └── test_sharded_index.py ├── tokengrams.pyi └── utils ├── __init__.py └── tokenize_hf_dataset.py /.github/workflows/publish-wheel.yml: -------------------------------------------------------------------------------- 1 | name: Publish 2 | on: 3 | workflow_dispatch: 4 | 5 | jobs: 6 | build: 7 | name: Build wheels for ${{ matrix.os }} - Python ${{ matrix.python-version }} 8 | strategy: 9 | fail-fast: false 10 | matrix: 11 | include: 12 | # 3.10 13 | - os: ubuntu-latest 14 | target: x86_64-unknown-linux-gnu 15 | python-version: '3.10' 16 | - os: ubuntu-latest 17 | target: aarch64-unknown-linux-gnu 18 | python-version: '3.10' 19 | - os: macos-latest 20 | target: x86_64-apple-darwin 21 | python-version: '3.10' 22 | - os: macos-latest 23 | target: aarch64-apple-darwin 24 | python-version: '3.10' 25 | - os: windows-latest 26 | target: x86_64-pc-windows-msvc 27 | python-version: '3.10' 28 | # 3.11- 29 | - os: ubuntu-latest 30 | target: x86_64-unknown-linux-gnu 31 | python-version: '3.11' 32 | - os: ubuntu-latest 33 | target: aarch64-unknown-linux-gnu 34 | python-version: '3.11' 35 | - os: macos-latest 36 | target: x86_64-apple-darwin 37 | python-version: '3.11' 38 | - os: macos-latest 39 | target: aarch64-apple-darwin 40 | python-version: '3.11' 41 | - os: windows-latest 42 | target: x86_64-pc-windows-msvc 43 | python-version: '3.11' 44 | 45 | runs-on: ${{ matrix.os }} 46 | steps: 47 | - uses: actions/checkout@v3 48 | - name: Set up Python ${{ matrix.python-version }} 49 | uses: actions/setup-python@v4 50 | with: 51 | python-version: ${{ matrix.python-version }} 52 | - name: Set up Rust 53 | uses: actions-rs/toolchain@v1 54 | with: 55 | profile: minimal 56 | toolchain: stable 57 | target: ${{ matrix.target }} 58 | override: true 59 | - name: Add macOS Rust targets 60 | if: matrix.os == 'macos-latest' 61 | run: | 62 | rustup target add x86_64-apple-darwin aarch64-apple-darwin 63 | - name: Build wheels 64 | uses: PyO3/maturin-action@v1 65 | with: 66 | target: ${{ matrix.target }} 67 | args: --release --out dist --interpreter python${{ matrix.python-version }} ${{ matrix.os == 'macos-latest' && '--target universal2-apple-darwin' || '' }} 68 | manylinux: auto 69 | - name: Upload wheels 70 | uses: actions/upload-artifact@v2 71 | with: 72 | name: wheels 73 | path: dist 74 | 75 | build-sdist: 76 | name: Build source distribution 77 | runs-on: ubuntu-latest 78 | steps: 79 | - uses: actions/checkout@v3 80 | - name: Build sdist 81 | uses: PyO3/maturin-action@v1 82 | with: 83 | command: sdist 84 | args: --out dist 85 | - name: Upload sdist 86 | uses: actions/upload-artifact@v2 87 | with: 88 | name: wheels 89 | path: dist 90 | 91 | publish: 92 | name: Publish to PyPI 93 | needs: [build, build-sdist] 94 | runs-on: ubuntu-latest 95 | steps: 96 | - uses: actions/download-artifact@v2 97 | with: 98 | name: wheels 99 | path: dist 100 | - name: Publish to PyPI 101 | uses: pypa/gh-action-pypi-publish@release/v1 102 | with: 103 | user: __token__ 104 | password: ${{ secrets.PYPI_API_TOKEN }} 105 | packages_dir: dist/ 106 | skip_existing: true -------------------------------------------------------------------------------- /.github/workflows/python.yml: -------------------------------------------------------------------------------- 1 | name: Python Package using Conda 2 | 3 | on: [push] 4 | 5 | jobs: 6 | build-linux: 7 | runs-on: ubuntu-latest 8 | strategy: 9 | max-parallel: 5 10 | 11 | steps: 12 | - uses: actions/checkout@v3 13 | - name: Set up Conda 14 | uses: conda-incubator/setup-miniconda@v3 15 | with: 16 | python-version: "3.10" 17 | miniforge-version: latest 18 | use-mamba: true 19 | mamba-version: "*" 20 | - name: Test Python 21 | env: 22 | PYTHONPATH: /home/runner/work/tokengrams/tokengrams 23 | shell: bash -l {0} 24 | run: | 25 | mamba install -c conda-forge numpy pytest hypothesis maturin 26 | maturin develop 27 | maturin build 28 | python -m pip install --user ./target/wheels/tokengrams*.whl 29 | pytest -------------------------------------------------------------------------------- /.github/workflows/rust.yml: -------------------------------------------------------------------------------- 1 | name: Rust 2 | 3 | on: 4 | push: 5 | branches: [ "master" ] 6 | pull_request: 7 | branches: [ "master" ] 8 | 9 | env: 10 | CARGO_TERM_COLOR: always 11 | 12 | jobs: 13 | build: 14 | 15 | runs-on: ubuntu-latest 16 | 17 | steps: 18 | - uses: actions/checkout@v3 19 | - name: Build 20 | run: cargo build --verbose 21 | - name: Run tests 22 | run: cargo test --verbose 23 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | cover/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | .pybuilder/ 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | # For a library or package, you might want to ignore these files since the code is 90 | # intended to run in multiple environments; otherwise, check them in: 91 | # .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # poetry 101 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 102 | # This is especially recommended for binary packages to ensure reproducibility, and is more 103 | # commonly ignored for libraries. 104 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 105 | #poetry.lock 106 | 107 | # pdm 108 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 109 | #pdm.lock 110 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 111 | # in version control. 112 | # https://pdm.fming.dev/#use-with-ide 113 | .pdm.toml 114 | 115 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 116 | __pypackages__/ 117 | 118 | # Celery stuff 119 | celerybeat-schedule 120 | celerybeat.pid 121 | 122 | # SageMath parsed files 123 | *.sage.py 124 | 125 | # Environments 126 | .env 127 | .venv 128 | env/ 129 | venv/ 130 | ENV/ 131 | env.bak/ 132 | venv.bak/ 133 | 134 | # Spyder project settings 135 | .spyderproject 136 | .spyproject 137 | 138 | # Rope project settings 139 | .ropeproject 140 | 141 | # mkdocs documentation 142 | /site 143 | 144 | # mypy 145 | .mypy_cache/ 146 | .dmypy.json 147 | dmypy.json 148 | 149 | # Pyre type checker 150 | .pyre/ 151 | 152 | # pytype static type analyzer 153 | .pytype/ 154 | 155 | # Cython debug symbols 156 | cython_debug/ 157 | 158 | # PyCharm 159 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 160 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 161 | # and can be added to the global gitignore or merged into this file. For a more nuclear 162 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 163 | #.idea/ 164 | 165 | # MacOS 166 | .DS_Store -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "rust-analyzer.linkedProjects": [ 3 | "./Cargo.toml", 4 | ], 5 | "python.testing.pytestArgs": [ 6 | "tests" 7 | ], 8 | "python.testing.unittestEnabled": false, 9 | "python.testing.pytestEnabled": true 10 | } -------------------------------------------------------------------------------- /Cargo.lock: -------------------------------------------------------------------------------- 1 | # This file is automatically @generated by Cargo. 2 | # It is not intended for manual editing. 3 | version = 3 4 | 5 | [[package]] 6 | name = "anyhow" 7 | version = "1.0.81" 8 | source = "registry+https://github.com/rust-lang/crates.io-index" 9 | checksum = "0952808a6c2afd1aa8947271f3a60f1a6763c7b912d210184c5149b5cf147247" 10 | 11 | [[package]] 12 | name = "autocfg" 13 | version = "1.1.0" 14 | source = "registry+https://github.com/rust-lang/crates.io-index" 15 | checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" 16 | 17 | [[package]] 18 | name = "bincode" 19 | version = "1.3.3" 20 | source = "registry+https://github.com/rust-lang/crates.io-index" 21 | checksum = "b1f45e9417d87227c7a56d22e471c6206462cba514c7590c09aff4cf6d1ddcad" 22 | dependencies = [ 23 | "serde", 24 | ] 25 | 26 | [[package]] 27 | name = "cfg-if" 28 | version = "1.0.0" 29 | source = "registry+https://github.com/rust-lang/crates.io-index" 30 | checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" 31 | 32 | [[package]] 33 | name = "console" 34 | version = "0.15.8" 35 | source = "registry+https://github.com/rust-lang/crates.io-index" 36 | checksum = "0e1f83fc076bd6dd27517eacdf25fef6c4dfe5f1d7448bafaaf3a26f13b5e4eb" 37 | dependencies = [ 38 | "encode_unicode", 39 | "lazy_static", 40 | "libc", 41 | "unicode-width", 42 | "windows-sys", 43 | ] 44 | 45 | [[package]] 46 | name = "crossbeam-deque" 47 | version = "0.8.5" 48 | source = "registry+https://github.com/rust-lang/crates.io-index" 49 | checksum = "613f8cc01fe9cf1a3eb3d7f488fd2fa8388403e97039e2f73692932e291a770d" 50 | dependencies = [ 51 | "crossbeam-epoch", 52 | "crossbeam-utils", 53 | ] 54 | 55 | [[package]] 56 | name = "crossbeam-epoch" 57 | version = "0.9.18" 58 | source = "registry+https://github.com/rust-lang/crates.io-index" 59 | checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" 60 | dependencies = [ 61 | "crossbeam-utils", 62 | ] 63 | 64 | [[package]] 65 | name = "crossbeam-utils" 66 | version = "0.8.19" 67 | source = "registry+https://github.com/rust-lang/crates.io-index" 68 | checksum = "248e3bacc7dc6baa3b21e405ee045c3047101a49145e7e9eca583ab4c2ca5345" 69 | 70 | [[package]] 71 | name = "either" 72 | version = "1.10.0" 73 | source = "registry+https://github.com/rust-lang/crates.io-index" 74 | checksum = "11157ac094ffbdde99aa67b23417ebdd801842852b500e395a45a9c0aac03e4a" 75 | 76 | [[package]] 77 | name = "encode_unicode" 78 | version = "0.3.6" 79 | source = "registry+https://github.com/rust-lang/crates.io-index" 80 | checksum = "a357d28ed41a50f9c765dbfe56cbc04a64e53e5fc58ba79fbc34c10ef3df831f" 81 | 82 | [[package]] 83 | name = "erased-serde" 84 | version = "0.4.5" 85 | source = "registry+https://github.com/rust-lang/crates.io-index" 86 | checksum = "24e2389d65ab4fab27dc2a5de7b191e1f6617d1f1c8855c0dc569c94a4cbb18d" 87 | dependencies = [ 88 | "serde", 89 | "typeid", 90 | ] 91 | 92 | [[package]] 93 | name = "funty" 94 | version = "2.0.0" 95 | source = "registry+https://github.com/rust-lang/crates.io-index" 96 | checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c" 97 | 98 | [[package]] 99 | name = "getrandom" 100 | version = "0.1.16" 101 | source = "registry+https://github.com/rust-lang/crates.io-index" 102 | checksum = "8fc3cb4d91f53b50155bdcfd23f6a4c39ae1969c2ae85982b135750cccaf5fce" 103 | dependencies = [ 104 | "cfg-if", 105 | "libc", 106 | "wasi 0.9.0+wasi-snapshot-preview1", 107 | ] 108 | 109 | [[package]] 110 | name = "getrandom" 111 | version = "0.2.11" 112 | source = "registry+https://github.com/rust-lang/crates.io-index" 113 | checksum = "fe9006bed769170c11f845cf00c7c1e9092aeb3f268e007c3e760ac68008070f" 114 | dependencies = [ 115 | "cfg-if", 116 | "libc", 117 | "wasi 0.11.0+wasi-snapshot-preview1", 118 | ] 119 | 120 | [[package]] 121 | name = "heck" 122 | version = "0.5.0" 123 | source = "registry+https://github.com/rust-lang/crates.io-index" 124 | checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" 125 | 126 | [[package]] 127 | name = "indicatif" 128 | version = "0.17.8" 129 | source = "registry+https://github.com/rust-lang/crates.io-index" 130 | checksum = "763a5a8f45087d6bcea4222e7b72c291a054edf80e4ef6efd2a4979878c7bea3" 131 | dependencies = [ 132 | "console", 133 | "instant", 134 | "number_prefix", 135 | "portable-atomic", 136 | "unicode-width", 137 | ] 138 | 139 | [[package]] 140 | name = "indoc" 141 | version = "2.0.4" 142 | source = "registry+https://github.com/rust-lang/crates.io-index" 143 | checksum = "1e186cfbae8084e513daff4240b4797e342f988cecda4fb6c939150f96315fd8" 144 | 145 | [[package]] 146 | name = "instant" 147 | version = "0.1.12" 148 | source = "registry+https://github.com/rust-lang/crates.io-index" 149 | checksum = "7a5bbe824c507c5da5956355e86a746d82e0e1464f65d862cc5e71da70e94b2c" 150 | dependencies = [ 151 | "cfg-if", 152 | ] 153 | 154 | [[package]] 155 | name = "inventory" 156 | version = "0.3.15" 157 | source = "registry+https://github.com/rust-lang/crates.io-index" 158 | checksum = "f958d3d68f4167080a18141e10381e7634563984a537f2a49a30fd8e53ac5767" 159 | 160 | [[package]] 161 | name = "lazy_static" 162 | version = "1.4.0" 163 | source = "registry+https://github.com/rust-lang/crates.io-index" 164 | checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" 165 | 166 | [[package]] 167 | name = "libc" 168 | version = "0.2.151" 169 | source = "registry+https://github.com/rust-lang/crates.io-index" 170 | checksum = "302d7ab3130588088d277783b1e2d2e10c9e9e4a16dd9050e6ec93fb3e7048f4" 171 | 172 | [[package]] 173 | name = "memmap2" 174 | version = "0.9.4" 175 | source = "registry+https://github.com/rust-lang/crates.io-index" 176 | checksum = "fe751422e4a8caa417e13c3ea66452215d7d63e19e604f4980461212f3ae1322" 177 | dependencies = [ 178 | "libc", 179 | ] 180 | 181 | [[package]] 182 | name = "memoffset" 183 | version = "0.9.0" 184 | source = "registry+https://github.com/rust-lang/crates.io-index" 185 | checksum = "5a634b1c61a95585bd15607c6ab0c4e5b226e695ff2800ba0cdccddf208c406c" 186 | dependencies = [ 187 | "autocfg", 188 | ] 189 | 190 | [[package]] 191 | name = "number_prefix" 192 | version = "0.4.0" 193 | source = "registry+https://github.com/rust-lang/crates.io-index" 194 | checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3" 195 | 196 | [[package]] 197 | name = "once_cell" 198 | version = "1.19.0" 199 | source = "registry+https://github.com/rust-lang/crates.io-index" 200 | checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" 201 | 202 | [[package]] 203 | name = "portable-atomic" 204 | version = "1.6.0" 205 | source = "registry+https://github.com/rust-lang/crates.io-index" 206 | checksum = "7170ef9988bc169ba16dd36a7fa041e5c4cbeb6a35b76d4c03daded371eae7c0" 207 | 208 | [[package]] 209 | name = "ppv-lite86" 210 | version = "0.2.17" 211 | source = "registry+https://github.com/rust-lang/crates.io-index" 212 | checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" 213 | 214 | [[package]] 215 | name = "proc-macro2" 216 | version = "1.0.86" 217 | source = "registry+https://github.com/rust-lang/crates.io-index" 218 | checksum = "5e719e8df665df0d1c8fbfd238015744736151d4445ec0836b8e628aae103b77" 219 | dependencies = [ 220 | "unicode-ident", 221 | ] 222 | 223 | [[package]] 224 | name = "pyo3" 225 | version = "0.22.2" 226 | source = "registry+https://github.com/rust-lang/crates.io-index" 227 | checksum = "831e8e819a138c36e212f3af3fd9eeffed6bf1510a805af35b0edee5ffa59433" 228 | dependencies = [ 229 | "anyhow", 230 | "cfg-if", 231 | "indoc", 232 | "libc", 233 | "memoffset", 234 | "once_cell", 235 | "portable-atomic", 236 | "pyo3-build-config", 237 | "pyo3-ffi", 238 | "pyo3-macros", 239 | "unindent", 240 | ] 241 | 242 | [[package]] 243 | name = "pyo3-build-config" 244 | version = "0.22.2" 245 | source = "registry+https://github.com/rust-lang/crates.io-index" 246 | checksum = "1e8730e591b14492a8945cdff32f089250b05f5accecf74aeddf9e8272ce1fa8" 247 | dependencies = [ 248 | "once_cell", 249 | "target-lexicon", 250 | ] 251 | 252 | [[package]] 253 | name = "pyo3-ffi" 254 | version = "0.22.2" 255 | source = "registry+https://github.com/rust-lang/crates.io-index" 256 | checksum = "5e97e919d2df92eb88ca80a037969f44e5e70356559654962cbb3316d00300c6" 257 | dependencies = [ 258 | "libc", 259 | "pyo3-build-config", 260 | ] 261 | 262 | [[package]] 263 | name = "pyo3-macros" 264 | version = "0.22.2" 265 | source = "registry+https://github.com/rust-lang/crates.io-index" 266 | checksum = "eb57983022ad41f9e683a599f2fd13c3664d7063a3ac5714cae4b7bee7d3f206" 267 | dependencies = [ 268 | "proc-macro2", 269 | "pyo3-macros-backend", 270 | "quote", 271 | "syn", 272 | ] 273 | 274 | [[package]] 275 | name = "pyo3-macros-backend" 276 | version = "0.22.2" 277 | source = "registry+https://github.com/rust-lang/crates.io-index" 278 | checksum = "ec480c0c51ddec81019531705acac51bcdbeae563557c982aa8263bb96880372" 279 | dependencies = [ 280 | "heck", 281 | "proc-macro2", 282 | "pyo3-build-config", 283 | "quote", 284 | "syn", 285 | ] 286 | 287 | [[package]] 288 | name = "quickcheck" 289 | version = "0.9.2" 290 | source = "registry+https://github.com/rust-lang/crates.io-index" 291 | checksum = "a44883e74aa97ad63db83c4bf8ca490f02b2fc02f92575e720c8551e843c945f" 292 | dependencies = [ 293 | "rand 0.7.3", 294 | "rand_core 0.5.1", 295 | ] 296 | 297 | [[package]] 298 | name = "quote" 299 | version = "1.0.35" 300 | source = "registry+https://github.com/rust-lang/crates.io-index" 301 | checksum = "291ec9ab5efd934aaf503a6466c5d5251535d108ee747472c3977cc5acc868ef" 302 | dependencies = [ 303 | "proc-macro2", 304 | ] 305 | 306 | [[package]] 307 | name = "rand" 308 | version = "0.7.3" 309 | source = "registry+https://github.com/rust-lang/crates.io-index" 310 | checksum = "6a6b1679d49b24bbfe0c803429aa1874472f50d9b363131f0e89fc356b544d03" 311 | dependencies = [ 312 | "getrandom 0.1.16", 313 | "libc", 314 | "rand_chacha 0.2.2", 315 | "rand_core 0.5.1", 316 | "rand_hc", 317 | ] 318 | 319 | [[package]] 320 | name = "rand" 321 | version = "0.8.5" 322 | source = "registry+https://github.com/rust-lang/crates.io-index" 323 | checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" 324 | dependencies = [ 325 | "libc", 326 | "rand_chacha 0.3.1", 327 | "rand_core 0.6.4", 328 | ] 329 | 330 | [[package]] 331 | name = "rand_chacha" 332 | version = "0.2.2" 333 | source = "registry+https://github.com/rust-lang/crates.io-index" 334 | checksum = "f4c8ed856279c9737206bf725bf36935d8666ead7aa69b52be55af369d193402" 335 | dependencies = [ 336 | "ppv-lite86", 337 | "rand_core 0.5.1", 338 | ] 339 | 340 | [[package]] 341 | name = "rand_chacha" 342 | version = "0.3.1" 343 | source = "registry+https://github.com/rust-lang/crates.io-index" 344 | checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" 345 | dependencies = [ 346 | "ppv-lite86", 347 | "rand_core 0.6.4", 348 | ] 349 | 350 | [[package]] 351 | name = "rand_core" 352 | version = "0.5.1" 353 | source = "registry+https://github.com/rust-lang/crates.io-index" 354 | checksum = "90bde5296fc891b0cef12a6d03ddccc162ce7b2aff54160af9338f8d40df6d19" 355 | dependencies = [ 356 | "getrandom 0.1.16", 357 | ] 358 | 359 | [[package]] 360 | name = "rand_core" 361 | version = "0.6.4" 362 | source = "registry+https://github.com/rust-lang/crates.io-index" 363 | checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" 364 | dependencies = [ 365 | "getrandom 0.2.11", 366 | ] 367 | 368 | [[package]] 369 | name = "rand_hc" 370 | version = "0.2.0" 371 | source = "registry+https://github.com/rust-lang/crates.io-index" 372 | checksum = "ca3129af7b92a17112d59ad498c6f81eaf463253766b90396d39ea7a39d6613c" 373 | dependencies = [ 374 | "rand_core 0.5.1", 375 | ] 376 | 377 | [[package]] 378 | name = "rayon" 379 | version = "1.10.0" 380 | source = "registry+https://github.com/rust-lang/crates.io-index" 381 | checksum = "b418a60154510ca1a002a752ca9714984e21e4241e804d32555251faf8b78ffa" 382 | dependencies = [ 383 | "either", 384 | "rayon-core", 385 | ] 386 | 387 | [[package]] 388 | name = "rayon-core" 389 | version = "1.12.1" 390 | source = "registry+https://github.com/rust-lang/crates.io-index" 391 | checksum = "1465873a3dfdaa8ae7cb14b4383657caab0b3e8a0aa9ae8e04b044854c8dfce2" 392 | dependencies = [ 393 | "crossbeam-deque", 394 | "crossbeam-utils", 395 | ] 396 | 397 | [[package]] 398 | name = "serde" 399 | version = "1.0.197" 400 | source = "registry+https://github.com/rust-lang/crates.io-index" 401 | checksum = "3fb1c873e1b9b056a4dc4c0c198b24c3ffa059243875552b2bd0933b1aee4ce2" 402 | dependencies = [ 403 | "serde_derive", 404 | ] 405 | 406 | [[package]] 407 | name = "serde_derive" 408 | version = "1.0.197" 409 | source = "registry+https://github.com/rust-lang/crates.io-index" 410 | checksum = "7eb0b34b42edc17f6b7cac84a52a1c5f0e1bb2227e997ca9011ea3dd34e8610b" 411 | dependencies = [ 412 | "proc-macro2", 413 | "quote", 414 | "syn", 415 | ] 416 | 417 | [[package]] 418 | name = "syn" 419 | version = "2.0.72" 420 | source = "registry+https://github.com/rust-lang/crates.io-index" 421 | checksum = "dc4b9b9bf2add8093d3f2c0204471e951b2285580335de42f9d2534f3ae7a8af" 422 | dependencies = [ 423 | "proc-macro2", 424 | "quote", 425 | "unicode-ident", 426 | ] 427 | 428 | [[package]] 429 | name = "target-lexicon" 430 | version = "0.12.15" 431 | source = "registry+https://github.com/rust-lang/crates.io-index" 432 | checksum = "4873307b7c257eddcb50c9bedf158eb669578359fb28428bef438fec8e6ba7c2" 433 | 434 | [[package]] 435 | name = "tokengrams" 436 | version = "0.3.3" 437 | dependencies = [ 438 | "anyhow", 439 | "bincode", 440 | "funty", 441 | "indicatif", 442 | "memmap2", 443 | "pyo3", 444 | "quickcheck", 445 | "rand 0.8.5", 446 | "rayon", 447 | "rayon-core", 448 | "serde", 449 | "typetag", 450 | "utf16_literal", 451 | ] 452 | 453 | [[package]] 454 | name = "typeid" 455 | version = "1.0.0" 456 | source = "registry+https://github.com/rust-lang/crates.io-index" 457 | checksum = "059d83cc991e7a42fc37bd50941885db0888e34209f8cfd9aab07ddec03bc9cf" 458 | 459 | [[package]] 460 | name = "typetag" 461 | version = "0.2.17" 462 | source = "registry+https://github.com/rust-lang/crates.io-index" 463 | checksum = "1f7ec175048b96728c30152928c52161bfcc8ea2bd3fb7ed4ccb7dec060b2834" 464 | dependencies = [ 465 | "erased-serde", 466 | "inventory", 467 | "once_cell", 468 | "serde", 469 | "typetag-impl", 470 | ] 471 | 472 | [[package]] 473 | name = "typetag-impl" 474 | version = "0.2.17" 475 | source = "registry+https://github.com/rust-lang/crates.io-index" 476 | checksum = "84b5474fd169a5b02b6782b56bbbbff27e85947d4488e5501123687db3148647" 477 | dependencies = [ 478 | "proc-macro2", 479 | "quote", 480 | "syn", 481 | ] 482 | 483 | [[package]] 484 | name = "unicode-ident" 485 | version = "1.0.12" 486 | source = "registry+https://github.com/rust-lang/crates.io-index" 487 | checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" 488 | 489 | [[package]] 490 | name = "unicode-width" 491 | version = "0.1.11" 492 | source = "registry+https://github.com/rust-lang/crates.io-index" 493 | checksum = "e51733f11c9c4f72aa0c160008246859e340b00807569a0da0e7a1079b27ba85" 494 | 495 | [[package]] 496 | name = "unindent" 497 | version = "0.2.3" 498 | source = "registry+https://github.com/rust-lang/crates.io-index" 499 | checksum = "c7de7d73e1754487cb58364ee906a499937a0dfabd86bcb980fa99ec8c8fa2ce" 500 | 501 | [[package]] 502 | name = "utf16_literal" 503 | version = "0.2.1" 504 | source = "registry+https://github.com/rust-lang/crates.io-index" 505 | checksum = "316f90fe4a7beb941ce0b6806ba7386f1a515155a42def71825f3f9a232e3f48" 506 | 507 | [[package]] 508 | name = "wasi" 509 | version = "0.9.0+wasi-snapshot-preview1" 510 | source = "registry+https://github.com/rust-lang/crates.io-index" 511 | checksum = "cccddf32554fecc6acb585f82a32a72e28b48f8c4c1883ddfeeeaa96f7d8e519" 512 | 513 | [[package]] 514 | name = "wasi" 515 | version = "0.11.0+wasi-snapshot-preview1" 516 | source = "registry+https://github.com/rust-lang/crates.io-index" 517 | checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" 518 | 519 | [[package]] 520 | name = "windows-sys" 521 | version = "0.52.0" 522 | source = "registry+https://github.com/rust-lang/crates.io-index" 523 | checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" 524 | dependencies = [ 525 | "windows-targets", 526 | ] 527 | 528 | [[package]] 529 | name = "windows-targets" 530 | version = "0.52.4" 531 | source = "registry+https://github.com/rust-lang/crates.io-index" 532 | checksum = "7dd37b7e5ab9018759f893a1952c9420d060016fc19a472b4bb20d1bdd694d1b" 533 | dependencies = [ 534 | "windows_aarch64_gnullvm", 535 | "windows_aarch64_msvc", 536 | "windows_i686_gnu", 537 | "windows_i686_msvc", 538 | "windows_x86_64_gnu", 539 | "windows_x86_64_gnullvm", 540 | "windows_x86_64_msvc", 541 | ] 542 | 543 | [[package]] 544 | name = "windows_aarch64_gnullvm" 545 | version = "0.52.4" 546 | source = "registry+https://github.com/rust-lang/crates.io-index" 547 | checksum = "bcf46cf4c365c6f2d1cc93ce535f2c8b244591df96ceee75d8e83deb70a9cac9" 548 | 549 | [[package]] 550 | name = "windows_aarch64_msvc" 551 | version = "0.52.4" 552 | source = "registry+https://github.com/rust-lang/crates.io-index" 553 | checksum = "da9f259dd3bcf6990b55bffd094c4f7235817ba4ceebde8e6d11cd0c5633b675" 554 | 555 | [[package]] 556 | name = "windows_i686_gnu" 557 | version = "0.52.4" 558 | source = "registry+https://github.com/rust-lang/crates.io-index" 559 | checksum = "b474d8268f99e0995f25b9f095bc7434632601028cf86590aea5c8a5cb7801d3" 560 | 561 | [[package]] 562 | name = "windows_i686_msvc" 563 | version = "0.52.4" 564 | source = "registry+https://github.com/rust-lang/crates.io-index" 565 | checksum = "1515e9a29e5bed743cb4415a9ecf5dfca648ce85ee42e15873c3cd8610ff8e02" 566 | 567 | [[package]] 568 | name = "windows_x86_64_gnu" 569 | version = "0.52.4" 570 | source = "registry+https://github.com/rust-lang/crates.io-index" 571 | checksum = "5eee091590e89cc02ad514ffe3ead9eb6b660aedca2183455434b93546371a03" 572 | 573 | [[package]] 574 | name = "windows_x86_64_gnullvm" 575 | version = "0.52.4" 576 | source = "registry+https://github.com/rust-lang/crates.io-index" 577 | checksum = "77ca79f2451b49fa9e2af39f0747fe999fcda4f5e241b2898624dca97a1f2177" 578 | 579 | [[package]] 580 | name = "windows_x86_64_msvc" 581 | version = "0.52.4" 582 | source = "registry+https://github.com/rust-lang/crates.io-index" 583 | checksum = "32b752e52a2da0ddfbdbcc6fceadfeede4c939ed16d13e648833a61dfb611ed8" 584 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "tokengrams" 3 | description = "Compute n-gram statistics and model language over pre-tokenized text corpora used to train large language models." 4 | license = "MIT" 5 | version = "0.3.3" 6 | edition = "2021" 7 | 8 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 9 | [lib] 10 | name = "tokengrams" 11 | crate-type = ["cdylib", "rlib"] 12 | 13 | [features] 14 | default = ["pyo3/extension-module"] 15 | rust = [] 16 | 17 | [dependencies] 18 | anyhow = "1.0.81" 19 | bincode = "1.3.3" 20 | funty = "2.0.0" 21 | indicatif = "0.17.8" 22 | memmap2 = "0.9.4" 23 | pyo3 = { version = "0.22.2", features = ["extension-module", "anyhow"] } 24 | rand = "0.8.5" 25 | rayon = "1.10.0" 26 | rayon-core = "1.12.1" 27 | serde = { version = "1.0.197", features = ["derive"] } 28 | typetag = "0.2.17" 29 | utf16_literal = "0.2.1" 30 | 31 | [[test]] 32 | name = "tests" 33 | path = "tests/tests.rs" 34 | 35 | [dev-dependencies] 36 | quickcheck = { version = "0.9", default-features = false } 37 | rand = "0.8.4" 38 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 EleutherAI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Tokengrams 2 | Tokengrams allows you to efficiently compute $n$-gram statistics for pre-tokenized text corpora used to train large language models. It does this not by explicitly pre-computing the $n$-gram counts for fixed $n$, but by creating a [suffix array](https://en.wikipedia.org/wiki/Suffix_array) index which allows you to efficiently compute the count of an $n$-gram on the fly for any $n$. 3 | 4 | Our code also allows you to turn your suffix array index into an efficient $n$-gram language model, which can be used to generate text or compute the perplexity of a given text. 5 | 6 | The backend is written in Rust, and the Python bindings are generated using [PyO3](https://github.com/PyO3/pyo3). 7 | 8 | # Installation 9 | 10 | ```bash 11 | pip install tokengrams 12 | ``` 13 | 14 | # Usage 15 | 16 | [Full text generation demo](https://colab.research.google.com/drive/1CEHoIjLboGl8YPbIqnWJlPYMm1wVOrrj?usp=sharing) 17 | 18 | ## Preparing data 19 | 20 | Use a dataset of u16 or u32 tokens, or prepare one from a HuggingFace dataset. 21 | 22 | ```python 23 | # Get pre-tokenized dataset 24 | from huggingface_hub import HfApi, hf_hub_download 25 | 26 | hf_hub_download( 27 | repo_id="EleutherAI/pile-standard-pythia-preshuffled", 28 | repo_type="dataset", 29 | filename="document-00000-of-00020.bin", 30 | local_dir="." 31 | ) 32 | ``` 33 | ```python 34 | # Tokenize HF dataset 35 | from tokengrams import tokenize_hf_dataset 36 | from datasets import load_dataset 37 | from transformers import AutoTokenizer 38 | 39 | tokenize_hf_dataset( 40 | dataset=load_dataset("EleutherAI/lambada_openai", "en"), 41 | tokenizer=AutoTokenizer.from_pretrained("EleutherAI/pythia-160m"), 42 | output_path="lambada.bin", 43 | text_key="text", 44 | append_eod=True, 45 | workers=1, 46 | ) 47 | ``` 48 | 49 | ## Building an index 50 | ```python 51 | from tokengrams import MemmapIndex 52 | 53 | # Create a new index from an on-disk corpus of u16 tokens and save it to a .idx file. 54 | # Set verbose to true to include a progress bar for the index sort. 55 | index = MemmapIndex.build( 56 | "document-00000-of-00020.bin", 57 | "document-00000-of-00020.idx", 58 | vocab=2**16, 59 | verbose=True 60 | ) 61 | 62 | # True for any valid index. 63 | print(index.is_sorted()) 64 | 65 | # Get the count of "hello world" in the corpus. 66 | from transformers import AutoTokenizer 67 | 68 | tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-160m") 69 | print(index.count(tokenizer.encode("hello world"))) 70 | 71 | # You can now load the index from disk later using __init__ 72 | index = MemmapIndex( 73 | "document-00000-of-00020.bin", 74 | "document-00000-of-00020.idx", 75 | vocab=2**16 76 | ) 77 | ``` 78 | 79 | ## Using an index 80 | 81 | ```python 82 | # Count how often each token in the corpus succeeds "hello world". 83 | print(index.count_next(tokenizer.encode("hello world"))) 84 | print(index.batch_count_next( 85 | [tokenizer.encode("hello world"), tokenizer.encode("hello universe")] 86 | )) 87 | 88 | # Get smoothed probabilities for query continuations 89 | print(index.smoothed_probs(tokenizer.encode("hello world"))) 90 | print(index.batch_smoothed_probs( 91 | [tokenizer.encode("hello world"), tokenizer.encode("hello universe")] 92 | )) 93 | 94 | # Autoregressively sample 10 tokens using 5-gram language statistics. Initial 95 | # gram statistics are derived from the query, with lower order gram statistics used 96 | # until the sequence contains at least 5 tokens. 97 | print(index.sample_unsmoothed(tokenizer.encode("hello world"), n=5, k=10, num_samples=20)) 98 | print(index.sample_smoothed(tokenizer.encode("hello world"), n=5, k=10, num_samples=20)) 99 | 100 | # Query whether the corpus contains "hello world" 101 | print(index.contains(tokenizer.encode("hello world"))) 102 | 103 | # Get all n-grams beginning with "hello world" in the corpus 104 | print(index.positions(tokenizer.encode("hello world"))) 105 | ``` 106 | 107 | ## Scaling 108 | 109 | Corpora small enough to fit in memory can use an InMemoryIndex: 110 | 111 | ```python 112 | from tokengrams import InMemoryIndex 113 | 114 | tokens = [0, 1, 2, 3, 4] 115 | index = InMemoryIndex(tokens, vocab=5) 116 | ``` 117 | 118 | Larger corpora must use a MemmapIndex. 119 | 120 | Some systems struggle with memory mapping extremely large tables (e.g. 40 billion tokens), causing unexpected bus errors. To prevent this split the corpus into shards then use a ShardedMemmapIndex to sort and query the table shard by shard: 121 | 122 | ```python 123 | from tokengrams import ShardedMemmapIndex 124 | from huggingface_hub import HfApi, hf_hub_download 125 | 126 | files = [ 127 | file for file in HfApi().list_repo_files("EleutherAI/pile-standard-pythia-preshuffled", repo_type="dataset") 128 | if file.endswith('.bin') 129 | ] 130 | 131 | index_paths = [] 132 | for file in files: 133 | hf_hub_download("EleutherAI/pile-standard-pythia-preshuffled", repo_type="dataset", filename=file, local_dir=".") 134 | index_paths.append((file, f'{file.rstrip(".bin")}.idx')) 135 | 136 | index = ShardedMemmapIndex.build(index_paths, vocab=2**16, verbose=True) 137 | ``` 138 | ### Tokens 139 | 140 | Tokengrams builds indices from on-disk corpora of either u16 or u32 tokens, supporting a maximum vocabulary size of 232. In practice, however, vocabulary size is limited by the length of the largest word size vector the machine can allocate in memory. 141 | 142 | Corpora with vocabulary sizes smaller than 216 must use u16 tokens. 143 | 144 | ## Performance 145 | 146 | Index build times for in-memory corpora scale inversely with the number of available CPU threads, whereas if the index reads from or writes to a file it is likely to be IO bound. 147 | 148 | The time complexities of count_next(query) and sample_unsmoothed(query) are O(n log n), where n is ~ the number of completions for the query. The time complexity of sample_smoothed(query) is O(m n log n) where m is the n-gram order. 149 | 150 | 151 | 152 | 153 | 154 | 155 |
Sample build times for an IO bound indexSample count_next times for an IO bound index
156 | 157 | # Development 158 | 159 | ```bash 160 | cargo build 161 | cargo test 162 | ``` 163 | 164 | Develop Python bindings: 165 | 166 | ```bash 167 | pip install maturin 168 | maturin develop 169 | pytest 170 | ``` 171 | 172 | # Support 173 | 174 | The best way to get support is to open an issue on this repo or post in #interp-across-time in the [EleutherAI Discord server](https://discord.gg/eleutherai). If you've used the library and have had a positive (or negative) experience, we'd love to hear from you! 175 | -------------------------------------------------------------------------------- /demos/generate_text.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%pip install -q maturin datasets transformers numpy pandas tokengrams" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": null, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "from tokengrams import MemmapIndex, tokenize_hf_dataset\n", 19 | "from datasets import load_dataset\n", 20 | "from transformers import AutoTokenizer" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": null, 26 | "metadata": {}, 27 | "outputs": [], 28 | "source": [ 29 | "tokenizer = AutoTokenizer.from_pretrained(\"EleutherAI/pythia-160m\")\n", 30 | "tokenize_hf_dataset(\n", 31 | " dataset=load_dataset(\"EleutherAI/lambada_openai\", \"en\"),\n", 32 | " tokenizer=tokenizer,\n", 33 | " output_path=\"lambada.bin\",\n", 34 | " text_key=\"text\",\n", 35 | " append_eod=False,\n", 36 | " workers=1,\n", 37 | ")\n", 38 | "\n", 39 | "index = MemmapIndex.build('lambada.bin', 'lambada.idx', vocab=2**16, verbose=True)" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": 4, 45 | "metadata": {}, 46 | "outputs": [ 47 | { 48 | "name": "stdout", 49 | "output_type": "stream", 50 | "text": [ 51 | "Once upon a time, Salem, an older wolf I'd never known as a human, had been the omega of the Boundary Wood pack. But I had seen enough of Shelby when I was clawing my way through the meningitis to know that she had fallen low in Paul's eyes and thus low in the packI'm wearing rented skis, rented ski boots that feel weird and tight and make me walk funny, plus every other kind of snow gear my mom was able to convince me to put on. I drew the line at goggles, and I stuck the unflattering wool hat into my jacket pocket, but from the neck down every inch of me is covered and padded. I don't know if I can move, let alone skiAway from the water that had changed everything for me, that had changed the lives of all of Jace's close friends. None of us would ever be the same again. But I knew that I couldn't protect myself from that kind\n" 52 | ] 53 | } 54 | ], 55 | "source": [ 56 | "sample = index.sample_unsmoothed(tokenizer.encode(\"Once\"), n=8, k=200, num_samples=1)[0]\n", 57 | "print(tokenizer.decode(sample))" 58 | ] 59 | } 60 | ], 61 | "metadata": { 62 | "kernelspec": { 63 | "display_name": "base", 64 | "language": "python", 65 | "name": "python3" 66 | }, 67 | "language_info": { 68 | "codemirror_mode": { 69 | "name": "ipython", 70 | "version": 3 71 | }, 72 | "file_extension": ".py", 73 | "mimetype": "text/x-python", 74 | "name": "python", 75 | "nbconvert_exporter": "python", 76 | "pygments_lexer": "ipython3", 77 | "version": "3.10.14" 78 | } 79 | }, 80 | "nbformat": 4, 81 | "nbformat_minor": 2 82 | } 83 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: test 2 | channels: 3 | - conda-forge 4 | - defaults 5 | dependencies: 6 | - python=3.10 7 | - numpy 8 | - pytest 9 | - hypothesis 10 | - maturin -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "tokengrams" 3 | version = "0.3.3" 4 | description = "Efficiently computing & storing token n-grams from large corpora " 5 | authors = [ 6 | { name = "Nora Belrose", email = "nora@eleuther.ai" }, 7 | { name = "Lucia Quirke", email = "lucia@eleuther.ai" } 8 | ] 9 | dependencies = [ 10 | "numpy>=1.24.4", 11 | "datasets>=1.14.0", 12 | "transformers>=4.11.3", 13 | "tqdm>=4.0.0", 14 | ] 15 | readme = "README.md" 16 | requires-python = ">= 3.10" 17 | 18 | [build-system] 19 | requires = ["maturin>=1.2,<2.0"] 20 | build-backend = "maturin" 21 | 22 | [tool.maturin] 23 | module-name = "tokengrams.tokengrams" 24 | features = ["pyo3/extension-module"] 25 | -------------------------------------------------------------------------------- /src/bindings/in_memory_index.rs: -------------------------------------------------------------------------------- 1 | use crate::in_memory_index::InMemoryIndexRs; 2 | use anyhow::Result; 3 | use pyo3::prelude::*; 4 | 5 | /// An in-memory index exposes suffix table functionality over text corpora small enough to fit in memory. 6 | /// Non-generic PyO3 wrapper over InMemoryIndexRs. 7 | #[pyclass] 8 | pub struct InMemoryIndex { 9 | index: Box, 10 | } 11 | 12 | /// This trait is non-generic for PyO3 compatibility. Implementing structs may cast data 13 | /// to other unsigned integer types. 14 | pub trait InMemoryIndexTrait { 15 | fn save_text(&self, path: String) -> Result<()>; 16 | fn save_table(&self, path: String) -> Result<()>; 17 | fn is_sorted(&self) -> bool; 18 | fn contains(&self, query: Vec) -> bool; 19 | fn positions(&self, query: Vec) -> Vec; 20 | fn count(&self, query: Vec) -> usize; 21 | fn count_next(&self, query: Vec) -> Vec; 22 | fn batch_count_next(&self, queries: Vec>) -> Vec>; 23 | fn sample_unsmoothed( 24 | &self, 25 | query: Vec, 26 | n: usize, 27 | k: usize, 28 | num_samples: usize, 29 | ) -> Result>>; 30 | fn sample_smoothed( 31 | &mut self, 32 | query: Vec, 33 | n: usize, 34 | k: usize, 35 | num_samples: usize, 36 | ) -> Result>>; 37 | fn get_smoothed_probs(&mut self, query: Vec) -> Vec; 38 | fn batch_get_smoothed_probs(&mut self, queries: Vec>) -> Vec>; 39 | fn estimate_deltas(&mut self, n: usize); 40 | } 41 | 42 | #[pymethods] 43 | impl InMemoryIndex { 44 | #[new] 45 | #[pyo3(signature = (tokens, vocab=u16::MAX as usize + 1, verbose=false))] 46 | pub fn new_py(_py: Python, tokens: Vec, vocab: usize, verbose: bool) -> Self { 47 | let index: Box = if vocab <= u16::MAX as usize + 1 { 48 | let tokens: Vec = tokens.iter().map(|&x| x as u16).collect(); 49 | Box::new(InMemoryIndexRs::::new(tokens, Some(vocab), verbose)) 50 | } else { 51 | let tokens: Vec = tokens.iter().map(|&x| x as u32).collect(); 52 | Box::new(InMemoryIndexRs::::new(tokens, Some(vocab), verbose)) 53 | }; 54 | 55 | InMemoryIndex { index } 56 | } 57 | 58 | #[staticmethod] 59 | #[pyo3(signature = (path, token_limit=None, vocab=u16::MAX as usize + 1, verbose=false))] 60 | pub fn from_token_file( 61 | path: String, 62 | token_limit: Option, 63 | vocab: usize, 64 | verbose: bool, 65 | ) -> Result { 66 | if vocab <= u16::MAX as usize + 1 { 67 | Ok(InMemoryIndex { 68 | index: Box::new(InMemoryIndexRs::::from_token_file( 69 | path, 70 | token_limit, 71 | vocab, 72 | verbose, 73 | )?), 74 | }) 75 | } else { 76 | Ok(InMemoryIndex { 77 | index: Box::new(InMemoryIndexRs::::from_token_file( 78 | path, 79 | token_limit, 80 | vocab, 81 | verbose, 82 | )?), 83 | }) 84 | } 85 | } 86 | 87 | #[staticmethod] 88 | #[pyo3(signature = (token_path, index_path, vocab=u16::MAX as usize + 1))] 89 | pub fn from_disk(token_path: String, index_path: String, vocab: usize) -> Result { 90 | if vocab <= u16::MAX as usize + 1 { 91 | Ok(InMemoryIndex { 92 | index: Box::new(InMemoryIndexRs::::from_disk( 93 | token_path, index_path, vocab, 94 | )?), 95 | }) 96 | } else { 97 | Ok(InMemoryIndex { 98 | index: Box::new(InMemoryIndexRs::::from_disk( 99 | token_path, index_path, vocab, 100 | )?), 101 | }) 102 | } 103 | } 104 | 105 | pub fn save_tokens(&self, path: String) -> Result<()> { 106 | self.index.save_text(path) 107 | } 108 | 109 | pub fn save_index(&self, path: String) -> Result<()> { 110 | self.index.save_table(path) 111 | } 112 | 113 | pub fn is_sorted(&self) -> bool { 114 | self.index.is_sorted() 115 | } 116 | 117 | pub fn contains(&self, query: Vec) -> bool { 118 | self.index.contains(query) 119 | } 120 | 121 | pub fn positions(&self, query: Vec) -> Vec { 122 | self.index.positions(query).to_vec() 123 | } 124 | 125 | pub fn count(&self, query: Vec) -> usize { 126 | self.index.count(query) 127 | } 128 | 129 | pub fn count_next(&self, query: Vec) -> Vec { 130 | self.index.count_next(query) 131 | } 132 | 133 | pub fn batch_count_next(&self, queries: Vec>) -> Vec> { 134 | self.index.batch_count_next(queries) 135 | } 136 | 137 | /// Autoregressively sample num_samples of k characters from an unsmoothed n-gram model.""" 138 | pub fn sample_unsmoothed( 139 | &self, 140 | query: Vec, 141 | n: usize, 142 | k: usize, 143 | num_samples: usize, 144 | ) -> Result>> { 145 | self.index.sample_unsmoothed(query, n, k, num_samples) 146 | } 147 | 148 | /// Returns interpolated Kneser-Ney smoothed token probability distribution using all previous 149 | /// tokens in the query. 150 | pub fn get_smoothed_probs(&mut self, query: Vec) -> Vec { 151 | self.index.get_smoothed_probs(query) 152 | } 153 | 154 | /// Returns interpolated Kneser-Ney smoothed token probability distribution using all previous 155 | /// tokens in the query. 156 | pub fn batch_get_smoothed_probs(&mut self, queries: Vec>) -> Vec> { 157 | self.index.batch_get_smoothed_probs(queries) 158 | } 159 | 160 | /// Autoregressively sample num_samples of k characters from a Kneser-Ney smoothed n-gram model. 161 | pub fn sample_smoothed( 162 | &mut self, 163 | query: Vec, 164 | n: usize, 165 | k: usize, 166 | num_samples: usize, 167 | ) -> Result>> { 168 | self.index.sample_smoothed(query, n, k, num_samples) 169 | } 170 | 171 | /// Warning: O(k**n) where k is vocabulary size, use with caution. 172 | /// Improve smoothed model quality by replacing the default delta hyperparameters 173 | /// for models of order n and below with improved estimates over the entire index. 174 | /// , page 16. 175 | pub fn estimate_deltas(&mut self, n: usize) { 176 | self.index.estimate_deltas(n); 177 | } 178 | } 179 | -------------------------------------------------------------------------------- /src/bindings/memmap_index.rs: -------------------------------------------------------------------------------- 1 | use crate::memmap_index::MemmapIndexRs; 2 | use anyhow::Result; 3 | use pyo3::prelude::*; 4 | 5 | /// A memmap index exposes suffix table functionality over text corpora too large to fit in memory. 6 | #[pyclass] 7 | pub struct MemmapIndex { 8 | index: Box, 9 | } 10 | 11 | /// This trait is non-generic for PyO3 compatibility. Implementing structs may cast data 12 | /// to other unsigned integer types. 13 | pub trait MemmapIndexTrait { 14 | fn is_sorted(&self) -> bool; 15 | fn contains(&self, query: Vec) -> bool; 16 | fn positions(&self, query: Vec) -> Vec; 17 | fn count(&self, query: Vec) -> usize; 18 | fn count_next(&self, query: Vec) -> Vec; 19 | fn batch_count_next(&self, queries: Vec>) -> Vec>; 20 | fn sample_unsmoothed( 21 | &self, 22 | query: Vec, 23 | n: usize, 24 | k: usize, 25 | num_samples: usize, 26 | ) -> Result>>; 27 | fn sample_smoothed( 28 | &mut self, 29 | query: Vec, 30 | n: usize, 31 | k: usize, 32 | num_samples: usize, 33 | ) -> Result>>; 34 | fn get_smoothed_probs(&mut self, query: Vec) -> Vec; 35 | fn batch_get_smoothed_probs(&mut self, queries: Vec>) -> Vec>; 36 | fn estimate_deltas(&mut self, n: usize); 37 | } 38 | 39 | #[pymethods] 40 | impl MemmapIndex { 41 | #[new] 42 | #[pyo3(signature = (text_path, table_path, vocab=u16::MAX as usize + 1))] 43 | pub fn new( 44 | _py: Python, 45 | text_path: String, 46 | table_path: String, 47 | vocab: usize, 48 | ) -> PyResult { 49 | if vocab <= u16::MAX as usize + 1 { 50 | Ok(MemmapIndex { 51 | index: Box::new(MemmapIndexRs::::new(text_path, table_path, vocab)?), 52 | }) 53 | } else { 54 | Ok(MemmapIndex { 55 | index: Box::new(MemmapIndexRs::::new(text_path, table_path, vocab)?), 56 | }) 57 | } 58 | } 59 | 60 | #[staticmethod] 61 | #[pyo3(signature = (text_path, table_path, vocab=u16::MAX as usize + 1, verbose=false))] 62 | pub fn build( 63 | text_path: String, 64 | table_path: String, 65 | vocab: usize, 66 | verbose: bool, 67 | ) -> PyResult { 68 | if vocab <= u16::MAX as usize + 1 { 69 | Ok(MemmapIndex { 70 | index: Box::new(MemmapIndexRs::::build( 71 | text_path, table_path, vocab, verbose, 72 | )?), 73 | }) 74 | } else { 75 | Ok(MemmapIndex { 76 | index: Box::new(MemmapIndexRs::::build( 77 | text_path, table_path, vocab, verbose, 78 | )?), 79 | }) 80 | } 81 | } 82 | 83 | pub fn is_sorted(&self) -> bool { 84 | self.index.is_sorted() 85 | } 86 | 87 | pub fn contains(&self, query: Vec) -> bool { 88 | self.index.contains(query) 89 | } 90 | 91 | pub fn positions(&self, query: Vec) -> Vec { 92 | self.index.positions(query).to_vec() 93 | } 94 | 95 | pub fn count(&self, query: Vec) -> usize { 96 | self.index.positions(query).len() 97 | } 98 | 99 | pub fn count_next(&self, query: Vec) -> Vec { 100 | self.index.count_next(query) 101 | } 102 | 103 | pub fn batch_count_next(&self, queries: Vec>) -> Vec> { 104 | self.index.batch_count_next(queries) 105 | } 106 | 107 | /// Autoregressively sample num_samples of k characters from an unsmoothed n-gram model.""" 108 | pub fn sample_unsmoothed( 109 | &self, 110 | query: Vec, 111 | n: usize, 112 | k: usize, 113 | num_samples: usize, 114 | ) -> Result>> { 115 | self.index.sample_unsmoothed(query, n, k, num_samples) 116 | } 117 | 118 | /// Returns interpolated Kneser-Ney smoothed token probability distribution using all previous 119 | /// tokens in the query. 120 | pub fn get_smoothed_probs(&mut self, query: Vec) -> Vec { 121 | self.index.get_smoothed_probs(query) 122 | } 123 | 124 | /// Returns interpolated Kneser-Ney smoothed token probability distribution using all previous 125 | /// tokens in the query. 126 | pub fn batch_get_smoothed_probs(&mut self, queries: Vec>) -> Vec> { 127 | self.index.batch_get_smoothed_probs(queries) 128 | } 129 | 130 | /// Autoregressively sample num_samples of k characters from a Kneser-Ney smoothed n-gram model. 131 | pub fn sample_smoothed( 132 | &mut self, 133 | query: Vec, 134 | n: usize, 135 | k: usize, 136 | num_samples: usize, 137 | ) -> Result>> { 138 | self.index.sample_smoothed(query, n, k, num_samples) 139 | } 140 | 141 | /// Warning: O(k**n) where k is vocabulary size, use with caution. 142 | /// Improve smoothed model quality by replacing the default delta hyperparameters 143 | /// for models of order n and below with improved estimates over the entire index. 144 | /// , page 16. 145 | pub fn estimate_deltas(&mut self, n: usize) { 146 | self.index.estimate_deltas(n); 147 | } 148 | } 149 | -------------------------------------------------------------------------------- /src/bindings/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod in_memory_index; 2 | pub mod memmap_index; 3 | pub mod sharded_memmap_index; 4 | pub mod sharded_in_memory_index; -------------------------------------------------------------------------------- /src/bindings/sharded_in_memory_index.rs: -------------------------------------------------------------------------------- 1 | use crate::sharded_in_memory_index::ShardedInMemoryIndexRs; 2 | use anyhow::Result; 3 | use pyo3::prelude::*; 4 | 5 | /// Expose suffix table functionality over text corpora too large to fit in memory. 6 | #[pyclass] 7 | pub struct ShardedInMemoryIndex { 8 | index: Box, 9 | } 10 | 11 | /// This trait is non-generic for PyO3 compatibility. Implementing structs may cast data 12 | /// to other unsigned integer types. 13 | pub trait ShardedInMemoryIndexTrait { 14 | fn is_sorted(&self) -> bool; 15 | fn contains(&self, query: Vec) -> bool; 16 | fn count(&self, query: Vec) -> usize; 17 | fn count_next(&self, query: Vec) -> Vec; 18 | fn batch_count_next(&self, queries: Vec>) -> Vec>; 19 | fn sample_unsmoothed( 20 | &self, 21 | query: Vec, 22 | n: usize, 23 | k: usize, 24 | num_samples: usize, 25 | ) -> Result>>; 26 | fn sample_smoothed( 27 | &mut self, 28 | query: Vec, 29 | n: usize, 30 | k: usize, 31 | num_samples: usize, 32 | ) -> Result>>; 33 | fn get_smoothed_probs(&mut self, query: Vec) -> Vec; 34 | fn batch_get_smoothed_probs(&mut self, queries: Vec>) -> Vec>; 35 | fn estimate_deltas(&mut self, n: usize); 36 | } 37 | 38 | #[pymethods] 39 | impl ShardedInMemoryIndex { 40 | #[new] 41 | #[pyo3(signature = (paths, vocab=u16::MAX as usize + 1))] 42 | pub fn new(_py: Python, paths: Vec<(String, String)>, vocab: usize) -> PyResult { 43 | if vocab <= u16::MAX as usize + 1 { 44 | Ok(ShardedInMemoryIndex { 45 | index: Box::new(ShardedInMemoryIndexRs::::new(paths, vocab)?), 46 | }) 47 | } else { 48 | Ok(ShardedInMemoryIndex { 49 | index: Box::new(ShardedInMemoryIndexRs::::new(paths, vocab)?), 50 | }) 51 | } 52 | } 53 | 54 | pub fn is_sorted(&self) -> bool { 55 | self.index.is_sorted() 56 | } 57 | 58 | pub fn contains(&self, query: Vec) -> bool { 59 | self.index.contains(query) 60 | } 61 | 62 | pub fn count(&self, query: Vec) -> usize { 63 | self.index.count(query) 64 | } 65 | 66 | pub fn count_next(&self, query: Vec) -> Vec { 67 | self.index.count_next(query) 68 | } 69 | 70 | pub fn batch_count_next(&self, queries: Vec>) -> Vec> { 71 | self.index.batch_count_next(queries) 72 | } 73 | 74 | /// Autoregressively sample num_samples of k characters from an unsmoothed n-gram model.""" 75 | pub fn sample_unsmoothed( 76 | &self, 77 | query: Vec, 78 | n: usize, 79 | k: usize, 80 | num_samples: usize, 81 | ) -> Result>> { 82 | self.index.sample_unsmoothed(query, n, k, num_samples) 83 | } 84 | 85 | /// Returns interpolated Kneser-Ney smoothed token probability distribution using all previous 86 | /// tokens in the query. 87 | pub fn get_smoothed_probs(&mut self, query: Vec) -> Vec { 88 | self.index.get_smoothed_probs(query) 89 | } 90 | 91 | /// Returns interpolated Kneser-Ney smoothed token probability distribution using all previous 92 | /// tokens in the query. 93 | pub fn batch_get_smoothed_probs(&mut self, queries: Vec>) -> Vec> { 94 | self.index.batch_get_smoothed_probs(queries) 95 | } 96 | 97 | /// Autoregressively sample num_samples of k characters from a Kneser-Ney smoothed n-gram model. 98 | pub fn sample_smoothed( 99 | &mut self, 100 | query: Vec, 101 | n: usize, 102 | k: usize, 103 | num_samples: usize, 104 | ) -> Result>> { 105 | self.index.sample_smoothed(query, n, k, num_samples) 106 | } 107 | 108 | /// Warning: O(k**n) where k is vocabulary size, use with caution. 109 | /// Improve smoothed model quality by replacing the default delta hyperparameters 110 | /// for models of order n and below with improved estimates over the entire index. 111 | /// , page 16. 112 | pub fn estimate_deltas(&mut self, n: usize) { 113 | self.index.estimate_deltas(n); 114 | } 115 | } 116 | -------------------------------------------------------------------------------- /src/bindings/sharded_memmap_index.rs: -------------------------------------------------------------------------------- 1 | use crate::sharded_memmap_index::ShardedMemmapIndexRs; 2 | use anyhow::Result; 3 | use pyo3::prelude::*; 4 | 5 | /// Expose suffix table functionality over text corpora too large to fit in memory. 6 | #[pyclass] 7 | pub struct ShardedMemmapIndex { 8 | index: Box, 9 | } 10 | 11 | /// This trait is non-generic for PyO3 compatibility. Implementing structs may cast data 12 | /// to other unsigned integer types. 13 | pub trait ShardedMemmapIndexTrait { 14 | fn is_sorted(&self) -> bool; 15 | fn contains(&self, query: Vec) -> bool; 16 | fn count(&self, query: Vec) -> usize; 17 | fn count_next(&self, query: Vec) -> Vec; 18 | fn batch_count_next(&self, queries: Vec>) -> Vec>; 19 | fn sample_unsmoothed( 20 | &self, 21 | query: Vec, 22 | n: usize, 23 | k: usize, 24 | num_samples: usize, 25 | ) -> Result>>; 26 | fn sample_smoothed( 27 | &mut self, 28 | query: Vec, 29 | n: usize, 30 | k: usize, 31 | num_samples: usize, 32 | ) -> Result>>; 33 | fn get_smoothed_probs(&mut self, query: Vec) -> Vec; 34 | fn batch_get_smoothed_probs(&mut self, queries: Vec>) -> Vec>; 35 | fn estimate_deltas(&mut self, n: usize); 36 | } 37 | 38 | #[pymethods] 39 | impl ShardedMemmapIndex { 40 | #[new] 41 | #[pyo3(signature = (paths, vocab=u16::MAX as usize + 1))] 42 | pub fn new(_py: Python, paths: Vec<(String, String)>, vocab: usize) -> PyResult { 43 | if vocab <= u16::MAX as usize + 1 { 44 | Ok(ShardedMemmapIndex { 45 | index: Box::new(ShardedMemmapIndexRs::::new(paths, vocab)?), 46 | }) 47 | } else { 48 | Ok(ShardedMemmapIndex { 49 | index: Box::new(ShardedMemmapIndexRs::::new(paths, vocab)?), 50 | }) 51 | } 52 | } 53 | 54 | #[staticmethod] 55 | #[pyo3(signature = (paths, vocab=u16::MAX as usize + 1, verbose=false))] 56 | pub fn build(paths: Vec<(String, String)>, vocab: usize, verbose: bool) -> PyResult { 57 | if vocab <= u16::MAX as usize + 1 { 58 | Ok(ShardedMemmapIndex { 59 | index: Box::new(ShardedMemmapIndexRs::::build(paths, vocab, verbose)?), 60 | }) 61 | } else { 62 | Ok(ShardedMemmapIndex { 63 | index: Box::new(ShardedMemmapIndexRs::::build(paths, vocab, verbose)?), 64 | }) 65 | } 66 | } 67 | 68 | pub fn is_sorted(&self) -> bool { 69 | self.index.is_sorted() 70 | } 71 | 72 | pub fn contains(&self, query: Vec) -> bool { 73 | self.index.contains(query) 74 | } 75 | 76 | pub fn count(&self, query: Vec) -> usize { 77 | self.index.count(query) 78 | } 79 | 80 | pub fn count_next(&self, query: Vec) -> Vec { 81 | self.index.count_next(query) 82 | } 83 | 84 | pub fn batch_count_next(&self, queries: Vec>) -> Vec> { 85 | self.index.batch_count_next(queries) 86 | } 87 | 88 | /// Autoregressively sample num_samples of k characters from an unsmoothed n-gram model.""" 89 | pub fn sample_unsmoothed( 90 | &self, 91 | query: Vec, 92 | n: usize, 93 | k: usize, 94 | num_samples: usize, 95 | ) -> Result>> { 96 | self.index.sample_unsmoothed(query, n, k, num_samples) 97 | } 98 | 99 | /// Returns interpolated Kneser-Ney smoothed token probability distribution using all previous 100 | /// tokens in the query. 101 | pub fn get_smoothed_probs(&mut self, query: Vec) -> Vec { 102 | self.index.get_smoothed_probs(query) 103 | } 104 | 105 | /// Returns interpolated Kneser-Ney smoothed token probability distribution using all previous 106 | /// tokens in the query. 107 | pub fn batch_get_smoothed_probs(&mut self, queries: Vec>) -> Vec> { 108 | self.index.batch_get_smoothed_probs(queries) 109 | } 110 | 111 | /// Autoregressively sample num_samples of k characters from a Kneser-Ney smoothed n-gram model. 112 | pub fn sample_smoothed( 113 | &mut self, 114 | query: Vec, 115 | n: usize, 116 | k: usize, 117 | num_samples: usize, 118 | ) -> Result>> { 119 | self.index.sample_smoothed(query, n, k, num_samples) 120 | } 121 | 122 | /// Warning: O(k**n) where k is vocabulary size, use with caution. 123 | /// Improve smoothed model quality by replacing the default delta hyperparameters 124 | /// for models of order n and below with improved estimates over the entire index. 125 | /// , page 16. 126 | pub fn estimate_deltas(&mut self, n: usize) { 127 | self.index.estimate_deltas(n); 128 | } 129 | } 130 | -------------------------------------------------------------------------------- /src/in_memory_index.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Result; 2 | use funty::Unsigned; 3 | use rayon::prelude::*; 4 | use std::collections::HashMap; 5 | use std::fmt::Debug; 6 | use std::fs::{File, OpenOptions}; 7 | use std::io::Read; 8 | 9 | use crate::bindings::in_memory_index::InMemoryIndexTrait; 10 | use crate::mmap_slice::MmapSliceMut; 11 | use crate::sample::{KneserNeyCache, Sample}; 12 | use crate::table::SuffixTable; 13 | use crate::util::transmute_slice; 14 | 15 | /// An in-memory index exposes suffix table functionality over text corpora small enough to fit in memory. 16 | pub struct InMemoryIndexRs { 17 | table: SuffixTable, Box<[u64]>>, 18 | cache: KneserNeyCache, 19 | } 20 | 21 | impl InMemoryIndexRs { 22 | pub fn new(tokens: Vec, vocab: Option, verbose: bool) -> Self { 23 | let vocab = vocab.unwrap_or(u16::MAX as usize + 1); 24 | 25 | let table = SuffixTable::new(tokens, Some(vocab), verbose); 26 | debug_assert!(table.is_sorted()); 27 | 28 | InMemoryIndexRs { 29 | table, 30 | cache: KneserNeyCache::default(), 31 | } 32 | } 33 | 34 | pub fn from_token_file( 35 | path: String, 36 | token_limit: Option, 37 | vocab: usize, 38 | verbose: bool, 39 | ) -> Result { 40 | let mut buffer = Vec::new(); 41 | let mut file = File::open(&path)?; 42 | 43 | if let Some(max_tokens) = token_limit { 44 | // Limit on the number of tokens to consider is provided 45 | let max_bytes = max_tokens * std::mem::size_of::(); 46 | file.take(max_bytes as u64).read_to_end(&mut buffer)?; 47 | } else { 48 | file.read_to_end(&mut buffer)?; 49 | }; 50 | 51 | let tokens = transmute_slice::(buffer.as_slice()); 52 | let table = SuffixTable::new(tokens, Some(vocab), verbose); 53 | debug_assert!(table.is_sorted()); 54 | 55 | Ok(InMemoryIndexRs { 56 | table, 57 | cache: KneserNeyCache::default(), 58 | }) 59 | } 60 | 61 | fn read_file_to_boxed_slice(path: &str) -> Result> { 62 | let mut file = File::open(path)?; 63 | let file_len_bytes = file.metadata()?.len() as usize; 64 | 65 | // Ensure file size is a multiple of size of E 66 | if file_len_bytes % std::mem::size_of::() != 0 { 67 | anyhow::bail!("File size is not a multiple of element size"); 68 | } 69 | 70 | let num_elements = file_len_bytes / std::mem::size_of::(); 71 | let mut vec: Vec = Vec::with_capacity(num_elements); 72 | unsafe { 73 | let buf = std::slice::from_raw_parts_mut(vec.as_mut_ptr() as *mut u8, file_len_bytes); 74 | file.read_exact(buf)?; 75 | vec.set_len(num_elements); 76 | } 77 | 78 | Ok(vec.into_boxed_slice()) 79 | } 80 | 81 | pub fn from_disk(text_path: String, table_path: String, vocab: usize) -> Result { 82 | let text = Self::read_file_to_boxed_slice::(&text_path)?; 83 | let table = Self::read_file_to_boxed_slice::(&table_path)?; 84 | 85 | let suffix_table = SuffixTable::from_parts(text, table, Some(vocab)); 86 | debug_assert!(suffix_table.is_sorted()); 87 | 88 | Ok(InMemoryIndexRs { 89 | table: suffix_table, 90 | cache: KneserNeyCache::default(), 91 | }) 92 | } 93 | 94 | pub fn save_text(&self, path: String) -> Result<()> { 95 | let text = self.table.get_text(); 96 | let file = OpenOptions::new() 97 | .create(true) 98 | .read(true) 99 | .write(true) 100 | .open(&path)?; 101 | 102 | let file_len = text.len() * std::mem::size_of::(); 103 | file.set_len(file_len as u64)?; 104 | 105 | let mut mmap = MmapSliceMut::::new(&file)?; 106 | mmap.copy_from_slice(text); 107 | mmap.flush()?; 108 | 109 | Ok(()) 110 | } 111 | 112 | pub fn save_table(&self, path: String) -> Result<()> { 113 | let table = self.table.get_table(); 114 | let file = OpenOptions::new() 115 | .create(true) 116 | .read(true) 117 | .write(true) 118 | .open(&path)?; 119 | 120 | file.set_len((table.len() * 8) as u64)?; 121 | 122 | let mut mmap = MmapSliceMut::::new(&file)?; 123 | mmap.copy_from_slice(table); 124 | mmap.flush()?; 125 | 126 | Ok(()) 127 | } 128 | } 129 | 130 | impl Sample for InMemoryIndexRs { 131 | fn get_cache(&self) -> &KneserNeyCache { 132 | &self.cache 133 | } 134 | 135 | fn get_mut_cache(&mut self) -> &mut KneserNeyCache { 136 | &mut self.cache 137 | } 138 | 139 | fn count_next_slice(&self, query: &[T]) -> Vec { 140 | self.table.count_next(query) 141 | } 142 | 143 | fn count_ngrams(&self, n: usize) -> HashMap { 144 | self.table.count_ngrams(n) 145 | } 146 | } 147 | 148 | impl InMemoryIndexTrait for InMemoryIndexRs { 149 | fn save_table(&self, table_path: String) -> Result<()> { 150 | self.save_table(table_path) 151 | } 152 | 153 | fn save_text(&self, text_path: String) -> Result<()> { 154 | self.save_text(text_path) 155 | } 156 | 157 | fn is_sorted(&self) -> bool { 158 | self.table.is_sorted() 159 | } 160 | 161 | fn contains(&self, query: Vec) -> bool { 162 | let query: Vec = query 163 | .iter() 164 | .filter_map(|&item| T::try_from(item).ok()) 165 | .collect(); 166 | self.table.contains(&query) 167 | } 168 | 169 | fn positions(&self, query: Vec) -> Vec { 170 | let query: Vec = query 171 | .iter() 172 | .filter_map(|&item| T::try_from(item).ok()) 173 | .collect(); 174 | self.table.positions(&query).to_vec() 175 | } 176 | 177 | fn count(&self, query: Vec) -> usize { 178 | let query: Vec = query 179 | .iter() 180 | .filter_map(|&item| T::try_from(item).ok()) 181 | .collect(); 182 | self.table.positions(&query).len() 183 | } 184 | 185 | fn count_next(&self, query: Vec) -> Vec { 186 | let query: Vec = query 187 | .iter() 188 | .filter_map(|&item| T::try_from(item).ok()) 189 | .collect(); 190 | self.table.count_next(&query) 191 | } 192 | 193 | fn batch_count_next(&self, queries: Vec>) -> Vec> { 194 | queries 195 | .into_par_iter() 196 | .map(|query| self.count_next(query)) 197 | .collect() 198 | } 199 | 200 | fn sample_smoothed( 201 | &mut self, 202 | query: Vec, 203 | n: usize, 204 | k: usize, 205 | num_samples: usize, 206 | ) -> Result>> { 207 | let query: Vec = query 208 | .iter() 209 | .filter_map(|&item| T::try_from(item).ok()) 210 | .collect(); 211 | 212 | let samples_batch = >::sample_smoothed(self, &query, n, k, num_samples)?; 213 | Ok(samples_batch 214 | .into_iter() 215 | .map(|samples| { 216 | samples 217 | .into_iter() 218 | .filter_map(|sample| { 219 | match TryInto::::try_into(sample) { 220 | Ok(value) => Some(value), 221 | Err(_) => None, // Silently skip values that can't be converted 222 | } 223 | }) 224 | .collect::>() 225 | }) 226 | .collect()) 227 | } 228 | 229 | fn sample_unsmoothed( 230 | &self, 231 | query: Vec, 232 | n: usize, 233 | k: usize, 234 | num_samples: usize, 235 | ) -> Result>> { 236 | let query: Vec = query 237 | .iter() 238 | .filter_map(|&item| T::try_from(item).ok()) 239 | .collect(); 240 | 241 | let samples_batch = 242 | >::sample_unsmoothed(self, &query, n, k, num_samples)?; 243 | Ok(samples_batch 244 | .into_iter() 245 | .map(|samples| { 246 | samples 247 | .into_iter() 248 | .filter_map(|sample| { 249 | match TryInto::::try_into(sample) { 250 | Ok(value) => Some(value), 251 | Err(_) => None, // Silently skip values that can't be converted 252 | } 253 | }) 254 | .collect::>() 255 | }) 256 | .collect()) 257 | } 258 | 259 | fn get_smoothed_probs(&mut self, query: Vec) -> Vec { 260 | let query: Vec = query 261 | .iter() 262 | .filter_map(|&item| T::try_from(item).ok()) 263 | .collect(); 264 | 265 | >::get_smoothed_probs(self, &query) 266 | } 267 | 268 | fn batch_get_smoothed_probs(&mut self, queries: Vec>) -> Vec> { 269 | let queries: Vec> = queries 270 | .into_iter() 271 | .map(|query| { 272 | query 273 | .iter() 274 | .filter_map(|&item| T::try_from(item).ok()) 275 | .collect() 276 | }) 277 | .collect(); 278 | >::batch_get_smoothed_probs(self, &queries) 279 | } 280 | 281 | fn estimate_deltas(&mut self, n: usize) { 282 | >::estimate_deltas(self, n) 283 | } 284 | } 285 | 286 | 287 | #[cfg(test)] 288 | pub mod tests { 289 | use super::*; 290 | use utf16_literal::utf16; 291 | use crate::table::SuffixTable; 292 | 293 | fn sais(text: &str) -> SuffixTable { 294 | SuffixTable::new(text.encode_utf16().collect::>(), None, false) 295 | } 296 | 297 | fn utf16_as_usize(s: &str) -> Vec { 298 | s.encode_utf16().map(|x| x as usize).collect() 299 | } 300 | 301 | #[test] 302 | fn sample_unsmoothed_empty_query_exists() { 303 | let s = utf16!("aaa"); 304 | let index: Box> = Box::new(InMemoryIndexRs::new(s.to_vec(), None, false)); 305 | 306 | let seqs = index.sample_unsmoothed(&[], 3, 10, 1).unwrap(); 307 | 308 | assert_eq!(*seqs[0].last().unwrap(), s[0]); 309 | } 310 | 311 | #[test] 312 | fn sample_unsmoothed_u16_exists() { 313 | let s = utf16!("aaaa"); 314 | let a = &s[0..1]; 315 | let index: Box> = Box::new(InMemoryIndexRs::new(s.to_vec(), None, false)); 316 | 317 | let seqs = index.sample_unsmoothed(a, 3, 10, 1).unwrap(); 318 | 319 | assert_eq!(*seqs[0].last().unwrap(), a[0]); 320 | } 321 | 322 | #[test] 323 | fn sample_unsmoothed_u32_exists() { 324 | let s: Vec = "aaaa".encode_utf16().map(|c| c as u32).collect(); 325 | let u32_vocab = Some(u16::MAX as usize + 2); 326 | let index: Box> = Box::new(InMemoryIndexRs::::new(s.clone(), u32_vocab, false)); 327 | 328 | let seqs = index.sample_unsmoothed(&s[0..1], 3, 10, 1).unwrap(); 329 | 330 | assert_eq!(*seqs[0].last().unwrap(), s[0]); 331 | } 332 | 333 | #[test] 334 | fn sample_unsmoothed_usize_exists() { 335 | let s = utf16_as_usize("aaaa"); 336 | let index: Box = Box::new(InMemoryIndexRs::new(s.to_vec(), None, false)); 337 | 338 | let seqs = index.sample_unsmoothed(s[0..1].to_vec(), 3, 10, 1).unwrap(); 339 | 340 | assert_eq!(*seqs[0].last().unwrap(), s[0]); 341 | } 342 | 343 | #[test] 344 | fn sample_smoothed_exists() { 345 | let s = utf16!("aabbccabccba"); 346 | let mut index: Box> = Box::new(InMemoryIndexRs::new(s.to_vec(), None, false)); 347 | 348 | let tokens = &index.sample_smoothed(&s[0..1], 3, 10, 1).unwrap()[0]; 349 | 350 | assert_eq!(tokens.len(), 11); 351 | } 352 | 353 | #[test] 354 | fn sample_smoothed_empty_query_exists() { 355 | let s: Vec = "aabbccabccba".encode_utf16().collect(); 356 | let mut index: Box> = Box::new(InMemoryIndexRs::new(s, None, false)); 357 | 358 | let tokens = &index.sample_smoothed(&[], 1, 10, 10).unwrap()[0]; 359 | 360 | assert_eq!(tokens.len(), 10); 361 | } 362 | 363 | #[test] 364 | fn smoothed_probs_exists() { 365 | let tokens = "aaaaaaaabc".to_string(); 366 | let tokens_vec: Vec = tokens.encode_utf16().collect(); 367 | let query: Vec<_> = vec![utf16!("b")[0]]; 368 | 369 | // Get unsmoothed probs for query 370 | let sa: SuffixTable = sais(&tokens); 371 | let bigram_counts = sa.count_next(&query); 372 | let unsmoothed_probs = bigram_counts 373 | .iter() 374 | .map(|&x| x as f64 / bigram_counts.iter().sum::() as f64) 375 | .collect::>(); 376 | 377 | // Get smoothed probs for query 378 | let mut index: Box> = Box::new(InMemoryIndexRs::new(tokens_vec, None, false)); 379 | let smoothed_probs = index.get_smoothed_probs(&query); 380 | 381 | // Compare unsmoothed and smoothed probabilities 382 | let a = utf16!("a")[0] as usize; 383 | let c = utf16!("c")[0] as usize; 384 | 385 | // The naive bigram probability for query 'b' is p(c) = 1.0. 386 | assert!(unsmoothed_probs[a] == 0.0); 387 | assert!(unsmoothed_probs[c] == 1.0); 388 | 389 | // The smoothed bigram probabilities interpolate with the lower-order unigram 390 | // probabilities where p(a) is high, lowering p(c) 391 | assert!(smoothed_probs[a] > 0.1); 392 | assert!(smoothed_probs[c] < 1.0); 393 | } 394 | } 395 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | pub mod mmap_slice; 2 | pub use bindings::in_memory_index::InMemoryIndex; 3 | pub use bindings::memmap_index::MemmapIndex; 4 | pub use bindings::sharded_memmap_index::ShardedMemmapIndex; 5 | pub use bindings::sharded_in_memory_index::ShardedInMemoryIndex; 6 | pub use sharded_in_memory_index::ShardedInMemoryIndexRs; 7 | pub use in_memory_index::InMemoryIndexRs; 8 | pub use sample::Sample; 9 | 10 | pub use table::SuffixTable; 11 | 12 | /// Python bindings 13 | #[cfg(not(feature = "rust"))] 14 | use pyo3::prelude::*; 15 | 16 | mod bindings; 17 | mod in_memory_index; 18 | mod memmap_index; 19 | mod par_quicksort; 20 | mod sample; 21 | mod sharded_memmap_index; 22 | mod sharded_in_memory_index; 23 | mod table; 24 | mod util; 25 | 26 | #[cfg(not(feature = "rust"))] 27 | #[pymodule] 28 | fn tokengrams(m: &Bound<'_, PyModule>) -> PyResult<()> { 29 | m.add_class::()?; 30 | m.add_class::()?; 31 | m.add_class::()?; 32 | m.add_class::()?; 33 | Ok(()) 34 | } 35 | -------------------------------------------------------------------------------- /src/memmap_index.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Result; 2 | use funty::Unsigned; 3 | use rayon::prelude::*; 4 | use std::collections::HashMap; 5 | use std::fs::{File, OpenOptions}; 6 | use std::time::Instant; 7 | 8 | use crate::bindings::memmap_index::MemmapIndexTrait; 9 | use crate::mmap_slice::{MmapSlice, MmapSliceMut}; 10 | use crate::par_quicksort::par_sort_unstable_by_key; 11 | use crate::sample::{KneserNeyCache, Sample}; 12 | use crate::table::SuffixTable; 13 | 14 | /// A memmap index exposes suffix table functionality over text corpora too large to fit in memory. 15 | pub struct MemmapIndexRs { 16 | table: SuffixTable, MmapSlice>, 17 | cache: KneserNeyCache, 18 | } 19 | 20 | impl MemmapIndexRs { 21 | pub fn new(text_path: String, table_path: String, vocab: usize) -> Result { 22 | let text_file = File::open(&text_path)?; 23 | let table_file = File::open(&table_path)?; 24 | 25 | let table = SuffixTable::from_parts( 26 | MmapSlice::::new(&text_file)?, 27 | MmapSlice::new(&table_file)?, 28 | Some(vocab), 29 | ); 30 | assert!(table.is_sorted()); 31 | 32 | Ok(MemmapIndexRs { 33 | table, 34 | cache: KneserNeyCache::default(), 35 | }) 36 | } 37 | 38 | pub fn build( 39 | text_path: String, 40 | table_path: String, 41 | vocab: usize, 42 | verbose: bool, 43 | ) -> Result { 44 | // Memory map the text as read-only 45 | let text_mmap = MmapSlice::::new(&File::open(&text_path).unwrap()).unwrap(); 46 | 47 | let table_file = OpenOptions::new() 48 | .create(true) 49 | .read(true) 50 | .write(true) 51 | .open(&table_path)?; 52 | 53 | // Allocate space on disk for the table 54 | let table_size = text_mmap.len() * 8; 55 | table_file.set_len(table_size as u64)?; 56 | 57 | if verbose { 58 | println!("Writing indices to disk..."); 59 | } 60 | let start = Instant::now(); 61 | let mut table_mmap = MmapSliceMut::::new(&table_file)?; 62 | table_mmap 63 | .iter_mut() 64 | .enumerate() 65 | .for_each(|(i, x)| *x = i as u64); 66 | 67 | assert_eq!(table_mmap.len(), text_mmap.len()); 68 | if verbose { 69 | println!("Time elapsed: {:?}", start.elapsed()); 70 | } 71 | let start = Instant::now(); 72 | 73 | // TODO: Be even smarter about this? We may need to take into account the number of CPUs 74 | // available as well. These magic numbers were tuned on a server with 48 physical cores. 75 | // Empirically we start getting stack overflows between 5B and 10B tokens when using the 76 | // default stack size of 2MB. We scale the stack size as log2(n) * 8MB to avoid this. 77 | let scale = (text_mmap.len() as f64) / 5e9; // 5B tokens 78 | let stack_size = scale.log2().max(1.0) * 8e6; // 8MB 79 | 80 | rayon::ThreadPoolBuilder::new() 81 | .stack_size(stack_size as usize) 82 | .build() 83 | .unwrap() 84 | .install(|| { 85 | // Sort the indices by the suffixes they point to. 86 | // The unstable algorithm is critical for avoiding out-of-memory errors, since it does 87 | // not allocate any more memory than the input and output slices. 88 | println!("Sorting indices..."); 89 | par_sort_unstable_by_key( 90 | table_mmap.as_slice_mut(), 91 | |&i| &text_mmap[i as usize..], 92 | verbose, 93 | ); 94 | }); 95 | if verbose { 96 | println!("Time elapsed: {:?}", start.elapsed()); 97 | } 98 | 99 | // Re-open the table as read-only 100 | let table_mmap = MmapSlice::new(&table_file)?; 101 | let table = SuffixTable::from_parts(text_mmap, table_mmap, Some(vocab)); 102 | debug_assert!(table.is_sorted()); 103 | 104 | Ok(MemmapIndexRs { 105 | table, 106 | cache: KneserNeyCache::default(), 107 | }) 108 | } 109 | } 110 | 111 | impl Sample for MemmapIndexRs { 112 | fn get_cache(&self) -> &KneserNeyCache { 113 | &self.cache 114 | } 115 | 116 | fn get_mut_cache(&mut self) -> &mut KneserNeyCache { 117 | &mut self.cache 118 | } 119 | 120 | fn count_next_slice(&self, query: &[T]) -> Vec { 121 | self.table.count_next(query) 122 | } 123 | 124 | fn count_ngrams(&self, n: usize) -> HashMap { 125 | self.table.count_ngrams(n) 126 | } 127 | } 128 | 129 | impl MemmapIndexTrait for MemmapIndexRs 130 | where 131 | T: Unsigned, 132 | { 133 | fn positions(&self, query: Vec) -> Vec { 134 | let query: Vec = query 135 | .iter() 136 | .filter_map(|&item| T::try_from(item).ok()) 137 | .collect(); 138 | self.table.positions(&query).to_vec() 139 | } 140 | 141 | fn is_sorted(&self) -> bool { 142 | self.table.is_sorted() 143 | } 144 | 145 | fn contains(&self, query: Vec) -> bool { 146 | let query: Vec = query 147 | .iter() 148 | .filter_map(|&item| T::try_from(item).ok()) 149 | .collect(); 150 | self.table.contains(&query) 151 | } 152 | 153 | fn count(&self, query: Vec) -> usize { 154 | let query: Vec = query 155 | .iter() 156 | .filter_map(|&item| T::try_from(item).ok()) 157 | .collect(); 158 | self.table.positions(&query).len() 159 | } 160 | 161 | fn count_next(&self, query: Vec) -> Vec { 162 | let query: Vec = query 163 | .iter() 164 | .filter_map(|&item| T::try_from(item).ok()) 165 | .collect(); 166 | self.table.count_next(&query) 167 | } 168 | 169 | fn batch_count_next(&self, queries: Vec>) -> Vec> { 170 | queries 171 | .into_par_iter() 172 | .map(|query| self.count_next(query)) 173 | .collect() 174 | } 175 | 176 | fn sample_smoothed( 177 | &mut self, 178 | query: Vec, 179 | n: usize, 180 | k: usize, 181 | num_samples: usize, 182 | ) -> Result>> { 183 | let query: Vec = query 184 | .iter() 185 | .filter_map(|&item| T::try_from(item).ok()) 186 | .collect(); 187 | 188 | let samples_batch = >::sample_smoothed(self, &query, n, k, num_samples)?; 189 | Ok(samples_batch 190 | .into_iter() 191 | .map(|samples| { 192 | samples 193 | .into_iter() 194 | .filter_map(|sample| { 195 | match TryInto::::try_into(sample) { 196 | Ok(value) => Some(value), 197 | Err(_) => None, // Silently skip values that can't be converted 198 | } 199 | }) 200 | .collect::>() 201 | }) 202 | .collect()) 203 | } 204 | 205 | fn sample_unsmoothed( 206 | &self, 207 | query: Vec, 208 | n: usize, 209 | k: usize, 210 | num_samples: usize, 211 | ) -> Result>> { 212 | let query: Vec = query 213 | .iter() 214 | .filter_map(|&item| T::try_from(item).ok()) 215 | .collect(); 216 | 217 | let samples_batch = 218 | >::sample_unsmoothed(self, &query, n, k, num_samples)?; 219 | Ok(samples_batch 220 | .into_iter() 221 | .map(|samples| { 222 | samples 223 | .into_iter() 224 | .filter_map(|sample| { 225 | match TryInto::::try_into(sample) { 226 | Ok(value) => Some(value), 227 | Err(_) => None, // Silently skip values that can't be converted 228 | } 229 | }) 230 | .collect::>() 231 | }) 232 | .collect()) 233 | } 234 | 235 | fn get_smoothed_probs(&mut self, query: Vec) -> Vec { 236 | let query: Vec = query 237 | .iter() 238 | .filter_map(|&item| T::try_from(item).ok()) 239 | .collect(); 240 | 241 | >::get_smoothed_probs(self, &query) 242 | } 243 | 244 | fn batch_get_smoothed_probs(&mut self, queries: Vec>) -> Vec> { 245 | let queries: Vec> = queries 246 | .into_iter() 247 | .map(|query| { 248 | query 249 | .iter() 250 | .filter_map(|&item| T::try_from(item).ok()) 251 | .collect() 252 | }) 253 | .collect(); 254 | >::batch_get_smoothed_probs(self, &queries) 255 | } 256 | 257 | fn estimate_deltas(&mut self, n: usize) { 258 | >::estimate_deltas(self, n) 259 | } 260 | } 261 | -------------------------------------------------------------------------------- /src/mmap_slice.rs: -------------------------------------------------------------------------------- 1 | use funty::Unsigned; 2 | use memmap2::{Mmap, MmapAsRawDesc, MmapMut}; 3 | use std::marker::PhantomData; 4 | use std::ops::{Deref, DerefMut}; 5 | 6 | /// An immutable memory-mapped slice of unsigned integers 7 | pub struct MmapSlice { 8 | mmap: Mmap, 9 | _element_type: PhantomData, 10 | } 11 | 12 | impl MmapSlice { 13 | pub fn new(file: F) -> std::io::Result { 14 | let raw = unsafe { Mmap::map(file)? }; 15 | 16 | // Sanity check that the file size is a multiple of the element size. 17 | if raw.len() % std::mem::size_of::() != 0 { 18 | Err(std::io::Error::new( 19 | std::io::ErrorKind::InvalidData, 20 | "File size is not a multiple of element size", 21 | )) 22 | } else { 23 | Ok(MmapSlice { 24 | mmap: raw, 25 | _element_type: PhantomData, 26 | }) 27 | } 28 | } 29 | 30 | // Return the number of items of type T that can fit in the memory map. 31 | pub fn len(&self) -> usize { 32 | self.mmap.len() / std::mem::size_of::() 33 | } 34 | 35 | pub fn as_slice<'a>(&'a self) -> &'a [T] { 36 | unsafe { std::slice::from_raw_parts(self.mmap.as_ptr() as *const T, self.len()) } 37 | } 38 | } 39 | 40 | impl Deref for MmapSlice { 41 | type Target = [T]; 42 | 43 | fn deref(&self) -> &[T] { 44 | self.as_slice() 45 | } 46 | } 47 | 48 | /// A mutable memory-mapped slice of unsigned integers 49 | pub struct MmapSliceMut { 50 | mmap: MmapMut, 51 | _element_type: PhantomData, 52 | } 53 | 54 | impl MmapSliceMut { 55 | pub fn new(file: F) -> std::io::Result { 56 | let raw = unsafe { MmapMut::map_mut(file)? }; 57 | 58 | Ok(MmapSliceMut { 59 | mmap: raw, 60 | _element_type: PhantomData, 61 | }) 62 | } 63 | 64 | pub fn len(&self) -> usize { 65 | self.mmap.len() / std::mem::size_of::() 66 | } 67 | 68 | pub fn as_slice<'a>(&'a self) -> &'a [T] { 69 | unsafe { std::slice::from_raw_parts(self.mmap.as_ptr() as *const T, self.len()) } 70 | } 71 | 72 | pub fn as_slice_mut<'a>(&'a mut self) -> &'a mut [T] { 73 | unsafe { std::slice::from_raw_parts_mut(self.mmap.as_mut_ptr() as *mut T, self.len()) } 74 | } 75 | 76 | pub fn into_read_only(self) -> std::io::Result> { 77 | Ok(MmapSlice { 78 | mmap: self.mmap.make_read_only()?, 79 | _element_type: PhantomData, 80 | }) 81 | } 82 | 83 | pub fn flush(&self) -> std::io::Result<()> { 84 | self.mmap.flush() 85 | } 86 | } 87 | 88 | impl Deref for MmapSliceMut { 89 | type Target = [T]; 90 | 91 | fn deref(&self) -> &[T] { 92 | self.as_slice() 93 | } 94 | } 95 | 96 | impl DerefMut for MmapSliceMut { 97 | fn deref_mut(&mut self) -> &mut [T] { 98 | self.as_slice_mut() 99 | } 100 | } 101 | -------------------------------------------------------------------------------- /src/par_quicksort.rs: -------------------------------------------------------------------------------- 1 | //! Parallel quicksort. 2 | //! 3 | //! This implementation is copied verbatim from rayon's `std::slice::sort_unstable` and then parallelized. 4 | //! The only difference from the original is that calls to `recurse` are executed in parallel using 5 | //! `rayon_core::join`. 6 | 7 | use indicatif::{ProgressBar, ProgressStyle}; 8 | use std::cmp; 9 | use std::marker::PhantomData; 10 | use std::mem::{self, MaybeUninit}; 11 | use std::ptr; 12 | 13 | pub fn par_sort_unstable_by_key(data: &mut [T], f: F, verbose: bool) 14 | where 15 | T: Send, 16 | K: Ord + Send, 17 | F: Fn(&T) -> K + Sync, 18 | { 19 | par_quicksort(data, |a, b| f(a).lt(&f(b)), verbose); 20 | } 21 | 22 | /// When dropped, copies from `src` into `dest`. 23 | #[must_use] 24 | struct CopyOnDrop<'a, T> { 25 | src: *const T, 26 | dest: *mut T, 27 | /// `src` is often a local pointer here, make sure we have appropriate 28 | /// PhantomData so that dropck can protect us. 29 | marker: PhantomData<&'a mut T>, 30 | } 31 | 32 | impl<'a, T> CopyOnDrop<'a, T> { 33 | /// Construct from a source pointer and a destination 34 | /// Assumes dest lives longer than src, since there is no easy way to 35 | /// copy down lifetime information from another pointer 36 | unsafe fn new(src: &'a T, dest: *mut T) -> Self { 37 | CopyOnDrop { 38 | src, 39 | dest, 40 | marker: PhantomData, 41 | } 42 | } 43 | } 44 | 45 | impl Drop for CopyOnDrop<'_, T> { 46 | fn drop(&mut self) { 47 | // SAFETY: This is a helper class. 48 | // Please refer to its usage for correctness. 49 | // Namely, one must be sure that `src` and `dst` does not overlap as required by `ptr::copy_nonoverlapping`. 50 | unsafe { 51 | ptr::copy_nonoverlapping(self.src, self.dest, 1); 52 | } 53 | } 54 | } 55 | 56 | /// Shifts the first element to the right until it encounters a greater or equal element. 57 | fn shift_head(v: &mut [T], is_less: &F) 58 | where 59 | F: Fn(&T, &T) -> bool, 60 | { 61 | let len = v.len(); 62 | // SAFETY: The unsafe operations below involves indexing without a bounds check (by offsetting a 63 | // pointer) and copying memory (`ptr::copy_nonoverlapping`). 64 | // 65 | // a. Indexing: 66 | // 1. We checked the size of the array to >=2. 67 | // 2. All the indexing that we will do is always between {0 <= index < len} at most. 68 | // 69 | // b. Memory copying 70 | // 1. We are obtaining pointers to references which are guaranteed to be valid. 71 | // 2. They cannot overlap because we obtain pointers to difference indices of the slice. 72 | // Namely, `i` and `i-1`. 73 | // 3. If the slice is properly aligned, the elements are properly aligned. 74 | // It is the caller's responsibility to make sure the slice is properly aligned. 75 | // 76 | // See comments below for further detail. 77 | unsafe { 78 | // If the first two elements are out-of-order... 79 | if len >= 2 && is_less(v.get_unchecked(1), v.get_unchecked(0)) { 80 | // Read the first element into a stack-allocated variable. If a following comparison 81 | // operation panics, `hole` will get dropped and automatically write the element back 82 | // into the slice. 83 | let tmp = mem::ManuallyDrop::new(ptr::read(v.get_unchecked(0))); 84 | let v = v.as_mut_ptr(); 85 | let mut hole = CopyOnDrop::new(&*tmp, v.add(1)); 86 | ptr::copy_nonoverlapping(v.add(1), v.add(0), 1); 87 | 88 | for i in 2..len { 89 | if !is_less(&*v.add(i), &*tmp) { 90 | break; 91 | } 92 | 93 | // Move `i`-th element one place to the left, thus shifting the hole to the right. 94 | ptr::copy_nonoverlapping(v.add(i), v.add(i - 1), 1); 95 | hole.dest = v.add(i); 96 | } 97 | // `hole` gets dropped and thus copies `tmp` into the remaining hole in `v`. 98 | } 99 | } 100 | } 101 | 102 | /// Shifts the last element to the left until it encounters a smaller or equal element. 103 | fn shift_tail(v: &mut [T], is_less: &F) 104 | where 105 | F: Fn(&T, &T) -> bool, 106 | { 107 | let len = v.len(); 108 | // SAFETY: The unsafe operations below involves indexing without a bound check (by offsetting a 109 | // pointer) and copying memory (`ptr::copy_nonoverlapping`). 110 | // 111 | // a. Indexing: 112 | // 1. We checked the size of the array to >= 2. 113 | // 2. All the indexing that we will do is always between `0 <= index < len-1` at most. 114 | // 115 | // b. Memory copying 116 | // 1. We are obtaining pointers to references which are guaranteed to be valid. 117 | // 2. They cannot overlap because we obtain pointers to difference indices of the slice. 118 | // Namely, `i` and `i+1`. 119 | // 3. If the slice is properly aligned, the elements are properly aligned. 120 | // It is the caller's responsibility to make sure the slice is properly aligned. 121 | // 122 | // See comments below for further detail. 123 | unsafe { 124 | // If the last two elements are out-of-order... 125 | if len >= 2 && is_less(v.get_unchecked(len - 1), v.get_unchecked(len - 2)) { 126 | // Read the last element into a stack-allocated variable. If a following comparison 127 | // operation panics, `hole` will get dropped and automatically write the element back 128 | // into the slice. 129 | let tmp = mem::ManuallyDrop::new(ptr::read(v.get_unchecked(len - 1))); 130 | let v = v.as_mut_ptr(); 131 | let mut hole = CopyOnDrop::new(&*tmp, v.add(len - 2)); 132 | ptr::copy_nonoverlapping(v.add(len - 2), v.add(len - 1), 1); 133 | 134 | for i in (0..len - 2).rev() { 135 | if !is_less(&*tmp, &*v.add(i)) { 136 | break; 137 | } 138 | 139 | // Move `i`-th element one place to the right, thus shifting the hole to the left. 140 | ptr::copy_nonoverlapping(v.add(i), v.add(i + 1), 1); 141 | hole.dest = v.add(i); 142 | } 143 | // `hole` gets dropped and thus copies `tmp` into the remaining hole in `v`. 144 | } 145 | } 146 | } 147 | 148 | /// Partially sorts a slice by shifting several out-of-order elements around. 149 | /// 150 | /// Returns `true` if the slice is sorted at the end. This function is *O*(*n*) worst-case. 151 | #[cold] 152 | fn partial_insertion_sort(v: &mut [T], is_less: &F) -> bool 153 | where 154 | F: Fn(&T, &T) -> bool, 155 | { 156 | // Maximum number of adjacent out-of-order pairs that will get shifted. 157 | const MAX_STEPS: usize = 5; 158 | // If the slice is shorter than this, don't shift any elements. 159 | const SHORTEST_SHIFTING: usize = 50; 160 | 161 | let len = v.len(); 162 | let mut i = 1; 163 | 164 | for _ in 0..MAX_STEPS { 165 | // SAFETY: We already explicitly did the bound checking with `i < len`. 166 | // All our subsequent indexing is only in the range `0 <= index < len` 167 | unsafe { 168 | // Find the next pair of adjacent out-of-order elements. 169 | while i < len && !is_less(v.get_unchecked(i), v.get_unchecked(i - 1)) { 170 | i += 1; 171 | } 172 | } 173 | 174 | // Are we done? 175 | if i == len { 176 | return true; 177 | } 178 | 179 | // Don't shift elements on short arrays, that has a performance cost. 180 | if len < SHORTEST_SHIFTING { 181 | return false; 182 | } 183 | 184 | // Swap the found pair of elements. This puts them in correct order. 185 | v.swap(i - 1, i); 186 | 187 | // Shift the smaller element to the left. 188 | shift_tail(&mut v[..i], is_less); 189 | // Shift the greater element to the right. 190 | shift_head(&mut v[i..], is_less); 191 | } 192 | 193 | // Didn't manage to sort the slice in the limited number of steps. 194 | false 195 | } 196 | 197 | /// Sorts a slice using insertion sort, which is *O*(*n*^2) worst-case. 198 | fn insertion_sort(v: &mut [T], is_less: &F) 199 | where 200 | F: Fn(&T, &T) -> bool, 201 | { 202 | for i in 1..v.len() { 203 | shift_tail(&mut v[..i + 1], is_less); 204 | } 205 | } 206 | 207 | /// Sorts `v` using heapsort, which guarantees *O*(*n* \* log(*n*)) worst-case. 208 | #[cold] 209 | fn heapsort(v: &mut [T], is_less: &F) 210 | where 211 | F: Fn(&T, &T) -> bool, 212 | { 213 | // This binary heap respects the invariant `parent >= child`. 214 | let sift_down = |v: &mut [T], mut node| { 215 | loop { 216 | // Children of `node`. 217 | let mut child = 2 * node + 1; 218 | if child >= v.len() { 219 | break; 220 | } 221 | 222 | // Choose the greater child. 223 | if child + 1 < v.len() && is_less(&v[child], &v[child + 1]) { 224 | child += 1; 225 | } 226 | 227 | // Stop if the invariant holds at `node`. 228 | if !is_less(&v[node], &v[child]) { 229 | break; 230 | } 231 | 232 | // Swap `node` with the greater child, move one step down, and continue sifting. 233 | v.swap(node, child); 234 | node = child; 235 | } 236 | }; 237 | 238 | // Build the heap in linear time. 239 | for i in (0..v.len() / 2).rev() { 240 | sift_down(v, i); 241 | } 242 | 243 | // Pop maximal elements from the heap. 244 | for i in (1..v.len()).rev() { 245 | v.swap(0, i); 246 | sift_down(&mut v[..i], 0); 247 | } 248 | } 249 | 250 | /// Partitions `v` into elements smaller than `pivot`, followed by elements greater than or equal 251 | /// to `pivot`. 252 | /// 253 | /// Returns the number of elements smaller than `pivot`. 254 | /// 255 | /// Partitioning is performed block-by-block in order to minimize the cost of branching operations. 256 | /// This idea is presented in the [BlockQuicksort][pdf] paper. 257 | /// 258 | /// [pdf]: https://drops.dagstuhl.de/opus/volltexte/2016/6389/pdf/LIPIcs-ESA-2016-38.pdf 259 | fn partition_in_blocks(v: &mut [T], pivot: &T, is_less: &F) -> usize 260 | where 261 | F: Fn(&T, &T) -> bool, 262 | { 263 | // Number of elements in a typical block. 264 | const BLOCK: usize = 128; 265 | 266 | // The partitioning algorithm repeats the following steps until completion: 267 | // 268 | // 1. Trace a block from the left side to identify elements greater than or equal to the pivot. 269 | // 2. Trace a block from the right side to identify elements smaller than the pivot. 270 | // 3. Exchange the identified elements between the left and right side. 271 | // 272 | // We keep the following variables for a block of elements: 273 | // 274 | // 1. `block` - Number of elements in the block. 275 | // 2. `start` - Start pointer into the `offsets` array. 276 | // 3. `end` - End pointer into the `offsets` array. 277 | // 4. `offsets - Indices of out-of-order elements within the block. 278 | 279 | // The current block on the left side (from `l` to `l.add(block_l)`). 280 | let mut l = v.as_mut_ptr(); 281 | let mut block_l = BLOCK; 282 | let mut start_l = ptr::null_mut(); 283 | let mut end_l = ptr::null_mut(); 284 | let mut offsets_l = [MaybeUninit::::uninit(); BLOCK]; 285 | 286 | // The current block on the right side (from `r.sub(block_r)` to `r`). 287 | // SAFETY: The documentation for .add() specifically mention that `vec.as_ptr().add(vec.len())` is always safe` 288 | let mut r = unsafe { l.add(v.len()) }; 289 | let mut block_r = BLOCK; 290 | let mut start_r = ptr::null_mut(); 291 | let mut end_r = ptr::null_mut(); 292 | let mut offsets_r = [MaybeUninit::::uninit(); BLOCK]; 293 | 294 | // FIXME: When we get VLAs, try creating one array of length `min(v.len(), 2 * BLOCK)` rather 295 | // than two fixed-size arrays of length `BLOCK`. VLAs might be more cache-efficient. 296 | 297 | // Returns the number of elements between pointers `l` (inclusive) and `r` (exclusive). 298 | fn width(l: *mut T, r: *mut T) -> usize { 299 | assert!(mem::size_of::() > 0); 300 | // FIXME: this should *likely* use `offset_from`, but more 301 | // investigation is needed (including running tests in miri). 302 | // TODO unstable: (r.addr() - l.addr()) / mem::size_of::() 303 | (r as usize - l as usize) / mem::size_of::() 304 | } 305 | 306 | loop { 307 | // We are done with partitioning block-by-block when `l` and `r` get very close. Then we do 308 | // some patch-up work in order to partition the remaining elements in between. 309 | let is_done = width(l, r) <= 2 * BLOCK; 310 | 311 | if is_done { 312 | // Number of remaining elements (still not compared to the pivot). 313 | let mut rem = width(l, r); 314 | if start_l < end_l || start_r < end_r { 315 | rem -= BLOCK; 316 | } 317 | 318 | // Adjust block sizes so that the left and right block don't overlap, but get perfectly 319 | // aligned to cover the whole remaining gap. 320 | if start_l < end_l { 321 | block_r = rem; 322 | } else if start_r < end_r { 323 | block_l = rem; 324 | } else { 325 | // There were the same number of elements to switch on both blocks during the last 326 | // iteration, so there are no remaining elements on either block. Cover the remaining 327 | // items with roughly equally-sized blocks. 328 | block_l = rem / 2; 329 | block_r = rem - block_l; 330 | } 331 | debug_assert!(block_l <= BLOCK && block_r <= BLOCK); 332 | debug_assert!(width(l, r) == block_l + block_r); 333 | } 334 | 335 | if start_l == end_l { 336 | // Trace `block_l` elements from the left side. 337 | // TODO unstable: start_l = MaybeUninit::slice_as_mut_ptr(&mut offsets_l); 338 | start_l = offsets_l.as_mut_ptr() as *mut u8; 339 | end_l = start_l; 340 | let mut elem = l; 341 | 342 | for i in 0..block_l { 343 | // SAFETY: The unsafety operations below involve the usage of the `offset`. 344 | // According to the conditions required by the function, we satisfy them because: 345 | // 1. `offsets_l` is stack-allocated, and thus considered separate allocated object. 346 | // 2. The function `is_less` returns a `bool`. 347 | // Casting a `bool` will never overflow `isize`. 348 | // 3. We have guaranteed that `block_l` will be `<= BLOCK`. 349 | // Plus, `end_l` was initially set to the begin pointer of `offsets_` which was declared on the stack. 350 | // Thus, we know that even in the worst case (all invocations of `is_less` returns false) we will only be at most 1 byte pass the end. 351 | // Another unsafety operation here is dereferencing `elem`. 352 | // However, `elem` was initially the begin pointer to the slice which is always valid. 353 | unsafe { 354 | // Branchless comparison. 355 | *end_l = i as u8; 356 | end_l = end_l.offset(!is_less(&*elem, pivot) as isize); 357 | elem = elem.offset(1); 358 | } 359 | } 360 | } 361 | 362 | if start_r == end_r { 363 | // Trace `block_r` elements from the right side. 364 | // TODO unstable: start_r = MaybeUninit::slice_as_mut_ptr(&mut offsets_r); 365 | start_r = offsets_r.as_mut_ptr() as *mut u8; 366 | end_r = start_r; 367 | let mut elem = r; 368 | 369 | for i in 0..block_r { 370 | // SAFETY: The unsafety operations below involve the usage of the `offset`. 371 | // According to the conditions required by the function, we satisfy them because: 372 | // 1. `offsets_r` is stack-allocated, and thus considered separate allocated object. 373 | // 2. The function `is_less` returns a `bool`. 374 | // Casting a `bool` will never overflow `isize`. 375 | // 3. We have guaranteed that `block_r` will be `<= BLOCK`. 376 | // Plus, `end_r` was initially set to the begin pointer of `offsets_` which was declared on the stack. 377 | // Thus, we know that even in the worst case (all invocations of `is_less` returns true) we will only be at most 1 byte pass the end. 378 | // Another unsafety operation here is dereferencing `elem`. 379 | // However, `elem` was initially `1 * sizeof(T)` past the end and we decrement it by `1 * sizeof(T)` before accessing it. 380 | // Plus, `block_r` was asserted to be less than `BLOCK` and `elem` will therefore at most be pointing to the beginning of the slice. 381 | unsafe { 382 | // Branchless comparison. 383 | elem = elem.offset(-1); 384 | *end_r = i as u8; 385 | end_r = end_r.offset(is_less(&*elem, pivot) as isize); 386 | } 387 | } 388 | } 389 | 390 | // Number of out-of-order elements to swap between the left and right side. 391 | let count = cmp::min(width(start_l, end_l), width(start_r, end_r)); 392 | 393 | if count > 0 { 394 | macro_rules! left { 395 | () => { 396 | l.offset(*start_l as isize) 397 | }; 398 | } 399 | macro_rules! right { 400 | () => { 401 | r.offset(-(*start_r as isize) - 1) 402 | }; 403 | } 404 | 405 | // Instead of swapping one pair at the time, it is more efficient to perform a cyclic 406 | // permutation. This is not strictly equivalent to swapping, but produces a similar 407 | // result using fewer memory operations. 408 | 409 | // SAFETY: The use of `ptr::read` is valid because there is at least one element in 410 | // both `offsets_l` and `offsets_r`, so `left!` is a valid pointer to read from. 411 | // 412 | // The uses of `left!` involve calls to `offset` on `l`, which points to the 413 | // beginning of `v`. All the offsets pointed-to by `start_l` are at most `block_l`, so 414 | // these `offset` calls are safe as all reads are within the block. The same argument 415 | // applies for the uses of `right!`. 416 | // 417 | // The calls to `start_l.offset` are valid because there are at most `count-1` of them, 418 | // plus the final one at the end of the unsafe block, where `count` is the minimum number 419 | // of collected offsets in `offsets_l` and `offsets_r`, so there is no risk of there not 420 | // being enough elements. The same reasoning applies to the calls to `start_r.offset`. 421 | // 422 | // The calls to `copy_nonoverlapping` are safe because `left!` and `right!` are guaranteed 423 | // not to overlap, and are valid because of the reasoning above. 424 | unsafe { 425 | let tmp = ptr::read(left!()); 426 | ptr::copy_nonoverlapping(right!(), left!(), 1); 427 | 428 | for _ in 1..count { 429 | start_l = start_l.offset(1); 430 | ptr::copy_nonoverlapping(left!(), right!(), 1); 431 | start_r = start_r.offset(1); 432 | ptr::copy_nonoverlapping(right!(), left!(), 1); 433 | } 434 | 435 | ptr::copy_nonoverlapping(&tmp, right!(), 1); 436 | mem::forget(tmp); 437 | start_l = start_l.offset(1); 438 | start_r = start_r.offset(1); 439 | } 440 | } 441 | 442 | if start_l == end_l { 443 | // All out-of-order elements in the left block were moved. Move to the next block. 444 | 445 | // block-width-guarantee 446 | // SAFETY: if `!is_done` then the slice width is guaranteed to be at least `2*BLOCK` wide. There 447 | // are at most `BLOCK` elements in `offsets_l` because of its size, so the `offset` operation is 448 | // safe. Otherwise, the debug assertions in the `is_done` case guarantee that 449 | // `width(l, r) == block_l + block_r`, namely, that the block sizes have been adjusted to account 450 | // for the smaller number of remaining elements. 451 | l = unsafe { l.add(block_l) }; 452 | } 453 | 454 | if start_r == end_r { 455 | // All out-of-order elements in the right block were moved. Move to the previous block. 456 | 457 | // SAFETY: Same argument as [block-width-guarantee]. Either this is a full block `2*BLOCK`-wide, 458 | // or `block_r` has been adjusted for the last handful of elements. 459 | r = unsafe { r.offset(-(block_r as isize)) }; 460 | } 461 | 462 | if is_done { 463 | break; 464 | } 465 | } 466 | 467 | // All that remains now is at most one block (either the left or the right) with out-of-order 468 | // elements that need to be moved. Such remaining elements can be simply shifted to the end 469 | // within their block. 470 | 471 | if start_l < end_l { 472 | // The left block remains. 473 | // Move its remaining out-of-order elements to the far right. 474 | debug_assert_eq!(width(l, r), block_l); 475 | while start_l < end_l { 476 | // remaining-elements-safety 477 | // SAFETY: while the loop condition holds there are still elements in `offsets_l`, so it 478 | // is safe to point `end_l` to the previous element. 479 | // 480 | // The `ptr::swap` is safe if both its arguments are valid for reads and writes: 481 | // - Per the debug assert above, the distance between `l` and `r` is `block_l` 482 | // elements, so there can be at most `block_l` remaining offsets between `start_l` 483 | // and `end_l`. This means `r` will be moved at most `block_l` steps back, which 484 | // makes the `r.offset` calls valid (at that point `l == r`). 485 | // - `offsets_l` contains valid offsets into `v` collected during the partitioning of 486 | // the last block, so the `l.offset` calls are valid. 487 | unsafe { 488 | end_l = end_l.offset(-1); 489 | ptr::swap(l.offset(*end_l as isize), r.offset(-1)); 490 | r = r.offset(-1); 491 | } 492 | } 493 | width(v.as_mut_ptr(), r) 494 | } else if start_r < end_r { 495 | // The right block remains. 496 | // Move its remaining out-of-order elements to the far left. 497 | debug_assert_eq!(width(l, r), block_r); 498 | while start_r < end_r { 499 | // SAFETY: See the reasoning in [remaining-elements-safety]. 500 | unsafe { 501 | end_r = end_r.offset(-1); 502 | ptr::swap(l, r.offset(-(*end_r as isize) - 1)); 503 | l = l.offset(1); 504 | } 505 | } 506 | width(v.as_mut_ptr(), l) 507 | } else { 508 | // Nothing else to do, we're done. 509 | width(v.as_mut_ptr(), l) 510 | } 511 | } 512 | 513 | /// Partitions `v` into elements smaller than `v[pivot]`, followed by elements greater than or 514 | /// equal to `v[pivot]`. 515 | /// 516 | /// Returns a tuple of: 517 | /// 518 | /// 1. Number of elements smaller than `v[pivot]`. 519 | /// 2. True if `v` was already partitioned. 520 | fn partition(v: &mut [T], pivot: usize, is_less: &F) -> (usize, bool) 521 | where 522 | F: Fn(&T, &T) -> bool, 523 | { 524 | let (mid, was_partitioned) = { 525 | // Place the pivot at the beginning of slice. 526 | v.swap(0, pivot); 527 | let (pivot, v) = v.split_at_mut(1); 528 | let pivot = &mut pivot[0]; 529 | 530 | // Read the pivot into a stack-allocated variable for efficiency. If a following comparison 531 | // operation panics, the pivot will be automatically written back into the slice. 532 | 533 | // SAFETY: `pivot` is a reference to the first element of `v`, so `ptr::read` is safe. 534 | let tmp = mem::ManuallyDrop::new(unsafe { ptr::read(pivot) }); 535 | let _pivot_guard = unsafe { CopyOnDrop::new(&*tmp, pivot) }; 536 | let pivot = &*tmp; 537 | 538 | // Find the first pair of out-of-order elements. 539 | let mut l = 0; 540 | let mut r = v.len(); 541 | 542 | // SAFETY: The unsafety below involves indexing an array. 543 | // For the first one: We already do the bounds checking here with `l < r`. 544 | // For the second one: We initially have `l == 0` and `r == v.len()` and we checked that `l < r` at every indexing operation. 545 | // From here we know that `r` must be at least `r == l` which was shown to be valid from the first one. 546 | unsafe { 547 | // Find the first element greater than or equal to the pivot. 548 | while l < r && is_less(v.get_unchecked(l), pivot) { 549 | l += 1; 550 | } 551 | 552 | // Find the last element smaller that the pivot. 553 | while l < r && !is_less(v.get_unchecked(r - 1), pivot) { 554 | r -= 1; 555 | } 556 | } 557 | 558 | ( 559 | l + partition_in_blocks(&mut v[l..r], pivot, is_less), 560 | l >= r, 561 | ) 562 | 563 | // `_pivot_guard` goes out of scope and writes the pivot (which is a stack-allocated 564 | // variable) back into the slice where it originally was. This step is critical in ensuring 565 | // safety! 566 | }; 567 | 568 | // Place the pivot between the two partitions. 569 | v.swap(0, mid); 570 | 571 | (mid, was_partitioned) 572 | } 573 | 574 | /// Partitions `v` into elements equal to `v[pivot]` followed by elements greater than `v[pivot]`. 575 | /// 576 | /// Returns the number of elements equal to the pivot. It is assumed that `v` does not contain 577 | /// elements smaller than the pivot. 578 | fn partition_equal(v: &mut [T], pivot: usize, is_less: &F) -> usize 579 | where 580 | F: Fn(&T, &T) -> bool, 581 | { 582 | // Place the pivot at the beginning of slice. 583 | v.swap(0, pivot); 584 | let (pivot, v) = v.split_at_mut(1); 585 | let pivot = &mut pivot[0]; 586 | 587 | // Read the pivot into a stack-allocated variable for efficiency. If a following comparison 588 | // operation panics, the pivot will be automatically written back into the slice. 589 | // SAFETY: The pointer here is valid because it is obtained from a reference to a slice. 590 | let tmp = mem::ManuallyDrop::new(unsafe { ptr::read(pivot) }); 591 | let _pivot_guard = unsafe { CopyOnDrop::new(&*tmp, pivot) }; 592 | let pivot = &*tmp; 593 | 594 | // Now partition the slice. 595 | let mut l = 0; 596 | let mut r = v.len(); 597 | loop { 598 | // SAFETY: The unsafety below involves indexing an array. 599 | // For the first one: We already do the bounds checking here with `l < r`. 600 | // For the second one: We initially have `l == 0` and `r == v.len()` and we checked that `l < r` at every indexing operation. 601 | // From here we know that `r` must be at least `r == l` which was shown to be valid from the first one. 602 | unsafe { 603 | // Find the first element greater than the pivot. 604 | while l < r && !is_less(pivot, v.get_unchecked(l)) { 605 | l += 1; 606 | } 607 | 608 | // Find the last element equal to the pivot. 609 | while l < r && is_less(pivot, v.get_unchecked(r - 1)) { 610 | r -= 1; 611 | } 612 | 613 | // Are we done? 614 | if l >= r { 615 | break; 616 | } 617 | 618 | // Swap the found pair of out-of-order elements. 619 | r -= 1; 620 | let ptr = v.as_mut_ptr(); 621 | ptr::swap(ptr.add(l), ptr.add(r)); 622 | l += 1; 623 | } 624 | } 625 | 626 | // We found `l` elements equal to the pivot. Add 1 to account for the pivot itself. 627 | l + 1 628 | 629 | // `_pivot_guard` goes out of scope and writes the pivot (which is a stack-allocated variable) 630 | // back into the slice where it originally was. This step is critical in ensuring safety! 631 | } 632 | 633 | /// Scatters some elements around in an attempt to break patterns that might cause imbalanced 634 | /// partitions in quicksort. 635 | #[cold] 636 | fn break_patterns(v: &mut [T]) { 637 | let len = v.len(); 638 | if len >= 8 { 639 | // Pseudorandom number generator from the "Xorshift RNGs" paper by George Marsaglia. 640 | let mut random = len as u32; 641 | let mut gen_u32 = || { 642 | random ^= random << 13; 643 | random ^= random >> 17; 644 | random ^= random << 5; 645 | random 646 | }; 647 | let mut gen_usize = || { 648 | if usize::BITS <= 32 { 649 | gen_u32() as usize 650 | } else { 651 | (((gen_u32() as u64) << 32) | (gen_u32() as u64)) as usize 652 | } 653 | }; 654 | 655 | // Take random numbers modulo this number. 656 | // The number fits into `usize` because `len` is not greater than `isize::MAX`. 657 | let modulus = len.next_power_of_two(); 658 | 659 | // Some pivot candidates will be in the nearby of this index. Let's randomize them. 660 | let pos = len / 4 * 2; 661 | 662 | for i in 0..3 { 663 | // Generate a random number modulo `len`. However, in order to avoid costly operations 664 | // we first take it modulo a power of two, and then decrease by `len` until it fits 665 | // into the range `[0, len - 1]`. 666 | let mut other = gen_usize() & (modulus - 1); 667 | 668 | // `other` is guaranteed to be less than `2 * len`. 669 | if other >= len { 670 | other -= len; 671 | } 672 | 673 | v.swap(pos - 1 + i, other); 674 | } 675 | } 676 | } 677 | 678 | /// Chooses a pivot in `v` and returns the index and `true` if the slice is likely already sorted. 679 | /// 680 | /// Elements in `v` might be reordered in the process. 681 | fn choose_pivot(v: &mut [T], is_less: &F) -> (usize, bool) 682 | where 683 | F: Fn(&T, &T) -> bool, 684 | { 685 | // Minimum length to choose the median-of-medians method. 686 | // Shorter slices use the simple median-of-three method. 687 | const SHORTEST_MEDIAN_OF_MEDIANS: usize = 50; 688 | // Maximum number of swaps that can be performed in this function. 689 | const MAX_SWAPS: usize = 4 * 3; 690 | 691 | let len = v.len(); 692 | 693 | // Three indices near which we are going to choose a pivot. 694 | #[allow(clippy::identity_op)] 695 | let mut a = len / 4 * 1; 696 | let mut b = len / 4 * 2; 697 | let mut c = len / 4 * 3; 698 | 699 | // Counts the total number of swaps we are about to perform while sorting indices. 700 | let mut swaps = 0; 701 | 702 | if len >= 8 { 703 | // Swaps indices so that `v[a] <= v[b]`. 704 | // SAFETY: `len >= 8` so there are at least two elements in the neighborhoods of 705 | // `a`, `b` and `c`. This means the three calls to `sort_adjacent` result in 706 | // corresponding calls to `sort3` with valid 3-item neighborhoods around each 707 | // pointer, which in turn means the calls to `sort2` are done with valid 708 | // references. Thus the `v.get_unchecked` calls are safe, as is the `ptr::swap` 709 | // call. 710 | let mut sort2 = |a: &mut usize, b: &mut usize| unsafe { 711 | if is_less(v.get_unchecked(*b), v.get_unchecked(*a)) { 712 | ptr::swap(a, b); 713 | swaps += 1; 714 | } 715 | }; 716 | 717 | // Swaps indices so that `v[a] <= v[b] <= v[c]`. 718 | let mut sort3 = |a: &mut usize, b: &mut usize, c: &mut usize| { 719 | sort2(a, b); 720 | sort2(b, c); 721 | sort2(a, b); 722 | }; 723 | 724 | if len >= SHORTEST_MEDIAN_OF_MEDIANS { 725 | // Finds the median of `v[a - 1], v[a], v[a + 1]` and stores the index into `a`. 726 | let mut sort_adjacent = |a: &mut usize| { 727 | let tmp = *a; 728 | sort3(&mut (tmp - 1), a, &mut (tmp + 1)); 729 | }; 730 | 731 | // Find medians in the neighborhoods of `a`, `b`, and `c`. 732 | sort_adjacent(&mut a); 733 | sort_adjacent(&mut b); 734 | sort_adjacent(&mut c); 735 | } 736 | 737 | // Find the median among `a`, `b`, and `c`. 738 | sort3(&mut a, &mut b, &mut c); 739 | } 740 | 741 | if swaps < MAX_SWAPS { 742 | (b, swaps == 0) 743 | } else { 744 | // The maximum number of swaps was performed. Chances are the slice is descending or mostly 745 | // descending, so reversing will probably help sort it faster. 746 | v.reverse(); 747 | (len - 1 - b, true) 748 | } 749 | } 750 | 751 | /// Sorts `v` recursively. 752 | /// 753 | /// If the slice had a predecessor in the original array, it is specified as `pred`. 754 | /// 755 | /// `limit` is the number of allowed imbalanced partitions before switching to `heapsort`. If zero, 756 | /// this function will immediately switch to heapsort. 757 | fn recurse<'a, T, F>( 758 | mut v: &'a mut [T], 759 | is_less: &F, 760 | mut pred: Option<&'a mut T>, 761 | mut limit: u32, 762 | pbar: &ProgressBar, 763 | ) where 764 | T: Send, 765 | F: Fn(&T, &T) -> bool + Sync, 766 | { 767 | // Slices of up to this length get sorted using insertion sort. 768 | const MAX_INSERTION: usize = 20; 769 | // If both partitions are up to this length, we continue sequentially. This number is as small 770 | // as possible but so that the overhead of Rayon's task scheduling is still negligible. 771 | const MAX_SEQUENTIAL: usize = 2000; 772 | 773 | // True if the last partitioning was reasonably balanced. 774 | let mut was_balanced = true; 775 | // True if the last partitioning didn't shuffle elements (the slice was already partitioned). 776 | let mut was_partitioned = true; 777 | 778 | loop { 779 | let len = v.len(); 780 | 781 | // Very short slices get sorted using insertion sort. 782 | if len <= MAX_INSERTION { 783 | insertion_sort(v, is_less); 784 | return; 785 | } 786 | 787 | // If too many bad pivot choices were made, simply fall back to heapsort in order to 788 | // guarantee `O(n * log(n))` worst-case. 789 | if limit == 0 { 790 | heapsort(v, is_less); 791 | return; 792 | } 793 | 794 | // If the last partitioning was imbalanced, try breaking patterns in the slice by shuffling 795 | // some elements around. Hopefully we'll choose a better pivot this time. 796 | if !was_balanced { 797 | break_patterns(v); 798 | limit -= 1; 799 | } 800 | 801 | // Choose a pivot and try guessing whether the slice is already sorted. 802 | let (pivot, likely_sorted) = choose_pivot(v, is_less); 803 | 804 | // If the last partitioning was decently balanced and didn't shuffle elements, and if pivot 805 | // selection predicts the slice is likely already sorted... 806 | if was_balanced && was_partitioned && likely_sorted { 807 | // Try identifying several out-of-order elements and shifting them to correct 808 | // positions. If the slice ends up being completely sorted, we're done. 809 | if partial_insertion_sort(v, is_less) { 810 | return; 811 | } 812 | } 813 | 814 | // If the chosen pivot is equal to the predecessor, then it's the smallest element in the 815 | // slice. Partition the slice into elements equal to and elements greater than the pivot. 816 | // This case is usually hit when the slice contains many duplicate elements. 817 | if let Some(ref p) = pred { 818 | if !is_less(p, &v[pivot]) { 819 | let mid = partition_equal(v, pivot, is_less); 820 | 821 | // Continue sorting elements greater than the pivot. 822 | v = &mut v[mid..]; 823 | continue; 824 | } 825 | } 826 | 827 | // Partition the slice. 828 | let (mid, was_p) = partition(v, pivot, is_less); 829 | was_balanced = cmp::min(mid, len - mid) >= len / 8; 830 | was_partitioned = was_p; 831 | 832 | // Split the slice into `left`, `pivot`, and `right`. 833 | let (left, right) = v.split_at_mut(mid); 834 | let (pivot, right) = right.split_at_mut(1); 835 | let pivot = &mut pivot[0]; 836 | 837 | if cmp::max(left.len(), right.len()) <= MAX_SEQUENTIAL { 838 | // Recurse into the shorter side only in order to minimize the total number of recursive 839 | // calls and consume less stack space. Then just continue with the longer side (this is 840 | // akin to tail recursion). 841 | if left.len() < right.len() { 842 | recurse(left, is_less, pred, limit, pbar); 843 | v = right; 844 | pred = Some(pivot); 845 | } else { 846 | recurse(right, is_less, Some(pivot), limit, pbar); 847 | v = left; 848 | } 849 | } else { 850 | pbar.inc(1); 851 | // Sort the left and right half in parallel. 852 | rayon_core::join( 853 | || recurse(left, is_less, pred, limit, pbar), 854 | || recurse(right, is_less, Some(pivot), limit, pbar), 855 | ); 856 | break; 857 | } 858 | } 859 | } 860 | 861 | /// Sorts `v` using pattern-defeating quicksort in parallel. 862 | /// 863 | /// The algorithm is unstable, in-place, and *O*(*n* \* log(*n*)) worst-case. 864 | pub fn par_quicksort(v: &mut [T], is_less: F, verbose: bool) 865 | where 866 | T: Send, 867 | F: Fn(&T, &T) -> bool + Sync, 868 | { 869 | // Sorting has no meaningful behavior on zero-sized types. 870 | if mem::size_of::() == 0 { 871 | return; 872 | } 873 | 874 | // Limit the number of imbalanced partitions to `floor(log2(len)) + 1`. 875 | let limit = usize::BITS - v.len().leading_zeros(); 876 | let pbar = if verbose { 877 | let p = ProgressBar::new((v.len() as f64 / 2000.0).ceil() as u64); 878 | p.set_style( 879 | ProgressStyle::with_template( 880 | "{elapsed} elapsed (estimated duration {duration}) {bar:80}", 881 | ) 882 | .unwrap(), 883 | ); 884 | p 885 | } else { 886 | ProgressBar::hidden() 887 | }; 888 | recurse(v, &is_less, None, limit, &pbar); 889 | pbar.finish(); 890 | } 891 | 892 | #[cfg(test)] 893 | mod tests { 894 | use super::heapsort; 895 | use rand::distributions::Uniform; 896 | use rand::{thread_rng, Rng}; 897 | 898 | #[test] 899 | fn test_heapsort() { 900 | let rng = &mut thread_rng(); 901 | 902 | for len in (0..25).chain(500..501) { 903 | for &modulus in &[5, 10, 100] { 904 | let dist = Uniform::new(0, modulus); 905 | for _ in 0..100 { 906 | let v: Vec = rng.sample_iter(&dist).take(len).collect(); 907 | 908 | // Test heapsort using `<` operator. 909 | let mut tmp = v.clone(); 910 | heapsort(&mut tmp, &|a, b| a < b); 911 | assert!(tmp.windows(2).all(|w| w[0] <= w[1])); 912 | 913 | // Test heapsort using `>` operator. 914 | let mut tmp = v.clone(); 915 | heapsort(&mut tmp, &|a, b| a > b); 916 | assert!(tmp.windows(2).all(|w| w[0] >= w[1])); 917 | } 918 | } 919 | } 920 | 921 | // Sort using a completely random comparison function. 922 | // This will reorder the elements *somehow*, but won't panic. 923 | let mut v: Vec<_> = (0..100).collect(); 924 | heapsort(&mut v, &|_, _| thread_rng().gen()); 925 | heapsort(&mut v, &|a, b| a < b); 926 | 927 | for (i, &entry) in v.iter().enumerate() { 928 | assert_eq!(entry, i); 929 | } 930 | } 931 | } 932 | -------------------------------------------------------------------------------- /src/sample.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Result; 2 | use funty::Unsigned; 3 | use rand::distributions::{Distribution, WeightedIndex}; 4 | use rand::thread_rng; 5 | use rayon::prelude::*; 6 | use serde::{Deserialize, Serialize}; 7 | use std::collections::HashMap; 8 | use std::ops::Mul; 9 | 10 | #[derive(Clone, Deserialize, Serialize, Default)] 11 | pub struct KneserNeyCache { 12 | unigram_probs: Option>, 13 | n_delta: HashMap, 14 | } 15 | 16 | pub trait Sample: Send + Sync { 17 | fn count_next_slice(&self, query: &[T]) -> Vec; 18 | 19 | /// Generate a frequency map from occurrence frequency to the number of 20 | /// unique n-grams in the corpus with that frequency. 21 | fn count_ngrams(&self, n: usize) -> HashMap; 22 | 23 | fn get_cache(&self) -> &KneserNeyCache; 24 | 25 | fn get_mut_cache(&mut self) -> &mut KneserNeyCache; 26 | 27 | /// Autoregressively sample num_samples of k characters from an unsmoothed n-gram model.""" 28 | fn sample_unsmoothed( 29 | &self, 30 | query: &[T], 31 | n: usize, 32 | k: usize, 33 | num_samples: usize, 34 | ) -> Result>> { 35 | (0..num_samples) 36 | .into_par_iter() 37 | .map(|_| self.sample(query, n, k)) 38 | .collect() 39 | } 40 | 41 | //// Autoregressively sample a sequence of k characters from an unsmoothed n-gram model.""" 42 | fn sample(&self, query: &[T], n: usize, k: usize) -> Result> { 43 | let mut rng = thread_rng(); 44 | let mut sequence = Vec::from(query); 45 | 46 | for _ in 0..k { 47 | // look at the previous (n - 1) characters to predict the n-gram completion 48 | let start = sequence.len().saturating_sub(n - 1); 49 | let prev = &sequence[start..]; 50 | 51 | let counts = self.count_next_slice(prev); 52 | let dist = WeightedIndex::new(&counts)?; 53 | let sampled_index: T = dist 54 | .sample(&mut rng) 55 | .try_into() 56 | .unwrap_or_else(|_| panic!("Sampled token > T::MAX")); 57 | 58 | sequence.push(sampled_index); 59 | } 60 | 61 | Ok(sequence) 62 | } 63 | 64 | /// Returns interpolated Kneser-Ney smoothed token probability distribution using all previous 65 | /// tokens in the query. 66 | fn get_smoothed_probs(&mut self, query: &[T]) -> Vec { 67 | self.estimate_deltas(1); 68 | self.compute_smoothed_unigram_probs(); 69 | self.smoothed_probs(query) 70 | } 71 | 72 | /// Returns interpolated Kneser-Ney smoothed token probability distribution using all previous 73 | /// tokens in the query. 74 | fn batch_get_smoothed_probs(&mut self, queries: &[Vec]) -> Vec> { 75 | self.estimate_deltas(1); 76 | self.compute_smoothed_unigram_probs(); 77 | 78 | queries 79 | .into_par_iter() 80 | .map(|query| self.smoothed_probs(query)) 81 | .collect() 82 | } 83 | 84 | /// Autoregressively sample num_samples of k characters from a Kneser-Ney smoothed n-gram model. 85 | fn sample_smoothed( 86 | &mut self, 87 | query: &[T], 88 | n: usize, 89 | k: usize, 90 | num_samples: usize, 91 | ) -> Result>> { 92 | self.estimate_deltas(1); 93 | self.compute_smoothed_unigram_probs(); 94 | 95 | (0..num_samples) 96 | .into_par_iter() 97 | .map(|_| self.kn_sample(query, n, k)) 98 | .collect() 99 | } 100 | 101 | /// Returns the Kneser-Ney smoothed token probability distribution for a query 102 | /// continuation using absolute discounting as described in 103 | /// "On structuring probabilistic dependences in stochastic language modelling", page 25, 104 | /// doi:10.1006/csla.1994.1001 105 | fn smoothed_probs(&self, query: &[T]) -> Vec { 106 | let p_continuations = if query.is_empty() { 107 | self.get_cached_smoothed_unigram_probs().to_vec() 108 | } else { 109 | self.smoothed_probs(&query[1..]) 110 | }; 111 | 112 | let counts = self.count_next_slice(&query); 113 | let suffix_count_recip = { 114 | let suffix_count: usize = counts.iter().sum(); 115 | if suffix_count == 0 { 116 | return p_continuations; 117 | } 118 | 1.0 / suffix_count as f64 119 | }; 120 | 121 | let (gt_zero_count, eq_one_count) = get_occurrence_counts(&counts); 122 | let used_suffix_count = gt_zero_count as f64; 123 | let used_once_suffix_count = eq_one_count as f64; 124 | 125 | // Interpolation budget to be distributed according to lower order n-gram distribution 126 | let delta = self.get_cached_delta(query.len() + 1); 127 | let lambda = if delta < 1.0 { 128 | delta.mul(used_suffix_count).mul(suffix_count_recip) 129 | } else { 130 | used_once_suffix_count 131 | + delta 132 | .mul(used_suffix_count - used_once_suffix_count) 133 | .mul(suffix_count_recip) 134 | }; 135 | 136 | let mut probs = Vec::with_capacity(counts.len()); 137 | counts 138 | .iter() 139 | .zip(p_continuations.iter()) 140 | .for_each(|(&count, &p_continuation)| { 141 | let prob = (count as f64 - delta).max(0.0).mul(suffix_count_recip) 142 | + lambda.mul(p_continuation); 143 | probs.push(prob); 144 | }); 145 | probs 146 | } 147 | 148 | /// Autoregressively sample k characters from a Kneser-Ney smoothed n-gram model. 149 | fn kn_sample(&self, query: &[T], n: usize, k: usize) -> Result> { 150 | let mut rng = thread_rng(); 151 | let mut sequence = Vec::from(query); 152 | 153 | for _ in 0..k { 154 | let start = sequence.len().saturating_sub(n - 1); 155 | let prev = &sequence[start..]; 156 | let probs = self.smoothed_probs(prev); 157 | let dist = WeightedIndex::new(&probs)?; 158 | let sampled_index: T = dist 159 | .sample(&mut rng) 160 | .try_into() 161 | .unwrap_or_else(|_| panic!("Sampled token > usize::MAX")); 162 | 163 | sequence.push(sampled_index); 164 | } 165 | 166 | Ok(sequence) 167 | } 168 | 169 | /// Warning: O(k**n) where k is vocabulary size, use with caution. 170 | /// Improve smoothed model quality by replacing the default delta hyperparameters 171 | /// for models of order n and below with improved estimates over the entire index. 172 | /// , page 16. 173 | fn estimate_deltas(&mut self, n: usize) { 174 | for i in 1..n + 1 { 175 | if self.get_cache().n_delta.contains_key(&i) { 176 | continue; 177 | } 178 | 179 | let count_map = self.count_ngrams(i); 180 | let n1 = *count_map.get(&1).unwrap_or(&0) as f64; 181 | let n2 = *count_map.get(&2).unwrap_or(&0) as f64; 182 | 183 | // n1 and n2 are greater than 0 for non-trivial datasets 184 | let delta = if n1 == 0. || n2 == 0. { 185 | 1. 186 | } else { 187 | n1 / (n1 + n2.mul(2.)) 188 | }; 189 | 190 | self.get_mut_cache().n_delta.insert(i, delta); 191 | } 192 | } 193 | 194 | fn get_cached_delta(&self, n: usize) -> f64 { 195 | *self.get_cache().n_delta.get(&n).unwrap_or(&0.5) 196 | } 197 | 198 | /// Returns unigram probabilities with additive smoothing applied. 199 | fn compute_smoothed_unigram_probs(&mut self) { 200 | if let Some(_) = &self.get_cache().unigram_probs { 201 | return; 202 | } 203 | 204 | let eps = 1e-9; 205 | 206 | // Count the number of unique bigrams that end with each token 207 | let counts = self.count_next_slice(&[]); 208 | 209 | let total_count: usize = counts.iter().sum(); 210 | let adjusted_total_count = total_count as f64 + eps.mul(counts.len() as f64); 211 | let unigram_probs: Vec = counts 212 | .iter() 213 | .map(|&count| (count as f64 + eps) / adjusted_total_count) 214 | .collect(); 215 | 216 | self.get_mut_cache().unigram_probs = Some(unigram_probs); 217 | } 218 | 219 | fn get_cached_smoothed_unigram_probs(&self) -> &[f64] { 220 | self.get_cache().unigram_probs.as_ref().unwrap() 221 | } 222 | } 223 | 224 | fn get_occurrence_counts(slice: &[usize]) -> (usize, usize) { 225 | slice 226 | .iter() 227 | .fold((0, 0), |(gt_zero_count, eq_one_count), &c| { 228 | let gt_zero_count = gt_zero_count + (c > 0) as usize; 229 | let eq_one_count = eq_one_count + (c == 1) as usize; 230 | (gt_zero_count, eq_one_count) 231 | }) 232 | } 233 | -------------------------------------------------------------------------------- /src/sharded_in_memory_index.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Result; 2 | use funty::Unsigned; 3 | use std::collections::HashMap; 4 | use rayon::prelude::*; 5 | 6 | use crate::in_memory_index::InMemoryIndexRs; 7 | use crate::sample::{KneserNeyCache, Sample}; 8 | use crate::bindings::sharded_in_memory_index::ShardedInMemoryIndexTrait; 9 | use crate::bindings::in_memory_index::InMemoryIndexTrait; 10 | 11 | /// Expose suffix table functionality over text corpora too large to fit in memory. 12 | pub struct ShardedInMemoryIndexRs { 13 | shards: Vec>, 14 | cache: KneserNeyCache, 15 | } 16 | 17 | impl Sample for ShardedInMemoryIndexRs { 18 | fn get_cache(&self) -> &KneserNeyCache { 19 | &self.cache 20 | } 21 | 22 | fn get_mut_cache(&mut self) -> &mut KneserNeyCache { 23 | &mut self.cache 24 | } 25 | 26 | fn count_next_slice(&self, query: &[T]) -> Vec { 27 | let counts = self 28 | .shards 29 | .par_iter() 30 | .map(|shard| shard.count_next_slice(query)) 31 | .collect::>(); 32 | (0..counts[0].len()) 33 | .map(|i| counts.iter().map(|count| count[i]).sum()) 34 | .collect() 35 | } 36 | 37 | fn count_ngrams(&self, n: usize) -> HashMap { 38 | self.shards.iter().map(|shard| shard.count_ngrams(n)).fold( 39 | HashMap::new(), 40 | |mut acc, counts| { 41 | for (k, v) in counts { 42 | *acc.entry(k).or_insert(0) += v; 43 | } 44 | acc 45 | }, 46 | ) 47 | } 48 | } 49 | 50 | impl ShardedInMemoryIndexRs { 51 | pub fn new(paths: Vec<(String, String)>, vocab: usize) -> Result { 52 | let shards: Vec> = paths 53 | .into_iter() 54 | .map(|(text_path, table_path)| { 55 | InMemoryIndexRs::from_disk(text_path, table_path, vocab).unwrap() 56 | }) 57 | .collect(); 58 | 59 | Ok(ShardedInMemoryIndexRs { 60 | shards, 61 | cache: KneserNeyCache::default(), 62 | }) 63 | } 64 | } 65 | 66 | impl ShardedInMemoryIndexTrait for ShardedInMemoryIndexRs { 67 | fn is_sorted(&self) -> bool { 68 | self.shards.iter().all(|shard| shard.is_sorted()) 69 | } 70 | 71 | fn contains(&self, query: Vec) -> bool { 72 | self.shards 73 | .iter() 74 | .any(|shard| shard.contains(query.clone())) 75 | } 76 | 77 | fn count(&self, query: Vec) -> usize { 78 | self.shards 79 | .iter() 80 | .map(|shard| shard.count(query.clone())) 81 | .sum() 82 | } 83 | 84 | fn count_next(&self, query: Vec) -> Vec { 85 | let query: Vec = query 86 | .iter() 87 | .filter_map(|&item| T::try_from(item).ok()) 88 | .collect(); 89 | 90 | let counts = self 91 | .shards 92 | .iter() 93 | .map(|shard| shard.count_next_slice(&query)) 94 | .collect::>(); 95 | (0..counts[0].len()) 96 | .map(|i| counts.iter().map(|count| count[i]).sum()) 97 | .collect() 98 | } 99 | 100 | fn batch_count_next(&self, queries: Vec>) -> Vec> { 101 | let batch_counts = self 102 | .shards 103 | .iter() 104 | .map(|shard| shard.batch_count_next(queries.clone())) 105 | .collect::>(); 106 | 107 | (0..queries.len()) 108 | .map(|i| { 109 | (0..batch_counts[0][i].len()) 110 | .map(|j| batch_counts.iter().map(|count| count[i][j]).sum()) 111 | .collect() 112 | }) 113 | .collect() 114 | } 115 | 116 | /// Autoregressively sample num_samples of k characters from an unsmoothed n-gram model.""" 117 | fn sample_unsmoothed( 118 | &self, 119 | query: Vec, 120 | n: usize, 121 | k: usize, 122 | num_samples: usize, 123 | ) -> Result>> { 124 | let query: Vec = query 125 | .iter() 126 | .filter_map(|&item| T::try_from(item).ok()) 127 | .collect(); 128 | 129 | let samples_batch = 130 | >::sample_unsmoothed(self, &query, n, k, num_samples)?; 131 | Ok(samples_batch 132 | .into_iter() 133 | .map(|samples| { 134 | samples 135 | .into_iter() 136 | .filter_map(|sample| { 137 | match TryInto::::try_into(sample) { 138 | Ok(value) => Some(value), 139 | Err(_) => None, // Silently skip values that can't be converted 140 | } 141 | }) 142 | .collect::>() 143 | }) 144 | .collect()) 145 | } 146 | 147 | /// Returns interpolated Kneser-Ney smoothed token probability distribution using all previous 148 | /// tokens in the query. 149 | fn get_smoothed_probs(&mut self, query: Vec) -> Vec { 150 | let query: Vec = query 151 | .iter() 152 | .filter_map(|&item| T::try_from(item).ok()) 153 | .collect(); 154 | >::get_smoothed_probs(self, &query) 155 | } 156 | 157 | /// Returns interpolated Kneser-Ney smoothed token probability distribution using all previous 158 | /// tokens in the query. 159 | fn batch_get_smoothed_probs(&mut self, queries: Vec>) -> Vec> { 160 | let queries: Vec> = queries 161 | .into_iter() 162 | .map(|query| { 163 | query 164 | .iter() 165 | .filter_map(|&item| T::try_from(item).ok()) 166 | .collect() 167 | }) 168 | .collect(); 169 | >::batch_get_smoothed_probs(self, &queries) 170 | } 171 | 172 | /// Autoregressively sample num_samples of k characters from a Kneser-Ney smoothed n-gram model. 173 | fn sample_smoothed( 174 | &mut self, 175 | query: Vec, 176 | n: usize, 177 | k: usize, 178 | num_samples: usize, 179 | ) -> Result>> { 180 | let query: Vec = query 181 | .iter() 182 | .filter_map(|&item| T::try_from(item).ok()) 183 | .collect(); 184 | 185 | let samples_batch = >::sample_smoothed(self, &query, n, k, num_samples)?; 186 | Ok(samples_batch 187 | .into_iter() 188 | .map(|samples| { 189 | samples 190 | .into_iter() 191 | .filter_map(|sample| { 192 | match TryInto::::try_into(sample) { 193 | Ok(value) => Some(value), 194 | Err(_) => None, // Silently skip values that can't be converted 195 | } 196 | }) 197 | .collect::>() 198 | }) 199 | .collect()) 200 | } 201 | 202 | /// Warning: O(k**n) where k is vocabulary size, use with caution. 203 | /// Improve smoothed model quality by replacing the default delta hyperparameters 204 | /// for models of order n and below with improved estimates over the entire index. 205 | /// , page 16. 206 | fn estimate_deltas(&mut self, n: usize) { 207 | >::estimate_deltas(self, n); 208 | } 209 | } 210 | -------------------------------------------------------------------------------- /src/sharded_memmap_index.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Result; 2 | use funty::Unsigned; 3 | use std::collections::HashMap; 4 | 5 | use crate::bindings::memmap_index::MemmapIndexTrait; 6 | use crate::bindings::sharded_memmap_index::ShardedMemmapIndexTrait; 7 | use crate::memmap_index::MemmapIndexRs; 8 | use crate::sample::{KneserNeyCache, Sample}; 9 | 10 | /// Expose suffix table functionality over text corpora too large to fit in memory. 11 | pub struct ShardedMemmapIndexRs { 12 | shards: Vec>, 13 | cache: KneserNeyCache, 14 | } 15 | 16 | impl Sample for ShardedMemmapIndexRs { 17 | fn get_cache(&self) -> &KneserNeyCache { 18 | &self.cache 19 | } 20 | 21 | fn get_mut_cache(&mut self) -> &mut KneserNeyCache { 22 | &mut self.cache 23 | } 24 | 25 | fn count_next_slice(&self, query: &[T]) -> Vec { 26 | let counts = self 27 | .shards 28 | .iter() 29 | .map(|shard| shard.count_next_slice(query)) 30 | .collect::>(); 31 | (0..counts[0].len()) 32 | .map(|i| counts.iter().map(|count| count[i]).sum()) 33 | .collect() 34 | } 35 | 36 | fn count_ngrams(&self, n: usize) -> HashMap { 37 | self.shards.iter().map(|shard| shard.count_ngrams(n)).fold( 38 | HashMap::new(), 39 | |mut acc, counts| { 40 | for (k, v) in counts { 41 | *acc.entry(k).or_insert(0) += v; 42 | } 43 | acc 44 | }, 45 | ) 46 | } 47 | } 48 | 49 | impl ShardedMemmapIndexRs { 50 | pub fn new(paths: Vec<(String, String)>, vocab: usize) -> Result { 51 | let shards: Vec> = paths 52 | .into_iter() 53 | .map(|(text_path, table_path)| { 54 | MemmapIndexRs::new(text_path, table_path, vocab).unwrap() 55 | }) 56 | .collect(); 57 | 58 | Ok(ShardedMemmapIndexRs { 59 | shards, 60 | cache: KneserNeyCache::default(), 61 | }) 62 | } 63 | 64 | pub fn build(paths: Vec<(String, String)>, vocab: usize, verbose: bool) -> Result { 65 | let shards: Vec> = paths 66 | .into_iter() 67 | .map(|(token_paths, index_paths)| { 68 | MemmapIndexRs::build(token_paths, index_paths, vocab, verbose).unwrap() 69 | }) 70 | .collect(); 71 | 72 | Ok(ShardedMemmapIndexRs { 73 | shards, 74 | cache: KneserNeyCache::default(), 75 | }) 76 | } 77 | } 78 | 79 | impl ShardedMemmapIndexTrait for ShardedMemmapIndexRs { 80 | fn is_sorted(&self) -> bool { 81 | self.shards.iter().all(|shard| shard.is_sorted()) 82 | } 83 | 84 | fn contains(&self, query: Vec) -> bool { 85 | self.shards 86 | .iter() 87 | .any(|shard| shard.contains(query.clone())) 88 | } 89 | 90 | fn count(&self, query: Vec) -> usize { 91 | self.shards 92 | .iter() 93 | .map(|shard| shard.count(query.clone())) 94 | .sum() 95 | } 96 | 97 | fn count_next(&self, query: Vec) -> Vec { 98 | let query: Vec = query 99 | .iter() 100 | .filter_map(|&item| T::try_from(item).ok()) 101 | .collect(); 102 | 103 | let counts = self 104 | .shards 105 | .iter() 106 | .map(|shard| shard.count_next_slice(&query)) 107 | .collect::>(); 108 | (0..counts[0].len()) 109 | .map(|i| counts.iter().map(|count| count[i]).sum()) 110 | .collect() 111 | } 112 | 113 | fn batch_count_next(&self, queries: Vec>) -> Vec> { 114 | let batch_counts = self 115 | .shards 116 | .iter() 117 | .map(|shard| shard.batch_count_next(queries.clone())) 118 | .collect::>(); 119 | 120 | (0..queries.len()) 121 | .map(|i| { 122 | (0..batch_counts[0][i].len()) 123 | .map(|j| batch_counts.iter().map(|count| count[i][j]).sum()) 124 | .collect() 125 | }) 126 | .collect() 127 | } 128 | 129 | /// Autoregressively sample num_samples of k characters from an unsmoothed n-gram model.""" 130 | fn sample_unsmoothed( 131 | &self, 132 | query: Vec, 133 | n: usize, 134 | k: usize, 135 | num_samples: usize, 136 | ) -> Result>> { 137 | let query: Vec = query 138 | .iter() 139 | .filter_map(|&item| T::try_from(item).ok()) 140 | .collect(); 141 | 142 | let samples_batch = 143 | >::sample_unsmoothed(self, &query, n, k, num_samples)?; 144 | Ok(samples_batch 145 | .into_iter() 146 | .map(|samples| { 147 | samples 148 | .into_iter() 149 | .filter_map(|sample| { 150 | match TryInto::::try_into(sample) { 151 | Ok(value) => Some(value), 152 | Err(_) => None, // Silently skip values that can't be converted 153 | } 154 | }) 155 | .collect::>() 156 | }) 157 | .collect()) 158 | } 159 | 160 | /// Returns interpolated Kneser-Ney smoothed token probability distribution using all previous 161 | /// tokens in the query. 162 | fn get_smoothed_probs(&mut self, query: Vec) -> Vec { 163 | let query: Vec = query 164 | .iter() 165 | .filter_map(|&item| T::try_from(item).ok()) 166 | .collect(); 167 | >::get_smoothed_probs(self, &query) 168 | } 169 | 170 | /// Returns interpolated Kneser-Ney smoothed token probability distribution using all previous 171 | /// tokens in the query. 172 | fn batch_get_smoothed_probs(&mut self, queries: Vec>) -> Vec> { 173 | let queries: Vec> = queries 174 | .into_iter() 175 | .map(|query| { 176 | query 177 | .iter() 178 | .filter_map(|&item| T::try_from(item).ok()) 179 | .collect() 180 | }) 181 | .collect(); 182 | >::batch_get_smoothed_probs(self, &queries) 183 | } 184 | 185 | /// Autoregressively sample num_samples of k characters from a Kneser-Ney smoothed n-gram model. 186 | fn sample_smoothed( 187 | &mut self, 188 | query: Vec, 189 | n: usize, 190 | k: usize, 191 | num_samples: usize, 192 | ) -> Result>> { 193 | let query: Vec = query 194 | .iter() 195 | .filter_map(|&item| T::try_from(item).ok()) 196 | .collect(); 197 | 198 | let samples_batch = >::sample_smoothed(self, &query, n, k, num_samples)?; 199 | Ok(samples_batch 200 | .into_iter() 201 | .map(|samples| { 202 | samples 203 | .into_iter() 204 | .filter_map(|sample| { 205 | match TryInto::::try_into(sample) { 206 | Ok(value) => Some(value), 207 | Err(_) => None, // Silently skip values that can't be converted 208 | } 209 | }) 210 | .collect::>() 211 | }) 212 | .collect()) 213 | } 214 | 215 | /// Warning: O(k**n) where k is vocabulary size, use with caution. 216 | /// Improve smoothed model quality by replacing the default delta hyperparameters 217 | /// for models of order n and below with improved estimates over the entire index. 218 | /// , page 16. 219 | fn estimate_deltas(&mut self, n: usize) { 220 | >::estimate_deltas(self, n); 221 | } 222 | } 223 | -------------------------------------------------------------------------------- /src/table.rs: -------------------------------------------------------------------------------- 1 | extern crate utf16_literal; 2 | 3 | use crate::par_quicksort::par_sort_unstable_by_key; 4 | use funty::Unsigned; 5 | use rayon::prelude::*; 6 | use serde::{Deserialize, Serialize}; 7 | use std::collections::HashMap; 8 | use std::{fmt, ops::Deref, u64}; 9 | 10 | /// A suffix table is a sequence of lexicographically sorted suffixes. 11 | /// The table supports n-gram statistics computation and language modeling over text corpora. 12 | #[derive(Clone, Serialize, Deserialize)] 13 | pub struct SuffixTable, U = Box<[u64]>> { 14 | text: T, 15 | table: U, 16 | vocab: usize, 17 | } 18 | 19 | /// Method for vanilla in-memory suffix tables 20 | impl SuffixTable, Box<[u64]>> { 21 | /// Creates a new suffix table for `text` in `O(n log n)` time and `O(n)` 22 | /// space. 23 | pub fn new(src: S, vocab: Option, verbose: bool) -> Self 24 | where 25 | S: Into>, 26 | { 27 | let text = src.into(); 28 | 29 | // Implicitly store the suffixes using indices into the corpus, 30 | // and sort the suffixes in parallel. Unstable sorting ensures we 31 | // use no extra memory during this operation. 32 | // 33 | // Rayon's implementation falls back to a sequential algorithm for 34 | // sufficiently small inputs, so we don't need to worry about 35 | // parallelism overhead here. 36 | let mut table: Vec<_> = (0..text.len() as u64).collect(); 37 | par_sort_unstable_by_key(&mut table[..], |&i| &text[i as usize..], verbose); 38 | 39 | let vocab = vocab.unwrap_or(u16::MAX as usize + 1); 40 | 41 | SuffixTable { 42 | text, 43 | table: table.into(), 44 | vocab, 45 | } 46 | } 47 | } 48 | 49 | impl SuffixTable 50 | where 51 | E: Unsigned, 52 | T: Deref + Sync, 53 | U: Deref + Sync, 54 | { 55 | pub fn from_parts(text: T, table: U, vocab: Option) -> Self { 56 | let vocab = vocab.unwrap_or(u16::MAX as usize + 1); 57 | SuffixTable { text, table, vocab } 58 | } 59 | 60 | /// Consumes the suffix table and returns the underlying text and table. 61 | pub fn into_parts(self) -> (T, U) { 62 | (self.text, self.table) 63 | } 64 | 65 | /// Returns the number of suffixes in the table. 66 | /// 67 | /// Alternatively, this is the number of *bytes* in the text. 68 | #[inline] 69 | #[allow(dead_code)] 70 | pub fn len(&self) -> usize { 71 | self.table.len() 72 | } 73 | 74 | /// Returns `true` iff `self.len() == 0`. 75 | #[inline] 76 | #[allow(dead_code)] 77 | pub fn is_empty(&self) -> bool { 78 | self.len() == 0 79 | } 80 | 81 | /// Checks if the suffix table is lexicographically sorted. This is always true for valid suffix tables. 82 | pub fn is_sorted(&self) -> bool { 83 | self.table 84 | .par_windows(2) 85 | .all(|pair| self.text[pair[0] as usize..] <= self.text[pair[1] as usize..]) 86 | } 87 | 88 | /// Returns the suffix at index `i`. 89 | #[inline] 90 | #[allow(dead_code)] 91 | pub fn suffix(&self, i: usize) -> &[E] { 92 | &self.text[self.table[i] as usize..] 93 | } 94 | 95 | /// Returns true if and only if `query` is in text. 96 | /// 97 | /// This runs in `O(mlogn)` time, where `m == query.len()` and 98 | /// `n == self.len()`. (As far as this author knows, this is the best known 99 | /// bound for a plain suffix table.) 100 | /// 101 | /// You should prefer this over `positions` when you only need to test 102 | /// existence (because it is faster). 103 | /// 104 | /// # Example 105 | /// 106 | /// Build a suffix array of some text and test existence of a substring: 107 | /// 108 | /// ```rust 109 | /// use tokengrams::SuffixTable; 110 | /// use utf16_literal::utf16; 111 | /// 112 | /// let sa = SuffixTable::new(utf16!("The quick brown fox.").to_vec(), None, false); 113 | /// assert!(sa.contains(utf16!("quick"))); 114 | /// ``` 115 | #[allow(dead_code)] 116 | pub fn contains(&self, query: &[E]) -> bool { 117 | !query.is_empty() 118 | && self 119 | .table 120 | .binary_search_by(|&sufi| { 121 | self.text[sufi as usize..] 122 | .iter() 123 | .take(query.len()) 124 | .cmp(query.iter()) 125 | }) 126 | .is_ok() 127 | } 128 | 129 | /// Returns an unordered list of positions where `query` starts in `text`. 130 | /// 131 | /// This runs in `O(mlogn)` time, where `m == query.len()` and 132 | /// `n == self.len()`. (As far as this author knows, this is the best known 133 | /// bound for a plain suffix table.) 134 | /// 135 | /// Positions are byte indices into `text`. 136 | /// 137 | /// If you just need to test existence, then use `contains` since it is 138 | /// faster. 139 | /// 140 | /// # Example 141 | /// 142 | /// Build a suffix array of some text and find all occurrences of a 143 | /// substring: 144 | /// 145 | /// ```rust 146 | /// use tokengrams::SuffixTable; 147 | /// use utf16_literal::utf16; 148 | /// 149 | /// let sa = SuffixTable::new(utf16!("The quick brown fox was very quick.").to_vec(), None, false); 150 | /// assert_eq!(sa.positions(utf16!("quick")), &[4, 29]); 151 | /// ``` 152 | #[allow(dead_code)] 153 | pub fn positions(&self, query: &[E]) -> &[u64] { 154 | // We can quickly decide whether the query won't match at all if 155 | // it's outside the range of suffixes. 156 | if self.text.is_empty() 157 | || query.is_empty() 158 | || (query < self.suffix(0) && !self.suffix(0).starts_with(query)) 159 | || query > self.suffix(self.len() - 1) 160 | { 161 | return &[]; 162 | } 163 | 164 | // The below is pretty close to the algorithm on Wikipedia: 165 | // 166 | // http://en.wikipedia.org/wiki/Suffix_array#Applications 167 | // 168 | // The key difference is that after we find the start index, we look 169 | // for the end by finding the first occurrence that doesn't start 170 | // with `query`. That becomes our upper bound. 171 | let start = binary_search(&self.table, |&sufi| query <= &self.text[sufi as usize..]); 172 | let end = start 173 | + binary_search(&self.table[start..], |&sufi| { 174 | !self.text[sufi as usize..].starts_with(query) 175 | }); 176 | 177 | // Whoops. If start is somehow greater than end, then we've got 178 | // nothing. 179 | if start > end { 180 | &[] 181 | } else { 182 | &self.table[start..end] 183 | } 184 | } 185 | 186 | /// Determine start and end `table` indices of items that start with `query`. 187 | fn boundaries(&self, query: &[E]) -> (usize, usize) { 188 | if self.text.is_empty() || query.is_empty() { 189 | return (0, self.table.len()); 190 | } 191 | if (query < self.suffix(0) && !self.suffix(0).starts_with(query)) 192 | || query > self.suffix(self.len() - 1) 193 | { 194 | return (0, 0); 195 | } 196 | 197 | let start = binary_search(&self.table, |&sufi| query <= &self.text[sufi as usize..]); 198 | let end = start 199 | + binary_search(&self.table[start..], |&sufi| { 200 | !self.text[sufi as usize..].starts_with(query) 201 | }); 202 | 203 | (start, end) 204 | } 205 | 206 | /// Determine start and end indices of items that start with `query` in the `table` range. 207 | fn range_boundaries( 208 | &self, 209 | query: &[E], 210 | range_start: usize, 211 | range_end: usize, 212 | ) -> (usize, usize) { 213 | if self.text.is_empty() 214 | || query.is_empty() 215 | || range_start.eq(&range_end) 216 | || (query < self.suffix(range_start) && !self.suffix(range_start).starts_with(query)) 217 | || query > self.suffix(std::cmp::max(0, range_end - 1)) 218 | { 219 | return (0, 0); 220 | } 221 | 222 | let start = binary_search(&self.table[range_start..range_end], |&sufi| { 223 | query <= &self.text[sufi as usize..] 224 | }); 225 | let end = start 226 | + binary_search(&self.table[range_start + start..range_end], |&sufi| { 227 | !self.text[sufi as usize..].starts_with(query) 228 | }); 229 | 230 | if start > end { 231 | (0, 0) 232 | } else { 233 | (range_start + start, range_start + end) 234 | } 235 | } 236 | 237 | // Count occurrences of each token directly following the query sequence. 238 | pub fn count_next(&self, query: &[E]) -> Vec { 239 | let mut counts: Vec = vec![0; self.vocab]; 240 | 241 | let (range_start, range_end) = self.boundaries(query); 242 | self.recurse_count_next(&mut counts, query, range_start, range_end); 243 | counts 244 | } 245 | 246 | // count_next helper method. 247 | fn recurse_count_next( 248 | &self, 249 | counts: &mut Vec, 250 | query: &[E], 251 | search_start: usize, 252 | search_end: usize, 253 | ) { 254 | if search_start >= search_end { 255 | return; 256 | } 257 | 258 | let mid = (search_start + search_end) / 2; 259 | let mut suffix = self.suffix(mid); 260 | // The search range may include the query itself, so we need to skip over it. 261 | if suffix == query { 262 | if mid + 1 == search_end { 263 | return; 264 | } 265 | suffix = self.suffix(mid + 1); 266 | } 267 | 268 | let (token_start, token_end) = 269 | self.range_boundaries(&suffix[..query.len() + 1], search_start, search_end); 270 | 271 | counts[suffix[query.len()].as_usize()] = token_end - token_start; 272 | 273 | if search_start < token_start { 274 | self.recurse_count_next(counts, query, search_start, token_start); 275 | } 276 | if token_end < search_end { 277 | self.recurse_count_next(counts, query, token_end, search_end); 278 | } 279 | } 280 | 281 | // count_ngrams helper method. 282 | fn recurse_count_ngrams( 283 | &self, 284 | search_start: usize, 285 | search_end: usize, 286 | n: usize, 287 | query: &[E], 288 | target_n: usize, 289 | count_map: &mut HashMap, 290 | ) { 291 | if search_start == search_end { 292 | return; 293 | } 294 | 295 | let mid = (search_start + search_end) / 2; 296 | let mut suffix = self.suffix(mid); 297 | // The search range may include the query itself, so we need to skip over it. 298 | if suffix == query { 299 | if mid + 1 == search_end { 300 | return; 301 | } 302 | suffix = self.suffix(mid + 1); 303 | } 304 | 305 | let (start, end) = 306 | self.range_boundaries(&suffix[..query.len() + 1], search_start, search_end); 307 | if n < target_n { 308 | self.recurse_count_ngrams( 309 | start, 310 | end, 311 | n + 1, 312 | &suffix[..query.len() + 1], 313 | target_n, 314 | count_map, 315 | ); 316 | } else { 317 | *count_map.entry(end - start).or_insert(0) += 1; 318 | } 319 | 320 | if search_start < start { 321 | self.recurse_count_ngrams(search_start, start, n, query, target_n, count_map); 322 | } 323 | if end < search_end { 324 | self.recurse_count_ngrams(end, search_end, n, query, target_n, count_map); 325 | } 326 | } 327 | 328 | // For a given n, produce a map from an occurrence count to the number of unique n-grams with that occurrence count. 329 | pub fn count_ngrams(&self, n: usize) -> HashMap { 330 | let mut count_map = HashMap::new(); 331 | let (range_start, range_end) = self.boundaries(&[]); 332 | self.recurse_count_ngrams(range_start, range_end, 1, &[], n, &mut count_map); 333 | count_map 334 | } 335 | 336 | pub fn batch_count_next(&self, queries: &[Vec]) -> Vec> { 337 | queries 338 | .into_par_iter() 339 | .map(|query| self.count_next(query)) 340 | .collect() 341 | } 342 | 343 | pub fn get_table(&self) -> &[u64] { 344 | &self.table 345 | } 346 | 347 | pub fn get_text(&self) -> &[E] { 348 | &self.text 349 | } 350 | } 351 | 352 | impl fmt::Debug for SuffixTable { 353 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 354 | writeln!(f, "\n-----------------------------------------")?; 355 | writeln!(f, "SUFFIX TABLE")?; 356 | for (rank, &sufstart) in self.table.iter().enumerate() { 357 | writeln!(f, "suffix[{}] {}", rank, sufstart,)?; 358 | } 359 | writeln!(f, "-----------------------------------------") 360 | } 361 | } 362 | 363 | /// Binary search to find first element such that `pred(T) == true`. 364 | /// 365 | /// Assumes that if `pred(xs[i]) == true` then `pred(xs[i+1]) == true`. 366 | /// 367 | /// If all elements yield `pred(T) == false`, then `xs.len()` is returned. 368 | #[allow(dead_code)] 369 | fn binary_search(xs: &[T], mut pred: F) -> usize 370 | where 371 | F: FnMut(&T) -> bool, 372 | { 373 | let (mut left, mut right) = (0, xs.len()); 374 | while left < right { 375 | let mid = (left + right) / 2; 376 | if pred(&xs[mid]) { 377 | right = mid; 378 | } else { 379 | left = mid + 1; 380 | } 381 | } 382 | left 383 | } 384 | 385 | #[cfg(test)] 386 | mod tests { 387 | use super::*; 388 | use utf16_literal::utf16; 389 | 390 | fn sais(text: &str) -> SuffixTable { 391 | SuffixTable::new(text.encode_utf16().collect::>(), None, false) 392 | } 393 | 394 | #[test] 395 | fn count_next_exists() { 396 | let sa = sais("aaab"); 397 | 398 | let query = utf16!("a"); 399 | let a_index = utf16!("a")[0] as usize; 400 | let b_index = utf16!("b")[0] as usize; 401 | 402 | assert_eq!(2, sa.count_next(query)[a_index]); 403 | assert_eq!(1, sa.count_next(query)[b_index]); 404 | } 405 | 406 | #[test] 407 | fn count_next_empty_query() { 408 | let sa = sais("aaab"); 409 | 410 | let query = utf16!(""); 411 | let a_index = utf16!("a")[0] as usize; 412 | let b_index = utf16!("b")[0] as usize; 413 | 414 | assert_eq!(3, sa.count_next(query)[a_index]); 415 | assert_eq!(1, sa.count_next(query)[b_index]); 416 | } 417 | 418 | #[test] 419 | fn batch_count_next_exists() { 420 | let sa = sais("aaab"); 421 | 422 | let queries: Vec> = vec![vec![utf16!("a")[0]; 1]; 10_000]; 423 | 424 | let a_index = utf16!("a")[0] as usize; 425 | let b_index = utf16!("b")[0] as usize; 426 | 427 | assert_eq!(2, sa.batch_count_next(&queries)[0][a_index]); 428 | assert_eq!(1, sa.batch_count_next(&queries)[0][b_index]); 429 | } 430 | } 431 | -------------------------------------------------------------------------------- /src/util.rs: -------------------------------------------------------------------------------- 1 | /// Return a zero-copy view of the given slice with the given type. 2 | /// The resulting view has the same lifetime as the provided slice. 3 | #[inline] 4 | pub fn transmute_slice<'a, T, U>(slice: &'a [T]) -> &'a [U] { 5 | // SAFETY: We use floor division to ensure that we can't read past the end of the slice. 6 | let new_len = (slice.len() * std::mem::size_of::()) / std::mem::size_of::(); 7 | unsafe { std::slice::from_raw_parts(slice.as_ptr() as *const U, new_len) } 8 | } 9 | 10 | #[cfg(test)] 11 | mod tests { 12 | use super::*; 13 | use rand::RngCore; 14 | 15 | macro_rules! test_transmute { 16 | ($bytes:ident, $type:ty) => { 17 | let num_bytes = std::mem::size_of::<$type>(); 18 | let transmuted = transmute_slice::(&$bytes); 19 | assert_eq!(transmuted.len(), $bytes.len() / num_bytes); 20 | 21 | for (cis, trans) in $bytes.chunks(num_bytes).zip(transmuted) { 22 | assert_eq!(<$type>::from_le_bytes(cis.try_into().unwrap()), *trans); 23 | } 24 | }; 25 | } 26 | 27 | #[test] 28 | fn test_transmute_slice() { 29 | let mut rng = rand::thread_rng(); 30 | let mut bytes = vec![0u8; 100]; 31 | rng.fill_bytes(&mut bytes); 32 | 33 | test_transmute!(bytes, u8); 34 | test_transmute!(bytes, u16); 35 | test_transmute!(bytes, u32); 36 | test_transmute!(bytes, u64); 37 | test_transmute!(bytes, u128); 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /tests/tests.rs: -------------------------------------------------------------------------------- 1 | extern crate quickcheck; 2 | extern crate utf16_literal; 3 | 4 | use quickcheck::{QuickCheck, Testable}; 5 | use tokengrams::SuffixTable; 6 | use utf16_literal::utf16; 7 | 8 | fn sais(text: &str) -> SuffixTable { 9 | SuffixTable::new(text.encode_utf16().collect::>(), None, false) 10 | } 11 | 12 | fn qc(f: T) { 13 | QuickCheck::new().tests(1000).max_tests(10000).quickcheck(f); 14 | } 15 | 16 | // Do some testing on substring search. 17 | 18 | #[test] 19 | fn empty_find_empty() { 20 | let sa = sais(""); 21 | assert_eq!(sa.positions(&[]), &[]); 22 | assert!(!sa.contains(&[])); 23 | } 24 | 25 | #[test] 26 | fn empty_find_one() { 27 | let sa = sais(""); 28 | assert_eq!(sa.positions(utf16!("a")), &[]); 29 | assert!(!sa.contains(utf16!("a"))); 30 | } 31 | 32 | #[test] 33 | fn empty_find_two() { 34 | let sa = sais(""); 35 | assert_eq!(sa.positions(utf16!("ab")), &[]); 36 | assert!(!sa.contains(utf16!("ab"))); 37 | } 38 | 39 | #[test] 40 | fn one_find_empty() { 41 | let sa = sais("a"); 42 | assert_eq!(sa.positions(utf16!("")), &[]); 43 | assert!(!sa.contains(utf16!(""))); 44 | } 45 | 46 | #[test] 47 | fn one_find_one_notexists() { 48 | let sa = sais("a"); 49 | assert_eq!(sa.positions(utf16!("b")), &[]); 50 | assert!(!sa.contains(utf16!("b"))); 51 | } 52 | 53 | #[test] 54 | fn one_find_one_exists() { 55 | let sa = sais("a"); 56 | assert_eq!(sa.positions(utf16!("a")), &[0]); 57 | assert!(sa.contains(utf16!("a"))); 58 | } 59 | 60 | #[test] 61 | fn two_find_one_exists() { 62 | let sa = sais("ab"); 63 | assert_eq!(sa.positions(utf16!("b")), &[1]); 64 | assert!(sa.contains(utf16!("b"))); 65 | } 66 | 67 | #[test] 68 | fn two_find_two_exists() { 69 | let sa = sais("aa"); 70 | assert_eq!(vec![1, 0], sa.positions(utf16!("a"))); 71 | assert!(sa.contains(utf16!("a"))); 72 | } 73 | 74 | #[test] 75 | fn many_exists() { 76 | let sa = sais("zzzzzaazzzzz"); 77 | assert_eq!(vec![5, 6], sa.positions(utf16!("a"))); 78 | assert!(sa.contains(utf16!("a"))); 79 | } 80 | 81 | #[test] 82 | fn many_exists_long() { 83 | let sa = sais("zzzzabczzzzzabczzzzzz"); 84 | assert_eq!(sa.positions(utf16!("abc")), &[4, 12]); 85 | assert!(sa.contains(utf16!("abc"))); 86 | } 87 | 88 | #[test] 89 | fn query_longer() { 90 | let sa = sais("az"); 91 | assert_eq!(sa.positions(utf16!("mnomnomnomnomnomnomno")), &[]); 92 | assert!(!sa.contains(utf16!("mnomnomnomnomnomnomno"))); 93 | } 94 | 95 | #[test] 96 | fn query_longer_less() { 97 | let sa = sais("zz"); 98 | assert_eq!(sa.positions(utf16!("mnomnomnomnomnomnomno")), &[]); 99 | assert!(!sa.contains(utf16!("mnomnomnomnomnomnomno"))); 100 | } 101 | 102 | #[test] 103 | fn query_longer_greater() { 104 | let sa = sais("aa"); 105 | assert_eq!(sa.positions(utf16!("mnomnomnomnomnomnomno")), &[]); 106 | assert!(!sa.contains(utf16!("mnomnomnomnomnomnomno"))); 107 | } 108 | 109 | #[test] 110 | fn query_spaces() { 111 | let sa = sais("The quick brown fox was very quick."); 112 | assert_eq!(sa.positions(utf16!("quick")), &[4, 29]); 113 | } 114 | 115 | #[test] 116 | fn prop_length() { 117 | fn prop(s: String) -> bool { 118 | s.encode_utf16().count() == sais(&s).len() 119 | } 120 | qc(prop as fn(String) -> bool); 121 | } 122 | 123 | #[test] 124 | fn prop_contains() { 125 | fn prop(s: String, c: u8) -> bool { 126 | let c = (c as char).to_string(); 127 | let c16 = c.encode_utf16().collect::>(); 128 | s.contains(&c) == sais(&s).contains(c16.as_slice()) 129 | } 130 | qc(prop as fn(String, u8) -> bool); 131 | } 132 | 133 | #[test] 134 | fn prop_positions() { 135 | fn prop(s: String, c: u16) -> bool { 136 | let s = s.encode_utf16().collect::>(); 137 | let table = SuffixTable::new(s.clone(), None, false); 138 | 139 | let got = table.positions(&[c]); 140 | for (i, c_) in s.into_iter().enumerate() { 141 | if (c_ == c) != got.contains(&(i as u64)) { 142 | return false; 143 | } 144 | } 145 | true 146 | } 147 | qc(prop as fn(String, u16) -> bool); 148 | } 149 | -------------------------------------------------------------------------------- /tokengrams/__init__.py: -------------------------------------------------------------------------------- 1 | from .tokengrams import ( 2 | InMemoryIndex, 3 | MemmapIndex, 4 | ShardedMemmapIndex, 5 | ShardedInMemoryIndex 6 | ) 7 | 8 | from .utils.tokenize_hf_dataset import tokenize_hf_dataset -------------------------------------------------------------------------------- /tokengrams/benchmark/InMemoryIndex_build_times.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EleutherAI/tokengrams/9efe08829a1667e6e65f5cac29babf03c5bbf020/tokengrams/benchmark/InMemoryIndex_build_times.png -------------------------------------------------------------------------------- /tokengrams/benchmark/InMemoryIndex_count_next_times.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EleutherAI/tokengrams/9efe08829a1667e6e65f5cac29babf03c5bbf020/tokengrams/benchmark/InMemoryIndex_count_next_times.png -------------------------------------------------------------------------------- /tokengrams/benchmark/MemmapIndex_build_times.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EleutherAI/tokengrams/9efe08829a1667e6e65f5cac29babf03c5bbf020/tokengrams/benchmark/MemmapIndex_build_times.png -------------------------------------------------------------------------------- /tokengrams/benchmark/MemmapIndex_count_next_times.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EleutherAI/tokengrams/9efe08829a1667e6e65f5cac29babf03c5bbf020/tokengrams/benchmark/MemmapIndex_count_next_times.png -------------------------------------------------------------------------------- /tokengrams/benchmark/benchmark.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import time 4 | from argparse import ArgumentParser 5 | from typing import Literal 6 | from pathlib import Path 7 | 8 | from tokengrams import MemmapIndex, InMemoryIndex 9 | import numpy as np 10 | import plotly.graph_objects as go 11 | 12 | def benchmark(document_path: str, cls: Literal["InMemoryIndex", "MemmapIndex"], encoding_width=16, vocab=2**16): 13 | slice_sizes = [1, 10, 100, 1000, 10_000, 100_000, 1_000_000, 10_000_000, 100_000_000, 1_000_000_000] 14 | 15 | file_size = os.path.getsize(document_path) 16 | assert encoding_width % 8 == 0, "Encoding width must be a multiple of 8" 17 | total_tokens = file_size // (encoding_width / 8) # Divide by document word length in bytes 18 | 19 | build_times = [] 20 | count_next_times = [] 21 | 22 | tokens = np.memmap(document_path, dtype=np.uint16, mode='r') 23 | for size in slice_sizes: 24 | if size > total_tokens: 25 | print(f"Skipping slice size {size} as it exceeds the total number of tokens.") 26 | continue 27 | 28 | slice_data = tokens[:size] 29 | output_file = f"tmp_slice_{size}.bin" 30 | with open(output_file, 'wb') as f: 31 | slice_data.tofile(f) 32 | 33 | print(f"Created file from slice of {size} tokens: {output_file}") 34 | 35 | # Build index 36 | tmp_index_file = f"tmp_slice_{size}.idx" 37 | start = time.time() 38 | if cls == "MemmapIndex": 39 | index = MemmapIndex.build(output_file, tmp_index_file, verbose=True) 40 | else: 41 | index = InMemoryIndex.from_token_file(output_file, verbose=True) 42 | build_times.append(time.time() - start) 43 | print(f"Built index for slice of {size} tokens in {time.time() - start:.2f} seconds. Updated data:") 44 | print(build_times) 45 | 46 | # Count next with empty query (count unigrams) 47 | start = time.time() 48 | index.count_next([]) 49 | count_next_times.append(time.time() - start) 50 | 51 | os.remove(output_file) 52 | if os.path.exists(tmp_index_file): 53 | os.remove(tmp_index_file) 54 | 55 | return build_times, count_next_times 56 | 57 | def plot( 58 | times: list[float], 59 | cls: Literal["InMemoryIndex", "MemmapIndex"], 60 | label: Literal["build", "count_next"] 61 | ) -> None: 62 | x = [10 ** i for i in range(len(times))] 63 | 64 | fig = go.Figure() 65 | fig.add_trace(go.Scatter( 66 | x=x, 67 | y=times, 68 | mode='lines+markers', 69 | line=dict(shape='spline', smoothing=1.3, color='rgb(55, 126, 184)'), 70 | marker=dict(color='rgb(55, 126, 184)') 71 | )) 72 | 73 | fig.update_layout( 74 | title=f'{cls} {label} times over corpus sizes', 75 | xaxis_title='Corpus size (tokens)', 76 | yaxis_title=f'{label.capitalize()} time (seconds)', 77 | width=800, 78 | height=500, 79 | margin=dict(l=80, r=50, t=80, b=80) 80 | ) 81 | 82 | ticktext = [f'1e{int(math.log10(val))}' for val in x] 83 | fig.update_xaxes( 84 | type='log', 85 | tickmode='array', 86 | tickvals=x, 87 | ticktext=ticktext 88 | ) 89 | 90 | # Get the enclosing powers of ten for the range 91 | y_log10_floor = math.floor(math.log10(min(times))) 92 | y_log10_ceil = math.ceil(math.log10(max(times))) 93 | 94 | y_tickvals = np.logspace(y_log10_floor, y_log10_ceil, num=(y_log10_ceil - y_log10_floor) + 1) 95 | 96 | fig.update_yaxes( 97 | type='log', 98 | range=[y_log10_floor, y_log10_ceil], 99 | tickmode='array', 100 | tickvals=y_tickvals, 101 | tickformat="," 102 | ) 103 | 104 | output_path = Path(f"tokengrams/benchmark/{cls}_{label}_times.png") 105 | fig.write_image(output_path, scale=5) 106 | print(f"Plot saved to {str(output_path)}") 107 | 108 | if __name__ == '__main__': 109 | parser = ArgumentParser() 110 | parser.add_argument("--data_path", default=None, help="Path to tokenized corpus file") 111 | parser.add_argument("--encoding_width", default=16, type=int, help="Bits per token") 112 | parser.add_argument("--cls", default="MemmapIndex", choices=["InMemoryIndex", "MemmapIndex"], 113 | help="Index class to benchmark") 114 | args = parser.parse_args() 115 | 116 | if args.data_path: 117 | build_times, count_next_times = benchmark( 118 | args.data_path, 119 | args.cls, 120 | args.encoding_width 121 | ) 122 | plot(build_times, args.cls, label="build") 123 | plot(count_next_times, args.cls, label="count_next") 124 | else: 125 | print("No path to token corpus found, plotting precomputed benchmark data for MemmapIndex:") 126 | benchmark_data = [ 127 | 0.00757908821105957, 128 | 0.0053365230560302734, 129 | 0.0067026615142822266, 130 | 0.009454727172851562, 131 | 0.026259660720825195, 132 | 0.11438727378845215, 133 | 1.1938815116882324, 134 | 10.36899209022522, 135 | 110.39609742164612, 136 | 1263.0914874076843 137 | ] 138 | plot(benchmark_data, "MemmapIndex", label="build") -------------------------------------------------------------------------------- /tokengrams/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EleutherAI/tokengrams/9efe08829a1667e6e65f5cac29babf03c5bbf020/tokengrams/tests/__init__.py -------------------------------------------------------------------------------- /tokengrams/tests/test_gram_index.py: -------------------------------------------------------------------------------- 1 | from itertools import pairwise 2 | from tempfile import NamedTemporaryFile 3 | 4 | from tokengrams import InMemoryIndex, MemmapIndex 5 | from hypothesis import given, strategies as st 6 | 7 | import numpy as np 8 | 9 | 10 | def check_gram_index(index: InMemoryIndex | MemmapIndex, tokens: list[int]): 11 | assert index.is_sorted() 12 | 13 | # Check unigram counts 14 | for t in tokens: 15 | assert index.contains([t]) == (t in tokens) 16 | assert index.count([t]) == tokens.count(t) 17 | 18 | # Check bigram counts 19 | bigrams = list(pairwise(tokens)) 20 | for b in bigrams: 21 | assert index.contains(list(b)) == (b in bigrams) 22 | assert index.count(list(b)) == bigrams.count(b) 23 | 24 | @given( 25 | st.lists( 26 | st.integers(0, 2 ** 16 - 1), min_size=1, 27 | ) 28 | ) 29 | def test_gram_index(tokens: list[int]): 30 | index = InMemoryIndex(tokens) 31 | check_gram_index(index, tokens) 32 | 33 | # Save to disk and check that we can load it back 34 | with NamedTemporaryFile() as f: 35 | index.save_tokens(f.name) 36 | 37 | index = InMemoryIndex.from_token_file(f.name) 38 | 39 | check_gram_index(index, tokens) 40 | 41 | with NamedTemporaryFile() as idx: 42 | index = MemmapIndex.build(f.name, idx.name) 43 | check_gram_index(index, tokens) 44 | 45 | index = MemmapIndex(f.name, idx.name) 46 | check_gram_index(index, tokens) 47 | 48 | # Now check limited token loading 49 | for limit in range(1, len(tokens) + 1): 50 | index = InMemoryIndex.from_token_file(f.name, limit) 51 | check_gram_index(index, tokens[:limit]) 52 | -------------------------------------------------------------------------------- /tokengrams/tests/test_sharded_index.py: -------------------------------------------------------------------------------- 1 | from itertools import pairwise 2 | from tempfile import NamedTemporaryFile 3 | import random 4 | from tokengrams import MemmapIndex, ShardedMemmapIndex 5 | 6 | import numpy as np 7 | 8 | 9 | def check_sharded_index(index: ShardedMemmapIndex, tokens: list[int], eos_token: int): 10 | # Check unigram counts 11 | for t in tokens: 12 | assert index.contains([t]) == (t in tokens) 13 | assert index.count([t]) == tokens.count(t) 14 | 15 | # Check bigram counts 16 | bigrams = list(pairwise(tokens)) 17 | for b in bigrams: 18 | if not eos_token in b: 19 | assert index.contains(list(b)) == (b in bigrams) 20 | assert index.count(list(b)) == bigrams.count(b) 21 | 22 | # Check bigram samples 23 | for i in range(len(tokens[:20])): 24 | query = tokens[:i] 25 | sample = index.sample_unsmoothed(query, 2, 1, 1)[0] 26 | assert len(sample) == 1 + len(query) 27 | assert all(s in tokens for s in sample) 28 | 29 | 30 | def test_sharded_index(): 31 | tokens = [random.randint(0, 2**16 - 1) for _ in range(10_000)] 32 | 33 | eos_token = 0 34 | mid = len(tokens) // 2 35 | chunked_tokens = tokens[:mid] + [eos_token] + tokens[mid:] + [eos_token] 36 | 37 | with NamedTemporaryFile() as token_file_1, NamedTemporaryFile() as index_file_1, \ 38 | NamedTemporaryFile() as token_file_2, NamedTemporaryFile() as index_file_2: 39 | 40 | shard_files = [ 41 | (token_file_1.name, index_file_1.name), 42 | (token_file_2.name, index_file_2.name) 43 | ] 44 | 45 | token_file_1.write(np.array(chunked_tokens[:mid + 1], dtype=np.uint16).tobytes()) 46 | token_file_1.flush() 47 | token_file_2.write(np.array(chunked_tokens[mid + 1:], dtype=np.uint16).tobytes()) 48 | token_file_2.flush() 49 | 50 | for token_file, index_file in shard_files: 51 | MemmapIndex.build(token_file, index_file) 52 | 53 | index = ShardedMemmapIndex(shard_files) 54 | check_sharded_index(index, chunked_tokens, eos_token) -------------------------------------------------------------------------------- /tokengrams/tokengrams.pyi: -------------------------------------------------------------------------------- 1 | class InMemoryIndex: 2 | """An n-gram index.""" 3 | 4 | def __init__(self, tokens: list[int], vocab: int = 2**16, verbose: bool = False) -> None: 5 | ... 6 | 7 | @staticmethod 8 | def from_token_file(path: str, token_limit: int | None = None, vocab: int = 2**16, verbose: bool = False) -> "InMemoryIndex": 9 | """Construct a `InMemoryIndex` from a file containing raw little-endian tokens.""" 10 | 11 | def from_disk(self, token_path: str, index_path: str, vocab: int = 2**16) -> "InMemoryIndex": 12 | """Load a pretrained index from disk.""" 13 | 14 | def save_tokens(self, path: str): 15 | """Save the tokens to a file.""" 16 | 17 | def save_index(self, path: str): 18 | """Save the index to disk.""" 19 | 20 | def is_sorted(self) -> bool: 21 | """Check if the index's suffix table is sorted lexicographically. 22 | This is always true for valid indices.""" 23 | 24 | def contains(self, query: list[int]) -> bool: 25 | """Check if `query` has nonzero count. Faster than `count(query) > 0`.""" 26 | 27 | def count(self, query: list[int]) -> int: 28 | """Count the number of occurrences of `query` in the index.""" 29 | 30 | def positions(self, query: list[int]) -> list[int]: 31 | """Returns an unordered list of positions where `query` starts in `tokens`.""" 32 | 33 | def count_next(self, query: list[int]) -> list[int]: 34 | """Count the occurrences of each token directly following `query`.""" 35 | 36 | def batch_count_next(self, queries: list[list[int]]) -> list[list[int]]: 37 | """Count the occurrences of each token that directly follows each sequence in `queries`.""" 38 | 39 | def get_smoothed_probs(self, query: list[int]) -> list[float]: 40 | """Compute interpolated Kneser-Ney smoothed token probability distribution using all previous tokens in the query.""" 41 | 42 | def batch_get_smoothed_probs(self, queries: list[list[int]]) -> list[list[float]]: 43 | """Compute interpolated Kneser-Ney smoothed token probability distributions using all previous tokens in each query.""" 44 | 45 | def sample_smoothed(self, query: list[int], n: int, k: int, num_samples: int) -> list[list[int]]: 46 | """Autoregressively samples num_samples of k characters each from Kneser-Ney smoothed conditional 47 | distributions based on the previous (n - 1) characters (n-gram prefix) in the sequence. If there are 48 | fewer than (n - 1) characters all available characters are used.""" 49 | 50 | def sample_unsmoothed(self, query: list[int], n: int, k: int, num_samples: int) -> list[list[int]]: 51 | """Autoregressively samples num_samples of k characters each from conditional distributions based 52 | on the previous (n - 1) characters (n-gram prefix) in the sequence. If there are fewer than 53 | (n - 1) characters all available characters are used.""" 54 | 55 | def estimate_deltas(self, n: int): 56 | """Warning: O(k**n) where k is vocabulary size, use with caution. 57 | Improve smoothed model quality by replacing the default delta hyperparameters 58 | for models of order n and below with improved estimates over the entire index. 59 | https://people.eecs.berkeley.edu/~klein/cs294-5/chen_goodman.pdf, page 16.""" 60 | 61 | class MemmapIndex: 62 | """An n-gram index backed by a memory-mapped file.""" 63 | 64 | def __init__(self, token_path: str, index_path: str, vocab: int = 2**16) -> None: 65 | """Load a prebuilt memory-mapped index from a pair of files.""" 66 | 67 | @staticmethod 68 | def build(token_path: str, index_path: str, vocab: int = 2**16, verbose: bool = False) -> "MemmapIndex": 69 | """Build a memory-mapped index from a token file.""" 70 | 71 | def is_sorted(self) -> bool: 72 | """Check if the index's suffix table is sorted lexicographically. 73 | This is always true for valid indices.""" 74 | 75 | def contains(self, query: list[int]) -> bool: 76 | """Check if `query` has nonzero count. Faster than `count(query) > 0`.""" 77 | 78 | def count(self, query: list[int]) -> int: 79 | """Count the number of occurrences of `query` in the index.""" 80 | 81 | def positions(self, query: list[int]) -> list[int]: 82 | """Returns an unordered list of positions where `query` starts in `tokens`.""" 83 | 84 | def count_next(self, query: list[int]) -> list[int]: 85 | """Count the occurrences of each token directly following `query`.""" 86 | 87 | def batch_count_next(self, queries: list[list[int]]) -> list[list[int]]: 88 | """Count the occurrences of each token that directly follows each sequence in `queries`.""" 89 | 90 | def sample_smoothed(self, query: list[int], n: int, k: int, num_samples: int) -> list[list[int]]: 91 | """Autoregressively samples num_samples of k characters each from Kneser-Ney smoothed conditional 92 | distributions based on the previous (n - 1) characters (n-gram prefix) in the sequence. If there are 93 | fewer than (n - 1) characters all available characters are used.""" 94 | 95 | def sample_unsmoothed(self, query: list[int], n: int, k: int, num_samples: int) -> list[list[int]]: 96 | """Autoregressively samples num_samples of k characters each from conditional distributions based 97 | on the previous (n - 1) characters (n-gram prefix) in the sequence. If there are fewer than 98 | (n - 1) characters all available characters are used.""" 99 | 100 | def get_smoothed_probs(self, query: list[int]) -> list[float]: 101 | """Compute interpolated Kneser-Ney smoothed token probability distribution using all previous tokens in the query.""" 102 | 103 | def batch_get_smoothed_probs(self, queries: list[list[int]]) -> list[list[float]]: 104 | """Compute interpolated Kneser-Ney smoothed token probability distributions using all previous tokens in each query.""" 105 | 106 | def estimate_deltas(self, n: int): 107 | """Warning: O(k**n) where k is vocabulary size, use with caution. 108 | Improve smoothed model quality by replacing the default delta hyperparameters 109 | for models of order n and below with improved estimates over the entire index. 110 | https://people.eecs.berkeley.edu/~klein/cs294-5/chen_goodman.pdf, page 16.""" 111 | 112 | class ShardedMemmapIndex: 113 | """An n-gram index backed by several memory-mapped files.""" 114 | 115 | def __init__(self, paths: list[tuple[str, str]], vocab: int = 2**16) -> None: 116 | """Load a prebuilt memory-mapped index from a list of pairs of files in form (token_file, index_file).""" 117 | 118 | @staticmethod 119 | def build(paths: list[tuple[str, str]], vocab: int = 2**16, verbose: bool = False) -> "ShardedMemmapIndex": 120 | """Build a memory-mapped index from a token file.""" 121 | 122 | def is_sorted(self) -> bool: 123 | """Check if the index's suffix table is sorted lexicographically. 124 | This is always true for valid indices.""" 125 | 126 | def contains(self, query: list[int]) -> bool: 127 | """Check if `query` has nonzero count. Faster than `count(query) > 0`.""" 128 | 129 | def count(self, query: list[int]) -> int: 130 | """Count the number of occurrences of `query` in the index.""" 131 | 132 | def count_next(self, query: list[int]) -> list[int]: 133 | """Count the occurrences of each token directly following `query`.""" 134 | 135 | def batch_count_next(self, queries: list[list[int]]) -> list[list[int]]: 136 | """Count the occurrences of each token that directly follows each sequence in `queries`.""" 137 | 138 | def sample_smoothed(self, query: list[int], n: int, k: int, num_samples: int) -> list[list[int]]: 139 | """Autoregressively samples num_samples of k characters each from Kneser-Ney smoothed conditional 140 | distributions based on the previous (n - 1) characters (n-gram prefix) in the sequence. If there are 141 | fewer than (n - 1) characters all available characters are used.""" 142 | 143 | def sample_unsmoothed(self, query: list[int], n: int, k: int, num_samples: int) -> list[list[int]]: 144 | """Autoregressively samples num_samples of k characters each from conditional distributions based 145 | on the previous (n - 1) characters (n-gram prefix) in the sequence. If there are fewer than 146 | (n - 1) characters all available characters are used.""" 147 | 148 | def get_smoothed_probs(self, query: list[int]) -> list[float]: 149 | """Compute interpolated Kneser-Ney smoothed token probability distribution using all previous tokens in the query.""" 150 | 151 | def batch_get_smoothed_probs(self, queries: list[list[int]]) -> list[list[float]]: 152 | """Compute interpolated Kneser-Ney smoothed token probability distributions using all previous tokens in each query.""" 153 | 154 | def estimate_deltas(self, n: int): 155 | """Warning: O(k**n) where k is vocabulary size, use with caution. 156 | Improve smoothed model quality by replacing the default delta hyperparameters 157 | for models of order n and below with improved estimates over the entire index. 158 | https://people.eecs.berkeley.edu/~klein/cs294-5/chen_goodman.pdf, page 16.""" 159 | 160 | 161 | class ShardedInMemoryIndex: 162 | """An n-gram index backed by several memory-mapped files.""" 163 | 164 | def __init__(self, paths: list[tuple[str, str]], vocab: int = 2**16) -> None: 165 | """Load a prebuilt memory-mapped index from a list of pairs of files in form (token_file, index_file).""" 166 | 167 | def is_sorted(self) -> bool: 168 | """Check if the index's suffix table is sorted lexicographically. 169 | This is always true for valid indices.""" 170 | 171 | def contains(self, query: list[int]) -> bool: 172 | """Check if `query` has nonzero count. Faster than `count(query) > 0`.""" 173 | 174 | def count(self, query: list[int]) -> int: 175 | """Count the number of occurrences of `query` in the index.""" 176 | 177 | def count_next(self, query: list[int]) -> list[int]: 178 | """Count the occurrences of each token directly following `query`.""" 179 | 180 | def batch_count_next(self, queries: list[list[int]]) -> list[list[int]]: 181 | """Count the occurrences of each token that directly follows each sequence in `queries`.""" 182 | 183 | def sample_smoothed(self, query: list[int], n: int, k: int, num_samples: int) -> list[list[int]]: 184 | """Autoregressively samples num_samples of k characters each from Kneser-Ney smoothed conditional 185 | distributions based on the previous (n - 1) characters (n-gram prefix) in the sequence. If there are 186 | fewer than (n - 1) characters all available characters are used.""" 187 | 188 | def sample_unsmoothed(self, query: list[int], n: int, k: int, num_samples: int) -> list[list[int]]: 189 | """Autoregressively samples num_samples of k characters each from conditional distributions based 190 | on the previous (n - 1) characters (n-gram prefix) in the sequence. If there are fewer than 191 | (n - 1) characters all available characters are used.""" 192 | 193 | def get_smoothed_probs(self, query: list[int]) -> list[float]: 194 | """Compute interpolated Kneser-Ney smoothed token probability distribution using all previous tokens in the query.""" 195 | 196 | def batch_get_smoothed_probs(self, queries: list[list[int]]) -> list[list[float]]: 197 | """Compute interpolated Kneser-Ney smoothed token probability distributions using all previous tokens in each query.""" 198 | 199 | def estimate_deltas(self, n: int): 200 | """Warning: O(k**n) where k is vocabulary size, use with caution. 201 | Improve smoothed model quality by replacing the default delta hyperparameters 202 | for models of order n and below with improved estimates over the entire index. 203 | https://people.eecs.berkeley.edu/~klein/cs294-5/chen_goodman.pdf, page 16.""" 204 | -------------------------------------------------------------------------------- /tokengrams/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EleutherAI/tokengrams/9efe08829a1667e6e65f5cac29babf03c5bbf020/tokengrams/utils/__init__.py -------------------------------------------------------------------------------- /tokengrams/utils/tokenize_hf_dataset.py: -------------------------------------------------------------------------------- 1 | import multiprocessing as mp 2 | from typing import Union, Generator 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict, concatenate_datasets 7 | from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast 8 | from tqdm import tqdm 9 | 10 | 11 | def tokenize_hf_dataset( 12 | dataset: Dataset | DatasetDict | IterableDataset | IterableDatasetDict, 13 | tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast, 14 | output_path: Union[Path, str], 15 | text_key="text", 16 | append_eod: bool = False, 17 | workers: int = 1, 18 | ): 19 | output_path = Path(output_path) 20 | 21 | batch_size = 10_000 22 | 23 | eos_token = tokenizer.eos_token_id if append_eod else None 24 | 25 | vocab_size = get_vocab_size(tokenizer) 26 | if vocab_size > 2**32: 27 | raise ValueError(f"Tokenizer vocab size {vocab_size} is too large for uint32") 28 | 29 | data = get_dataset_iterator(dataset, batch_size) 30 | 31 | # Tokenize and save as memory-mapped array 32 | total_tokens = tokenize_and_write_mmap( 33 | data, 34 | tokenizer, 35 | output_path, 36 | eos_token=eos_token, 37 | text_key=text_key, 38 | num_workers=workers, 39 | dtype=np.dtype(np.uint16 if vocab_size < 2**16 else np.uint32) 40 | ) 41 | print(f"{total_tokens} tokens saved to {output_path}") 42 | 43 | def get_vocab_size(tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast) -> int: 44 | """Get the vocab size of the tokenizer.""" 45 | if hasattr(tokenizer, 'vocab_size'): 46 | return tokenizer.vocab_size 47 | elif hasattr(tokenizer, 'get_vocab'): 48 | return len(tokenizer.get_vocab()) 49 | else: 50 | return len(tokenizer) 51 | 52 | def get_dataset_iterator(data: Union[Dataset, DatasetDict, IterableDataset, IterableDatasetDict], batch_size: int): 53 | """Get an iterator for the dataset, handling different dataset types.""" 54 | if isinstance(data, IterableDataset): 55 | return iter(data.iter(batch_size=batch_size)) 56 | elif isinstance(data, Dataset): 57 | return ( 58 | data.select(range(i, min(i + batch_size, len(data)))) 59 | for i in range(0, len(data), batch_size) 60 | ) 61 | elif isinstance(data, DatasetDict) or isinstance(data, IterableDatasetDict): 62 | # Concatenate all available splits 63 | concatenated_dataset = concatenate_datasets(list(data.values())) 64 | return concatenated_dataset.iter(batch_size=batch_size) 65 | else: 66 | raise ValueError(f"Unsupported dataset type: {type(data)}") 67 | 68 | def tokenize_batch(args): 69 | batch, tokenizer, text_key, eos_token = args 70 | tokenized = tokenizer(batch[text_key], add_special_tokens=False, truncation=False, padding=False) 71 | suffix = [eos_token] if eos_token is not None else [] 72 | 73 | all_tokens = [] 74 | for tokens in tokenized['input_ids']: 75 | all_tokens.extend(tokens) 76 | all_tokens.extend(suffix) 77 | return all_tokens 78 | 79 | def tokenize_and_write_mmap( 80 | data: Generator, 81 | tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast, 82 | output_path: Path, 83 | text_key: str = "text", 84 | buffer_size: int = 10_000_000, 85 | eos_token: int | None = None, 86 | num_workers: int = 4, 87 | dtype: np.dtype = np.dtype(np.uint16) 88 | ): 89 | mmap = np.memmap(output_path, dtype=dtype, mode='w+', shape=(buffer_size,)) 90 | 91 | total_tokens = 0 92 | pool = mp.Pool(num_workers) 93 | 94 | pbar = tqdm(desc="Tokenizing") 95 | for batch in data: 96 | tokenize_args = [(batch, tokenizer, text_key, eos_token)] 97 | new_tokens = pool.map(tokenize_batch, tokenize_args)[0] 98 | 99 | if total_tokens + len(new_tokens) > mmap.shape[0]: 100 | mmap = np.memmap(output_path, dtype=dtype, mode='r+', shape=(mmap.shape[0] * 2,)) 101 | 102 | mmap[total_tokens:total_tokens + len(new_tokens)] = new_tokens 103 | total_tokens += len(new_tokens) 104 | pbar.update(len(batch)) 105 | 106 | pool.close() 107 | pool.join() 108 | 109 | # Resize mmap to actual size 110 | with open(output_path, 'r+b') as f: 111 | f.truncate(total_tokens * dtype.itemsize) 112 | 113 | mmap = np.memmap(output_path, dtype=dtype, mode='r+', shape=(total_tokens,)) 114 | mmap.flush() 115 | 116 | pbar.close() 117 | return total_tokens --------------------------------------------------------------------------------