├── .github
└── workflows
│ ├── publish-wheel.yml
│ ├── python.yml
│ └── rust.yml
├── .gitignore
├── .vscode
└── settings.json
├── Cargo.lock
├── Cargo.toml
├── LICENSE.md
├── README.md
├── demos
└── generate_text.ipynb
├── environment.yml
├── pyproject.toml
├── src
├── bindings
│ ├── in_memory_index.rs
│ ├── memmap_index.rs
│ ├── mod.rs
│ ├── sharded_in_memory_index.rs
│ └── sharded_memmap_index.rs
├── in_memory_index.rs
├── lib.rs
├── memmap_index.rs
├── mmap_slice.rs
├── par_quicksort.rs
├── sample.rs
├── sharded_in_memory_index.rs
├── sharded_memmap_index.rs
├── table.rs
└── util.rs
├── tests
└── tests.rs
└── tokengrams
├── __init__.py
├── benchmark
├── InMemoryIndex_build_times.png
├── InMemoryIndex_count_next_times.png
├── MemmapIndex_build_times.png
├── MemmapIndex_count_next_times.png
└── benchmark.py
├── tests
├── __init__.py
├── test_gram_index.py
└── test_sharded_index.py
├── tokengrams.pyi
└── utils
├── __init__.py
└── tokenize_hf_dataset.py
/.github/workflows/publish-wheel.yml:
--------------------------------------------------------------------------------
1 | name: Publish
2 | on:
3 | workflow_dispatch:
4 |
5 | jobs:
6 | build:
7 | name: Build wheels for ${{ matrix.os }} - Python ${{ matrix.python-version }}
8 | strategy:
9 | fail-fast: false
10 | matrix:
11 | include:
12 | # 3.10
13 | - os: ubuntu-latest
14 | target: x86_64-unknown-linux-gnu
15 | python-version: '3.10'
16 | - os: ubuntu-latest
17 | target: aarch64-unknown-linux-gnu
18 | python-version: '3.10'
19 | - os: macos-latest
20 | target: x86_64-apple-darwin
21 | python-version: '3.10'
22 | - os: macos-latest
23 | target: aarch64-apple-darwin
24 | python-version: '3.10'
25 | - os: windows-latest
26 | target: x86_64-pc-windows-msvc
27 | python-version: '3.10'
28 | # 3.11-
29 | - os: ubuntu-latest
30 | target: x86_64-unknown-linux-gnu
31 | python-version: '3.11'
32 | - os: ubuntu-latest
33 | target: aarch64-unknown-linux-gnu
34 | python-version: '3.11'
35 | - os: macos-latest
36 | target: x86_64-apple-darwin
37 | python-version: '3.11'
38 | - os: macos-latest
39 | target: aarch64-apple-darwin
40 | python-version: '3.11'
41 | - os: windows-latest
42 | target: x86_64-pc-windows-msvc
43 | python-version: '3.11'
44 |
45 | runs-on: ${{ matrix.os }}
46 | steps:
47 | - uses: actions/checkout@v3
48 | - name: Set up Python ${{ matrix.python-version }}
49 | uses: actions/setup-python@v4
50 | with:
51 | python-version: ${{ matrix.python-version }}
52 | - name: Set up Rust
53 | uses: actions-rs/toolchain@v1
54 | with:
55 | profile: minimal
56 | toolchain: stable
57 | target: ${{ matrix.target }}
58 | override: true
59 | - name: Add macOS Rust targets
60 | if: matrix.os == 'macos-latest'
61 | run: |
62 | rustup target add x86_64-apple-darwin aarch64-apple-darwin
63 | - name: Build wheels
64 | uses: PyO3/maturin-action@v1
65 | with:
66 | target: ${{ matrix.target }}
67 | args: --release --out dist --interpreter python${{ matrix.python-version }} ${{ matrix.os == 'macos-latest' && '--target universal2-apple-darwin' || '' }}
68 | manylinux: auto
69 | - name: Upload wheels
70 | uses: actions/upload-artifact@v2
71 | with:
72 | name: wheels
73 | path: dist
74 |
75 | build-sdist:
76 | name: Build source distribution
77 | runs-on: ubuntu-latest
78 | steps:
79 | - uses: actions/checkout@v3
80 | - name: Build sdist
81 | uses: PyO3/maturin-action@v1
82 | with:
83 | command: sdist
84 | args: --out dist
85 | - name: Upload sdist
86 | uses: actions/upload-artifact@v2
87 | with:
88 | name: wheels
89 | path: dist
90 |
91 | publish:
92 | name: Publish to PyPI
93 | needs: [build, build-sdist]
94 | runs-on: ubuntu-latest
95 | steps:
96 | - uses: actions/download-artifact@v2
97 | with:
98 | name: wheels
99 | path: dist
100 | - name: Publish to PyPI
101 | uses: pypa/gh-action-pypi-publish@release/v1
102 | with:
103 | user: __token__
104 | password: ${{ secrets.PYPI_API_TOKEN }}
105 | packages_dir: dist/
106 | skip_existing: true
--------------------------------------------------------------------------------
/.github/workflows/python.yml:
--------------------------------------------------------------------------------
1 | name: Python Package using Conda
2 |
3 | on: [push]
4 |
5 | jobs:
6 | build-linux:
7 | runs-on: ubuntu-latest
8 | strategy:
9 | max-parallel: 5
10 |
11 | steps:
12 | - uses: actions/checkout@v3
13 | - name: Set up Conda
14 | uses: conda-incubator/setup-miniconda@v3
15 | with:
16 | python-version: "3.10"
17 | miniforge-version: latest
18 | use-mamba: true
19 | mamba-version: "*"
20 | - name: Test Python
21 | env:
22 | PYTHONPATH: /home/runner/work/tokengrams/tokengrams
23 | shell: bash -l {0}
24 | run: |
25 | mamba install -c conda-forge numpy pytest hypothesis maturin
26 | maturin develop
27 | maturin build
28 | python -m pip install --user ./target/wheels/tokengrams*.whl
29 | pytest
--------------------------------------------------------------------------------
/.github/workflows/rust.yml:
--------------------------------------------------------------------------------
1 | name: Rust
2 |
3 | on:
4 | push:
5 | branches: [ "master" ]
6 | pull_request:
7 | branches: [ "master" ]
8 |
9 | env:
10 | CARGO_TERM_COLOR: always
11 |
12 | jobs:
13 | build:
14 |
15 | runs-on: ubuntu-latest
16 |
17 | steps:
18 | - uses: actions/checkout@v3
19 | - name: Build
20 | run: cargo build --verbose
21 | - name: Run tests
22 | run: cargo test --verbose
23 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | /target
2 |
3 |
4 | # Byte-compiled / optimized / DLL files
5 | __pycache__/
6 | *.py[cod]
7 | *$py.class
8 |
9 | # C extensions
10 | *.so
11 |
12 | # Distribution / packaging
13 | .Python
14 | build/
15 | develop-eggs/
16 | dist/
17 | downloads/
18 | eggs/
19 | .eggs/
20 | lib/
21 | lib64/
22 | parts/
23 | sdist/
24 | var/
25 | wheels/
26 | share/python-wheels/
27 | *.egg-info/
28 | .installed.cfg
29 | *.egg
30 | MANIFEST
31 |
32 | # PyInstaller
33 | # Usually these files are written by a python script from a template
34 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
35 | *.manifest
36 | *.spec
37 |
38 | # Installer logs
39 | pip-log.txt
40 | pip-delete-this-directory.txt
41 |
42 | # Unit test / coverage reports
43 | htmlcov/
44 | .tox/
45 | .nox/
46 | .coverage
47 | .coverage.*
48 | .cache
49 | nosetests.xml
50 | coverage.xml
51 | *.cover
52 | *.py,cover
53 | .hypothesis/
54 | .pytest_cache/
55 | cover/
56 |
57 | # Translations
58 | *.mo
59 | *.pot
60 |
61 | # Django stuff:
62 | *.log
63 | local_settings.py
64 | db.sqlite3
65 | db.sqlite3-journal
66 |
67 | # Flask stuff:
68 | instance/
69 | .webassets-cache
70 |
71 | # Scrapy stuff:
72 | .scrapy
73 |
74 | # Sphinx documentation
75 | docs/_build/
76 |
77 | # PyBuilder
78 | .pybuilder/
79 | target/
80 |
81 | # Jupyter Notebook
82 | .ipynb_checkpoints
83 |
84 | # IPython
85 | profile_default/
86 | ipython_config.py
87 |
88 | # pyenv
89 | # For a library or package, you might want to ignore these files since the code is
90 | # intended to run in multiple environments; otherwise, check them in:
91 | # .python-version
92 |
93 | # pipenv
94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
97 | # install all needed dependencies.
98 | #Pipfile.lock
99 |
100 | # poetry
101 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
102 | # This is especially recommended for binary packages to ensure reproducibility, and is more
103 | # commonly ignored for libraries.
104 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
105 | #poetry.lock
106 |
107 | # pdm
108 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
109 | #pdm.lock
110 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
111 | # in version control.
112 | # https://pdm.fming.dev/#use-with-ide
113 | .pdm.toml
114 |
115 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
116 | __pypackages__/
117 |
118 | # Celery stuff
119 | celerybeat-schedule
120 | celerybeat.pid
121 |
122 | # SageMath parsed files
123 | *.sage.py
124 |
125 | # Environments
126 | .env
127 | .venv
128 | env/
129 | venv/
130 | ENV/
131 | env.bak/
132 | venv.bak/
133 |
134 | # Spyder project settings
135 | .spyderproject
136 | .spyproject
137 |
138 | # Rope project settings
139 | .ropeproject
140 |
141 | # mkdocs documentation
142 | /site
143 |
144 | # mypy
145 | .mypy_cache/
146 | .dmypy.json
147 | dmypy.json
148 |
149 | # Pyre type checker
150 | .pyre/
151 |
152 | # pytype static type analyzer
153 | .pytype/
154 |
155 | # Cython debug symbols
156 | cython_debug/
157 |
158 | # PyCharm
159 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
160 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
161 | # and can be added to the global gitignore or merged into this file. For a more nuclear
162 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
163 | #.idea/
164 |
165 | # MacOS
166 | .DS_Store
--------------------------------------------------------------------------------
/.vscode/settings.json:
--------------------------------------------------------------------------------
1 | {
2 | "rust-analyzer.linkedProjects": [
3 | "./Cargo.toml",
4 | ],
5 | "python.testing.pytestArgs": [
6 | "tests"
7 | ],
8 | "python.testing.unittestEnabled": false,
9 | "python.testing.pytestEnabled": true
10 | }
--------------------------------------------------------------------------------
/Cargo.lock:
--------------------------------------------------------------------------------
1 | # This file is automatically @generated by Cargo.
2 | # It is not intended for manual editing.
3 | version = 3
4 |
5 | [[package]]
6 | name = "anyhow"
7 | version = "1.0.81"
8 | source = "registry+https://github.com/rust-lang/crates.io-index"
9 | checksum = "0952808a6c2afd1aa8947271f3a60f1a6763c7b912d210184c5149b5cf147247"
10 |
11 | [[package]]
12 | name = "autocfg"
13 | version = "1.1.0"
14 | source = "registry+https://github.com/rust-lang/crates.io-index"
15 | checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa"
16 |
17 | [[package]]
18 | name = "bincode"
19 | version = "1.3.3"
20 | source = "registry+https://github.com/rust-lang/crates.io-index"
21 | checksum = "b1f45e9417d87227c7a56d22e471c6206462cba514c7590c09aff4cf6d1ddcad"
22 | dependencies = [
23 | "serde",
24 | ]
25 |
26 | [[package]]
27 | name = "cfg-if"
28 | version = "1.0.0"
29 | source = "registry+https://github.com/rust-lang/crates.io-index"
30 | checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
31 |
32 | [[package]]
33 | name = "console"
34 | version = "0.15.8"
35 | source = "registry+https://github.com/rust-lang/crates.io-index"
36 | checksum = "0e1f83fc076bd6dd27517eacdf25fef6c4dfe5f1d7448bafaaf3a26f13b5e4eb"
37 | dependencies = [
38 | "encode_unicode",
39 | "lazy_static",
40 | "libc",
41 | "unicode-width",
42 | "windows-sys",
43 | ]
44 |
45 | [[package]]
46 | name = "crossbeam-deque"
47 | version = "0.8.5"
48 | source = "registry+https://github.com/rust-lang/crates.io-index"
49 | checksum = "613f8cc01fe9cf1a3eb3d7f488fd2fa8388403e97039e2f73692932e291a770d"
50 | dependencies = [
51 | "crossbeam-epoch",
52 | "crossbeam-utils",
53 | ]
54 |
55 | [[package]]
56 | name = "crossbeam-epoch"
57 | version = "0.9.18"
58 | source = "registry+https://github.com/rust-lang/crates.io-index"
59 | checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e"
60 | dependencies = [
61 | "crossbeam-utils",
62 | ]
63 |
64 | [[package]]
65 | name = "crossbeam-utils"
66 | version = "0.8.19"
67 | source = "registry+https://github.com/rust-lang/crates.io-index"
68 | checksum = "248e3bacc7dc6baa3b21e405ee045c3047101a49145e7e9eca583ab4c2ca5345"
69 |
70 | [[package]]
71 | name = "either"
72 | version = "1.10.0"
73 | source = "registry+https://github.com/rust-lang/crates.io-index"
74 | checksum = "11157ac094ffbdde99aa67b23417ebdd801842852b500e395a45a9c0aac03e4a"
75 |
76 | [[package]]
77 | name = "encode_unicode"
78 | version = "0.3.6"
79 | source = "registry+https://github.com/rust-lang/crates.io-index"
80 | checksum = "a357d28ed41a50f9c765dbfe56cbc04a64e53e5fc58ba79fbc34c10ef3df831f"
81 |
82 | [[package]]
83 | name = "erased-serde"
84 | version = "0.4.5"
85 | source = "registry+https://github.com/rust-lang/crates.io-index"
86 | checksum = "24e2389d65ab4fab27dc2a5de7b191e1f6617d1f1c8855c0dc569c94a4cbb18d"
87 | dependencies = [
88 | "serde",
89 | "typeid",
90 | ]
91 |
92 | [[package]]
93 | name = "funty"
94 | version = "2.0.0"
95 | source = "registry+https://github.com/rust-lang/crates.io-index"
96 | checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c"
97 |
98 | [[package]]
99 | name = "getrandom"
100 | version = "0.1.16"
101 | source = "registry+https://github.com/rust-lang/crates.io-index"
102 | checksum = "8fc3cb4d91f53b50155bdcfd23f6a4c39ae1969c2ae85982b135750cccaf5fce"
103 | dependencies = [
104 | "cfg-if",
105 | "libc",
106 | "wasi 0.9.0+wasi-snapshot-preview1",
107 | ]
108 |
109 | [[package]]
110 | name = "getrandom"
111 | version = "0.2.11"
112 | source = "registry+https://github.com/rust-lang/crates.io-index"
113 | checksum = "fe9006bed769170c11f845cf00c7c1e9092aeb3f268e007c3e760ac68008070f"
114 | dependencies = [
115 | "cfg-if",
116 | "libc",
117 | "wasi 0.11.0+wasi-snapshot-preview1",
118 | ]
119 |
120 | [[package]]
121 | name = "heck"
122 | version = "0.5.0"
123 | source = "registry+https://github.com/rust-lang/crates.io-index"
124 | checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
125 |
126 | [[package]]
127 | name = "indicatif"
128 | version = "0.17.8"
129 | source = "registry+https://github.com/rust-lang/crates.io-index"
130 | checksum = "763a5a8f45087d6bcea4222e7b72c291a054edf80e4ef6efd2a4979878c7bea3"
131 | dependencies = [
132 | "console",
133 | "instant",
134 | "number_prefix",
135 | "portable-atomic",
136 | "unicode-width",
137 | ]
138 |
139 | [[package]]
140 | name = "indoc"
141 | version = "2.0.4"
142 | source = "registry+https://github.com/rust-lang/crates.io-index"
143 | checksum = "1e186cfbae8084e513daff4240b4797e342f988cecda4fb6c939150f96315fd8"
144 |
145 | [[package]]
146 | name = "instant"
147 | version = "0.1.12"
148 | source = "registry+https://github.com/rust-lang/crates.io-index"
149 | checksum = "7a5bbe824c507c5da5956355e86a746d82e0e1464f65d862cc5e71da70e94b2c"
150 | dependencies = [
151 | "cfg-if",
152 | ]
153 |
154 | [[package]]
155 | name = "inventory"
156 | version = "0.3.15"
157 | source = "registry+https://github.com/rust-lang/crates.io-index"
158 | checksum = "f958d3d68f4167080a18141e10381e7634563984a537f2a49a30fd8e53ac5767"
159 |
160 | [[package]]
161 | name = "lazy_static"
162 | version = "1.4.0"
163 | source = "registry+https://github.com/rust-lang/crates.io-index"
164 | checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646"
165 |
166 | [[package]]
167 | name = "libc"
168 | version = "0.2.151"
169 | source = "registry+https://github.com/rust-lang/crates.io-index"
170 | checksum = "302d7ab3130588088d277783b1e2d2e10c9e9e4a16dd9050e6ec93fb3e7048f4"
171 |
172 | [[package]]
173 | name = "memmap2"
174 | version = "0.9.4"
175 | source = "registry+https://github.com/rust-lang/crates.io-index"
176 | checksum = "fe751422e4a8caa417e13c3ea66452215d7d63e19e604f4980461212f3ae1322"
177 | dependencies = [
178 | "libc",
179 | ]
180 |
181 | [[package]]
182 | name = "memoffset"
183 | version = "0.9.0"
184 | source = "registry+https://github.com/rust-lang/crates.io-index"
185 | checksum = "5a634b1c61a95585bd15607c6ab0c4e5b226e695ff2800ba0cdccddf208c406c"
186 | dependencies = [
187 | "autocfg",
188 | ]
189 |
190 | [[package]]
191 | name = "number_prefix"
192 | version = "0.4.0"
193 | source = "registry+https://github.com/rust-lang/crates.io-index"
194 | checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3"
195 |
196 | [[package]]
197 | name = "once_cell"
198 | version = "1.19.0"
199 | source = "registry+https://github.com/rust-lang/crates.io-index"
200 | checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92"
201 |
202 | [[package]]
203 | name = "portable-atomic"
204 | version = "1.6.0"
205 | source = "registry+https://github.com/rust-lang/crates.io-index"
206 | checksum = "7170ef9988bc169ba16dd36a7fa041e5c4cbeb6a35b76d4c03daded371eae7c0"
207 |
208 | [[package]]
209 | name = "ppv-lite86"
210 | version = "0.2.17"
211 | source = "registry+https://github.com/rust-lang/crates.io-index"
212 | checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de"
213 |
214 | [[package]]
215 | name = "proc-macro2"
216 | version = "1.0.86"
217 | source = "registry+https://github.com/rust-lang/crates.io-index"
218 | checksum = "5e719e8df665df0d1c8fbfd238015744736151d4445ec0836b8e628aae103b77"
219 | dependencies = [
220 | "unicode-ident",
221 | ]
222 |
223 | [[package]]
224 | name = "pyo3"
225 | version = "0.22.2"
226 | source = "registry+https://github.com/rust-lang/crates.io-index"
227 | checksum = "831e8e819a138c36e212f3af3fd9eeffed6bf1510a805af35b0edee5ffa59433"
228 | dependencies = [
229 | "anyhow",
230 | "cfg-if",
231 | "indoc",
232 | "libc",
233 | "memoffset",
234 | "once_cell",
235 | "portable-atomic",
236 | "pyo3-build-config",
237 | "pyo3-ffi",
238 | "pyo3-macros",
239 | "unindent",
240 | ]
241 |
242 | [[package]]
243 | name = "pyo3-build-config"
244 | version = "0.22.2"
245 | source = "registry+https://github.com/rust-lang/crates.io-index"
246 | checksum = "1e8730e591b14492a8945cdff32f089250b05f5accecf74aeddf9e8272ce1fa8"
247 | dependencies = [
248 | "once_cell",
249 | "target-lexicon",
250 | ]
251 |
252 | [[package]]
253 | name = "pyo3-ffi"
254 | version = "0.22.2"
255 | source = "registry+https://github.com/rust-lang/crates.io-index"
256 | checksum = "5e97e919d2df92eb88ca80a037969f44e5e70356559654962cbb3316d00300c6"
257 | dependencies = [
258 | "libc",
259 | "pyo3-build-config",
260 | ]
261 |
262 | [[package]]
263 | name = "pyo3-macros"
264 | version = "0.22.2"
265 | source = "registry+https://github.com/rust-lang/crates.io-index"
266 | checksum = "eb57983022ad41f9e683a599f2fd13c3664d7063a3ac5714cae4b7bee7d3f206"
267 | dependencies = [
268 | "proc-macro2",
269 | "pyo3-macros-backend",
270 | "quote",
271 | "syn",
272 | ]
273 |
274 | [[package]]
275 | name = "pyo3-macros-backend"
276 | version = "0.22.2"
277 | source = "registry+https://github.com/rust-lang/crates.io-index"
278 | checksum = "ec480c0c51ddec81019531705acac51bcdbeae563557c982aa8263bb96880372"
279 | dependencies = [
280 | "heck",
281 | "proc-macro2",
282 | "pyo3-build-config",
283 | "quote",
284 | "syn",
285 | ]
286 |
287 | [[package]]
288 | name = "quickcheck"
289 | version = "0.9.2"
290 | source = "registry+https://github.com/rust-lang/crates.io-index"
291 | checksum = "a44883e74aa97ad63db83c4bf8ca490f02b2fc02f92575e720c8551e843c945f"
292 | dependencies = [
293 | "rand 0.7.3",
294 | "rand_core 0.5.1",
295 | ]
296 |
297 | [[package]]
298 | name = "quote"
299 | version = "1.0.35"
300 | source = "registry+https://github.com/rust-lang/crates.io-index"
301 | checksum = "291ec9ab5efd934aaf503a6466c5d5251535d108ee747472c3977cc5acc868ef"
302 | dependencies = [
303 | "proc-macro2",
304 | ]
305 |
306 | [[package]]
307 | name = "rand"
308 | version = "0.7.3"
309 | source = "registry+https://github.com/rust-lang/crates.io-index"
310 | checksum = "6a6b1679d49b24bbfe0c803429aa1874472f50d9b363131f0e89fc356b544d03"
311 | dependencies = [
312 | "getrandom 0.1.16",
313 | "libc",
314 | "rand_chacha 0.2.2",
315 | "rand_core 0.5.1",
316 | "rand_hc",
317 | ]
318 |
319 | [[package]]
320 | name = "rand"
321 | version = "0.8.5"
322 | source = "registry+https://github.com/rust-lang/crates.io-index"
323 | checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404"
324 | dependencies = [
325 | "libc",
326 | "rand_chacha 0.3.1",
327 | "rand_core 0.6.4",
328 | ]
329 |
330 | [[package]]
331 | name = "rand_chacha"
332 | version = "0.2.2"
333 | source = "registry+https://github.com/rust-lang/crates.io-index"
334 | checksum = "f4c8ed856279c9737206bf725bf36935d8666ead7aa69b52be55af369d193402"
335 | dependencies = [
336 | "ppv-lite86",
337 | "rand_core 0.5.1",
338 | ]
339 |
340 | [[package]]
341 | name = "rand_chacha"
342 | version = "0.3.1"
343 | source = "registry+https://github.com/rust-lang/crates.io-index"
344 | checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88"
345 | dependencies = [
346 | "ppv-lite86",
347 | "rand_core 0.6.4",
348 | ]
349 |
350 | [[package]]
351 | name = "rand_core"
352 | version = "0.5.1"
353 | source = "registry+https://github.com/rust-lang/crates.io-index"
354 | checksum = "90bde5296fc891b0cef12a6d03ddccc162ce7b2aff54160af9338f8d40df6d19"
355 | dependencies = [
356 | "getrandom 0.1.16",
357 | ]
358 |
359 | [[package]]
360 | name = "rand_core"
361 | version = "0.6.4"
362 | source = "registry+https://github.com/rust-lang/crates.io-index"
363 | checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c"
364 | dependencies = [
365 | "getrandom 0.2.11",
366 | ]
367 |
368 | [[package]]
369 | name = "rand_hc"
370 | version = "0.2.0"
371 | source = "registry+https://github.com/rust-lang/crates.io-index"
372 | checksum = "ca3129af7b92a17112d59ad498c6f81eaf463253766b90396d39ea7a39d6613c"
373 | dependencies = [
374 | "rand_core 0.5.1",
375 | ]
376 |
377 | [[package]]
378 | name = "rayon"
379 | version = "1.10.0"
380 | source = "registry+https://github.com/rust-lang/crates.io-index"
381 | checksum = "b418a60154510ca1a002a752ca9714984e21e4241e804d32555251faf8b78ffa"
382 | dependencies = [
383 | "either",
384 | "rayon-core",
385 | ]
386 |
387 | [[package]]
388 | name = "rayon-core"
389 | version = "1.12.1"
390 | source = "registry+https://github.com/rust-lang/crates.io-index"
391 | checksum = "1465873a3dfdaa8ae7cb14b4383657caab0b3e8a0aa9ae8e04b044854c8dfce2"
392 | dependencies = [
393 | "crossbeam-deque",
394 | "crossbeam-utils",
395 | ]
396 |
397 | [[package]]
398 | name = "serde"
399 | version = "1.0.197"
400 | source = "registry+https://github.com/rust-lang/crates.io-index"
401 | checksum = "3fb1c873e1b9b056a4dc4c0c198b24c3ffa059243875552b2bd0933b1aee4ce2"
402 | dependencies = [
403 | "serde_derive",
404 | ]
405 |
406 | [[package]]
407 | name = "serde_derive"
408 | version = "1.0.197"
409 | source = "registry+https://github.com/rust-lang/crates.io-index"
410 | checksum = "7eb0b34b42edc17f6b7cac84a52a1c5f0e1bb2227e997ca9011ea3dd34e8610b"
411 | dependencies = [
412 | "proc-macro2",
413 | "quote",
414 | "syn",
415 | ]
416 |
417 | [[package]]
418 | name = "syn"
419 | version = "2.0.72"
420 | source = "registry+https://github.com/rust-lang/crates.io-index"
421 | checksum = "dc4b9b9bf2add8093d3f2c0204471e951b2285580335de42f9d2534f3ae7a8af"
422 | dependencies = [
423 | "proc-macro2",
424 | "quote",
425 | "unicode-ident",
426 | ]
427 |
428 | [[package]]
429 | name = "target-lexicon"
430 | version = "0.12.15"
431 | source = "registry+https://github.com/rust-lang/crates.io-index"
432 | checksum = "4873307b7c257eddcb50c9bedf158eb669578359fb28428bef438fec8e6ba7c2"
433 |
434 | [[package]]
435 | name = "tokengrams"
436 | version = "0.3.3"
437 | dependencies = [
438 | "anyhow",
439 | "bincode",
440 | "funty",
441 | "indicatif",
442 | "memmap2",
443 | "pyo3",
444 | "quickcheck",
445 | "rand 0.8.5",
446 | "rayon",
447 | "rayon-core",
448 | "serde",
449 | "typetag",
450 | "utf16_literal",
451 | ]
452 |
453 | [[package]]
454 | name = "typeid"
455 | version = "1.0.0"
456 | source = "registry+https://github.com/rust-lang/crates.io-index"
457 | checksum = "059d83cc991e7a42fc37bd50941885db0888e34209f8cfd9aab07ddec03bc9cf"
458 |
459 | [[package]]
460 | name = "typetag"
461 | version = "0.2.17"
462 | source = "registry+https://github.com/rust-lang/crates.io-index"
463 | checksum = "1f7ec175048b96728c30152928c52161bfcc8ea2bd3fb7ed4ccb7dec060b2834"
464 | dependencies = [
465 | "erased-serde",
466 | "inventory",
467 | "once_cell",
468 | "serde",
469 | "typetag-impl",
470 | ]
471 |
472 | [[package]]
473 | name = "typetag-impl"
474 | version = "0.2.17"
475 | source = "registry+https://github.com/rust-lang/crates.io-index"
476 | checksum = "84b5474fd169a5b02b6782b56bbbbff27e85947d4488e5501123687db3148647"
477 | dependencies = [
478 | "proc-macro2",
479 | "quote",
480 | "syn",
481 | ]
482 |
483 | [[package]]
484 | name = "unicode-ident"
485 | version = "1.0.12"
486 | source = "registry+https://github.com/rust-lang/crates.io-index"
487 | checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b"
488 |
489 | [[package]]
490 | name = "unicode-width"
491 | version = "0.1.11"
492 | source = "registry+https://github.com/rust-lang/crates.io-index"
493 | checksum = "e51733f11c9c4f72aa0c160008246859e340b00807569a0da0e7a1079b27ba85"
494 |
495 | [[package]]
496 | name = "unindent"
497 | version = "0.2.3"
498 | source = "registry+https://github.com/rust-lang/crates.io-index"
499 | checksum = "c7de7d73e1754487cb58364ee906a499937a0dfabd86bcb980fa99ec8c8fa2ce"
500 |
501 | [[package]]
502 | name = "utf16_literal"
503 | version = "0.2.1"
504 | source = "registry+https://github.com/rust-lang/crates.io-index"
505 | checksum = "316f90fe4a7beb941ce0b6806ba7386f1a515155a42def71825f3f9a232e3f48"
506 |
507 | [[package]]
508 | name = "wasi"
509 | version = "0.9.0+wasi-snapshot-preview1"
510 | source = "registry+https://github.com/rust-lang/crates.io-index"
511 | checksum = "cccddf32554fecc6acb585f82a32a72e28b48f8c4c1883ddfeeeaa96f7d8e519"
512 |
513 | [[package]]
514 | name = "wasi"
515 | version = "0.11.0+wasi-snapshot-preview1"
516 | source = "registry+https://github.com/rust-lang/crates.io-index"
517 | checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423"
518 |
519 | [[package]]
520 | name = "windows-sys"
521 | version = "0.52.0"
522 | source = "registry+https://github.com/rust-lang/crates.io-index"
523 | checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d"
524 | dependencies = [
525 | "windows-targets",
526 | ]
527 |
528 | [[package]]
529 | name = "windows-targets"
530 | version = "0.52.4"
531 | source = "registry+https://github.com/rust-lang/crates.io-index"
532 | checksum = "7dd37b7e5ab9018759f893a1952c9420d060016fc19a472b4bb20d1bdd694d1b"
533 | dependencies = [
534 | "windows_aarch64_gnullvm",
535 | "windows_aarch64_msvc",
536 | "windows_i686_gnu",
537 | "windows_i686_msvc",
538 | "windows_x86_64_gnu",
539 | "windows_x86_64_gnullvm",
540 | "windows_x86_64_msvc",
541 | ]
542 |
543 | [[package]]
544 | name = "windows_aarch64_gnullvm"
545 | version = "0.52.4"
546 | source = "registry+https://github.com/rust-lang/crates.io-index"
547 | checksum = "bcf46cf4c365c6f2d1cc93ce535f2c8b244591df96ceee75d8e83deb70a9cac9"
548 |
549 | [[package]]
550 | name = "windows_aarch64_msvc"
551 | version = "0.52.4"
552 | source = "registry+https://github.com/rust-lang/crates.io-index"
553 | checksum = "da9f259dd3bcf6990b55bffd094c4f7235817ba4ceebde8e6d11cd0c5633b675"
554 |
555 | [[package]]
556 | name = "windows_i686_gnu"
557 | version = "0.52.4"
558 | source = "registry+https://github.com/rust-lang/crates.io-index"
559 | checksum = "b474d8268f99e0995f25b9f095bc7434632601028cf86590aea5c8a5cb7801d3"
560 |
561 | [[package]]
562 | name = "windows_i686_msvc"
563 | version = "0.52.4"
564 | source = "registry+https://github.com/rust-lang/crates.io-index"
565 | checksum = "1515e9a29e5bed743cb4415a9ecf5dfca648ce85ee42e15873c3cd8610ff8e02"
566 |
567 | [[package]]
568 | name = "windows_x86_64_gnu"
569 | version = "0.52.4"
570 | source = "registry+https://github.com/rust-lang/crates.io-index"
571 | checksum = "5eee091590e89cc02ad514ffe3ead9eb6b660aedca2183455434b93546371a03"
572 |
573 | [[package]]
574 | name = "windows_x86_64_gnullvm"
575 | version = "0.52.4"
576 | source = "registry+https://github.com/rust-lang/crates.io-index"
577 | checksum = "77ca79f2451b49fa9e2af39f0747fe999fcda4f5e241b2898624dca97a1f2177"
578 |
579 | [[package]]
580 | name = "windows_x86_64_msvc"
581 | version = "0.52.4"
582 | source = "registry+https://github.com/rust-lang/crates.io-index"
583 | checksum = "32b752e52a2da0ddfbdbcc6fceadfeede4c939ed16d13e648833a61dfb611ed8"
584 |
--------------------------------------------------------------------------------
/Cargo.toml:
--------------------------------------------------------------------------------
1 | [package]
2 | name = "tokengrams"
3 | description = "Compute n-gram statistics and model language over pre-tokenized text corpora used to train large language models."
4 | license = "MIT"
5 | version = "0.3.3"
6 | edition = "2021"
7 |
8 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
9 | [lib]
10 | name = "tokengrams"
11 | crate-type = ["cdylib", "rlib"]
12 |
13 | [features]
14 | default = ["pyo3/extension-module"]
15 | rust = []
16 |
17 | [dependencies]
18 | anyhow = "1.0.81"
19 | bincode = "1.3.3"
20 | funty = "2.0.0"
21 | indicatif = "0.17.8"
22 | memmap2 = "0.9.4"
23 | pyo3 = { version = "0.22.2", features = ["extension-module", "anyhow"] }
24 | rand = "0.8.5"
25 | rayon = "1.10.0"
26 | rayon-core = "1.12.1"
27 | serde = { version = "1.0.197", features = ["derive"] }
28 | typetag = "0.2.17"
29 | utf16_literal = "0.2.1"
30 |
31 | [[test]]
32 | name = "tests"
33 | path = "tests/tests.rs"
34 |
35 | [dev-dependencies]
36 | quickcheck = { version = "0.9", default-features = false }
37 | rand = "0.8.4"
38 |
--------------------------------------------------------------------------------
/LICENSE.md:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 EleutherAI
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Tokengrams
2 | Tokengrams allows you to efficiently compute $n$-gram statistics for pre-tokenized text corpora used to train large language models. It does this not by explicitly pre-computing the $n$-gram counts for fixed $n$, but by creating a [suffix array](https://en.wikipedia.org/wiki/Suffix_array) index which allows you to efficiently compute the count of an $n$-gram on the fly for any $n$.
3 |
4 | Our code also allows you to turn your suffix array index into an efficient $n$-gram language model, which can be used to generate text or compute the perplexity of a given text.
5 |
6 | The backend is written in Rust, and the Python bindings are generated using [PyO3](https://github.com/PyO3/pyo3).
7 |
8 | # Installation
9 |
10 | ```bash
11 | pip install tokengrams
12 | ```
13 |
14 | # Usage
15 |
16 | [Full text generation demo](https://colab.research.google.com/drive/1CEHoIjLboGl8YPbIqnWJlPYMm1wVOrrj?usp=sharing)
17 |
18 | ## Preparing data
19 |
20 | Use a dataset of u16 or u32 tokens, or prepare one from a HuggingFace dataset.
21 |
22 | ```python
23 | # Get pre-tokenized dataset
24 | from huggingface_hub import HfApi, hf_hub_download
25 |
26 | hf_hub_download(
27 | repo_id="EleutherAI/pile-standard-pythia-preshuffled",
28 | repo_type="dataset",
29 | filename="document-00000-of-00020.bin",
30 | local_dir="."
31 | )
32 | ```
33 | ```python
34 | # Tokenize HF dataset
35 | from tokengrams import tokenize_hf_dataset
36 | from datasets import load_dataset
37 | from transformers import AutoTokenizer
38 |
39 | tokenize_hf_dataset(
40 | dataset=load_dataset("EleutherAI/lambada_openai", "en"),
41 | tokenizer=AutoTokenizer.from_pretrained("EleutherAI/pythia-160m"),
42 | output_path="lambada.bin",
43 | text_key="text",
44 | append_eod=True,
45 | workers=1,
46 | )
47 | ```
48 |
49 | ## Building an index
50 | ```python
51 | from tokengrams import MemmapIndex
52 |
53 | # Create a new index from an on-disk corpus of u16 tokens and save it to a .idx file.
54 | # Set verbose to true to include a progress bar for the index sort.
55 | index = MemmapIndex.build(
56 | "document-00000-of-00020.bin",
57 | "document-00000-of-00020.idx",
58 | vocab=2**16,
59 | verbose=True
60 | )
61 |
62 | # True for any valid index.
63 | print(index.is_sorted())
64 |
65 | # Get the count of "hello world" in the corpus.
66 | from transformers import AutoTokenizer
67 |
68 | tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-160m")
69 | print(index.count(tokenizer.encode("hello world")))
70 |
71 | # You can now load the index from disk later using __init__
72 | index = MemmapIndex(
73 | "document-00000-of-00020.bin",
74 | "document-00000-of-00020.idx",
75 | vocab=2**16
76 | )
77 | ```
78 |
79 | ## Using an index
80 |
81 | ```python
82 | # Count how often each token in the corpus succeeds "hello world".
83 | print(index.count_next(tokenizer.encode("hello world")))
84 | print(index.batch_count_next(
85 | [tokenizer.encode("hello world"), tokenizer.encode("hello universe")]
86 | ))
87 |
88 | # Get smoothed probabilities for query continuations
89 | print(index.smoothed_probs(tokenizer.encode("hello world")))
90 | print(index.batch_smoothed_probs(
91 | [tokenizer.encode("hello world"), tokenizer.encode("hello universe")]
92 | ))
93 |
94 | # Autoregressively sample 10 tokens using 5-gram language statistics. Initial
95 | # gram statistics are derived from the query, with lower order gram statistics used
96 | # until the sequence contains at least 5 tokens.
97 | print(index.sample_unsmoothed(tokenizer.encode("hello world"), n=5, k=10, num_samples=20))
98 | print(index.sample_smoothed(tokenizer.encode("hello world"), n=5, k=10, num_samples=20))
99 |
100 | # Query whether the corpus contains "hello world"
101 | print(index.contains(tokenizer.encode("hello world")))
102 |
103 | # Get all n-grams beginning with "hello world" in the corpus
104 | print(index.positions(tokenizer.encode("hello world")))
105 | ```
106 |
107 | ## Scaling
108 |
109 | Corpora small enough to fit in memory can use an InMemoryIndex:
110 |
111 | ```python
112 | from tokengrams import InMemoryIndex
113 |
114 | tokens = [0, 1, 2, 3, 4]
115 | index = InMemoryIndex(tokens, vocab=5)
116 | ```
117 |
118 | Larger corpora must use a MemmapIndex.
119 |
120 | Some systems struggle with memory mapping extremely large tables (e.g. 40 billion tokens), causing unexpected bus errors. To prevent this split the corpus into shards then use a ShardedMemmapIndex to sort and query the table shard by shard:
121 |
122 | ```python
123 | from tokengrams import ShardedMemmapIndex
124 | from huggingface_hub import HfApi, hf_hub_download
125 |
126 | files = [
127 | file for file in HfApi().list_repo_files("EleutherAI/pile-standard-pythia-preshuffled", repo_type="dataset")
128 | if file.endswith('.bin')
129 | ]
130 |
131 | index_paths = []
132 | for file in files:
133 | hf_hub_download("EleutherAI/pile-standard-pythia-preshuffled", repo_type="dataset", filename=file, local_dir=".")
134 | index_paths.append((file, f'{file.rstrip(".bin")}.idx'))
135 |
136 | index = ShardedMemmapIndex.build(index_paths, vocab=2**16, verbose=True)
137 | ```
138 | ### Tokens
139 |
140 | Tokengrams builds indices from on-disk corpora of either u16 or u32 tokens, supporting a maximum vocabulary size of 232. In practice, however, vocabulary size is limited by the length of the largest word size vector the machine can allocate in memory.
141 |
142 | Corpora with vocabulary sizes smaller than 216 must use u16 tokens.
143 |
144 | ## Performance
145 |
146 | Index build times for in-memory corpora scale inversely with the number of available CPU threads, whereas if the index reads from or writes to a file it is likely to be IO bound.
147 |
148 | The time complexities of count_next(query) and sample_unsmoothed(query) are O(n log n), where n is ~ the number of completions for the query. The time complexity of sample_smoothed(query) is O(m n log n) where m is the n-gram order.
149 |
150 |
151 |
152 |  |
153 |  |
154 |
155 |
156 |
157 | # Development
158 |
159 | ```bash
160 | cargo build
161 | cargo test
162 | ```
163 |
164 | Develop Python bindings:
165 |
166 | ```bash
167 | pip install maturin
168 | maturin develop
169 | pytest
170 | ```
171 |
172 | # Support
173 |
174 | The best way to get support is to open an issue on this repo or post in #interp-across-time in the [EleutherAI Discord server](https://discord.gg/eleutherai). If you've used the library and have had a positive (or negative) experience, we'd love to hear from you!
175 |
--------------------------------------------------------------------------------
/demos/generate_text.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "%pip install -q maturin datasets transformers numpy pandas tokengrams"
10 | ]
11 | },
12 | {
13 | "cell_type": "code",
14 | "execution_count": null,
15 | "metadata": {},
16 | "outputs": [],
17 | "source": [
18 | "from tokengrams import MemmapIndex, tokenize_hf_dataset\n",
19 | "from datasets import load_dataset\n",
20 | "from transformers import AutoTokenizer"
21 | ]
22 | },
23 | {
24 | "cell_type": "code",
25 | "execution_count": null,
26 | "metadata": {},
27 | "outputs": [],
28 | "source": [
29 | "tokenizer = AutoTokenizer.from_pretrained(\"EleutherAI/pythia-160m\")\n",
30 | "tokenize_hf_dataset(\n",
31 | " dataset=load_dataset(\"EleutherAI/lambada_openai\", \"en\"),\n",
32 | " tokenizer=tokenizer,\n",
33 | " output_path=\"lambada.bin\",\n",
34 | " text_key=\"text\",\n",
35 | " append_eod=False,\n",
36 | " workers=1,\n",
37 | ")\n",
38 | "\n",
39 | "index = MemmapIndex.build('lambada.bin', 'lambada.idx', vocab=2**16, verbose=True)"
40 | ]
41 | },
42 | {
43 | "cell_type": "code",
44 | "execution_count": 4,
45 | "metadata": {},
46 | "outputs": [
47 | {
48 | "name": "stdout",
49 | "output_type": "stream",
50 | "text": [
51 | "Once upon a time, Salem, an older wolf I'd never known as a human, had been the omega of the Boundary Wood pack. But I had seen enough of Shelby when I was clawing my way through the meningitis to know that she had fallen low in Paul's eyes and thus low in the packI'm wearing rented skis, rented ski boots that feel weird and tight and make me walk funny, plus every other kind of snow gear my mom was able to convince me to put on. I drew the line at goggles, and I stuck the unflattering wool hat into my jacket pocket, but from the neck down every inch of me is covered and padded. I don't know if I can move, let alone skiAway from the water that had changed everything for me, that had changed the lives of all of Jace's close friends. None of us would ever be the same again. But I knew that I couldn't protect myself from that kind\n"
52 | ]
53 | }
54 | ],
55 | "source": [
56 | "sample = index.sample_unsmoothed(tokenizer.encode(\"Once\"), n=8, k=200, num_samples=1)[0]\n",
57 | "print(tokenizer.decode(sample))"
58 | ]
59 | }
60 | ],
61 | "metadata": {
62 | "kernelspec": {
63 | "display_name": "base",
64 | "language": "python",
65 | "name": "python3"
66 | },
67 | "language_info": {
68 | "codemirror_mode": {
69 | "name": "ipython",
70 | "version": 3
71 | },
72 | "file_extension": ".py",
73 | "mimetype": "text/x-python",
74 | "name": "python",
75 | "nbconvert_exporter": "python",
76 | "pygments_lexer": "ipython3",
77 | "version": "3.10.14"
78 | }
79 | },
80 | "nbformat": 4,
81 | "nbformat_minor": 2
82 | }
83 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: test
2 | channels:
3 | - conda-forge
4 | - defaults
5 | dependencies:
6 | - python=3.10
7 | - numpy
8 | - pytest
9 | - hypothesis
10 | - maturin
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "tokengrams"
3 | version = "0.3.3"
4 | description = "Efficiently computing & storing token n-grams from large corpora "
5 | authors = [
6 | { name = "Nora Belrose", email = "nora@eleuther.ai" },
7 | { name = "Lucia Quirke", email = "lucia@eleuther.ai" }
8 | ]
9 | dependencies = [
10 | "numpy>=1.24.4",
11 | "datasets>=1.14.0",
12 | "transformers>=4.11.3",
13 | "tqdm>=4.0.0",
14 | ]
15 | readme = "README.md"
16 | requires-python = ">= 3.10"
17 |
18 | [build-system]
19 | requires = ["maturin>=1.2,<2.0"]
20 | build-backend = "maturin"
21 |
22 | [tool.maturin]
23 | module-name = "tokengrams.tokengrams"
24 | features = ["pyo3/extension-module"]
25 |
--------------------------------------------------------------------------------
/src/bindings/in_memory_index.rs:
--------------------------------------------------------------------------------
1 | use crate::in_memory_index::InMemoryIndexRs;
2 | use anyhow::Result;
3 | use pyo3::prelude::*;
4 |
5 | /// An in-memory index exposes suffix table functionality over text corpora small enough to fit in memory.
6 | /// Non-generic PyO3 wrapper over InMemoryIndexRs.
7 | #[pyclass]
8 | pub struct InMemoryIndex {
9 | index: Box,
10 | }
11 |
12 | /// This trait is non-generic for PyO3 compatibility. Implementing structs may cast data
13 | /// to other unsigned integer types.
14 | pub trait InMemoryIndexTrait {
15 | fn save_text(&self, path: String) -> Result<()>;
16 | fn save_table(&self, path: String) -> Result<()>;
17 | fn is_sorted(&self) -> bool;
18 | fn contains(&self, query: Vec) -> bool;
19 | fn positions(&self, query: Vec) -> Vec;
20 | fn count(&self, query: Vec) -> usize;
21 | fn count_next(&self, query: Vec) -> Vec;
22 | fn batch_count_next(&self, queries: Vec>) -> Vec>;
23 | fn sample_unsmoothed(
24 | &self,
25 | query: Vec,
26 | n: usize,
27 | k: usize,
28 | num_samples: usize,
29 | ) -> Result>>;
30 | fn sample_smoothed(
31 | &mut self,
32 | query: Vec,
33 | n: usize,
34 | k: usize,
35 | num_samples: usize,
36 | ) -> Result>>;
37 | fn get_smoothed_probs(&mut self, query: Vec) -> Vec;
38 | fn batch_get_smoothed_probs(&mut self, queries: Vec>) -> Vec>;
39 | fn estimate_deltas(&mut self, n: usize);
40 | }
41 |
42 | #[pymethods]
43 | impl InMemoryIndex {
44 | #[new]
45 | #[pyo3(signature = (tokens, vocab=u16::MAX as usize + 1, verbose=false))]
46 | pub fn new_py(_py: Python, tokens: Vec, vocab: usize, verbose: bool) -> Self {
47 | let index: Box = if vocab <= u16::MAX as usize + 1 {
48 | let tokens: Vec = tokens.iter().map(|&x| x as u16).collect();
49 | Box::new(InMemoryIndexRs::::new(tokens, Some(vocab), verbose))
50 | } else {
51 | let tokens: Vec = tokens.iter().map(|&x| x as u32).collect();
52 | Box::new(InMemoryIndexRs::::new(tokens, Some(vocab), verbose))
53 | };
54 |
55 | InMemoryIndex { index }
56 | }
57 |
58 | #[staticmethod]
59 | #[pyo3(signature = (path, token_limit=None, vocab=u16::MAX as usize + 1, verbose=false))]
60 | pub fn from_token_file(
61 | path: String,
62 | token_limit: Option,
63 | vocab: usize,
64 | verbose: bool,
65 | ) -> Result {
66 | if vocab <= u16::MAX as usize + 1 {
67 | Ok(InMemoryIndex {
68 | index: Box::new(InMemoryIndexRs::::from_token_file(
69 | path,
70 | token_limit,
71 | vocab,
72 | verbose,
73 | )?),
74 | })
75 | } else {
76 | Ok(InMemoryIndex {
77 | index: Box::new(InMemoryIndexRs::::from_token_file(
78 | path,
79 | token_limit,
80 | vocab,
81 | verbose,
82 | )?),
83 | })
84 | }
85 | }
86 |
87 | #[staticmethod]
88 | #[pyo3(signature = (token_path, index_path, vocab=u16::MAX as usize + 1))]
89 | pub fn from_disk(token_path: String, index_path: String, vocab: usize) -> Result {
90 | if vocab <= u16::MAX as usize + 1 {
91 | Ok(InMemoryIndex {
92 | index: Box::new(InMemoryIndexRs::::from_disk(
93 | token_path, index_path, vocab,
94 | )?),
95 | })
96 | } else {
97 | Ok(InMemoryIndex {
98 | index: Box::new(InMemoryIndexRs::::from_disk(
99 | token_path, index_path, vocab,
100 | )?),
101 | })
102 | }
103 | }
104 |
105 | pub fn save_tokens(&self, path: String) -> Result<()> {
106 | self.index.save_text(path)
107 | }
108 |
109 | pub fn save_index(&self, path: String) -> Result<()> {
110 | self.index.save_table(path)
111 | }
112 |
113 | pub fn is_sorted(&self) -> bool {
114 | self.index.is_sorted()
115 | }
116 |
117 | pub fn contains(&self, query: Vec) -> bool {
118 | self.index.contains(query)
119 | }
120 |
121 | pub fn positions(&self, query: Vec) -> Vec {
122 | self.index.positions(query).to_vec()
123 | }
124 |
125 | pub fn count(&self, query: Vec) -> usize {
126 | self.index.count(query)
127 | }
128 |
129 | pub fn count_next(&self, query: Vec) -> Vec {
130 | self.index.count_next(query)
131 | }
132 |
133 | pub fn batch_count_next(&self, queries: Vec>) -> Vec> {
134 | self.index.batch_count_next(queries)
135 | }
136 |
137 | /// Autoregressively sample num_samples of k characters from an unsmoothed n-gram model."""
138 | pub fn sample_unsmoothed(
139 | &self,
140 | query: Vec,
141 | n: usize,
142 | k: usize,
143 | num_samples: usize,
144 | ) -> Result>> {
145 | self.index.sample_unsmoothed(query, n, k, num_samples)
146 | }
147 |
148 | /// Returns interpolated Kneser-Ney smoothed token probability distribution using all previous
149 | /// tokens in the query.
150 | pub fn get_smoothed_probs(&mut self, query: Vec) -> Vec {
151 | self.index.get_smoothed_probs(query)
152 | }
153 |
154 | /// Returns interpolated Kneser-Ney smoothed token probability distribution using all previous
155 | /// tokens in the query.
156 | pub fn batch_get_smoothed_probs(&mut self, queries: Vec>) -> Vec> {
157 | self.index.batch_get_smoothed_probs(queries)
158 | }
159 |
160 | /// Autoregressively sample num_samples of k characters from a Kneser-Ney smoothed n-gram model.
161 | pub fn sample_smoothed(
162 | &mut self,
163 | query: Vec,
164 | n: usize,
165 | k: usize,
166 | num_samples: usize,
167 | ) -> Result>> {
168 | self.index.sample_smoothed(query, n, k, num_samples)
169 | }
170 |
171 | /// Warning: O(k**n) where k is vocabulary size, use with caution.
172 | /// Improve smoothed model quality by replacing the default delta hyperparameters
173 | /// for models of order n and below with improved estimates over the entire index.
174 | /// , page 16.
175 | pub fn estimate_deltas(&mut self, n: usize) {
176 | self.index.estimate_deltas(n);
177 | }
178 | }
179 |
--------------------------------------------------------------------------------
/src/bindings/memmap_index.rs:
--------------------------------------------------------------------------------
1 | use crate::memmap_index::MemmapIndexRs;
2 | use anyhow::Result;
3 | use pyo3::prelude::*;
4 |
5 | /// A memmap index exposes suffix table functionality over text corpora too large to fit in memory.
6 | #[pyclass]
7 | pub struct MemmapIndex {
8 | index: Box,
9 | }
10 |
11 | /// This trait is non-generic for PyO3 compatibility. Implementing structs may cast data
12 | /// to other unsigned integer types.
13 | pub trait MemmapIndexTrait {
14 | fn is_sorted(&self) -> bool;
15 | fn contains(&self, query: Vec) -> bool;
16 | fn positions(&self, query: Vec) -> Vec;
17 | fn count(&self, query: Vec) -> usize;
18 | fn count_next(&self, query: Vec) -> Vec;
19 | fn batch_count_next(&self, queries: Vec>) -> Vec>;
20 | fn sample_unsmoothed(
21 | &self,
22 | query: Vec,
23 | n: usize,
24 | k: usize,
25 | num_samples: usize,
26 | ) -> Result>>;
27 | fn sample_smoothed(
28 | &mut self,
29 | query: Vec,
30 | n: usize,
31 | k: usize,
32 | num_samples: usize,
33 | ) -> Result>>;
34 | fn get_smoothed_probs(&mut self, query: Vec) -> Vec;
35 | fn batch_get_smoothed_probs(&mut self, queries: Vec>) -> Vec>;
36 | fn estimate_deltas(&mut self, n: usize);
37 | }
38 |
39 | #[pymethods]
40 | impl MemmapIndex {
41 | #[new]
42 | #[pyo3(signature = (text_path, table_path, vocab=u16::MAX as usize + 1))]
43 | pub fn new(
44 | _py: Python,
45 | text_path: String,
46 | table_path: String,
47 | vocab: usize,
48 | ) -> PyResult {
49 | if vocab <= u16::MAX as usize + 1 {
50 | Ok(MemmapIndex {
51 | index: Box::new(MemmapIndexRs::::new(text_path, table_path, vocab)?),
52 | })
53 | } else {
54 | Ok(MemmapIndex {
55 | index: Box::new(MemmapIndexRs::::new(text_path, table_path, vocab)?),
56 | })
57 | }
58 | }
59 |
60 | #[staticmethod]
61 | #[pyo3(signature = (text_path, table_path, vocab=u16::MAX as usize + 1, verbose=false))]
62 | pub fn build(
63 | text_path: String,
64 | table_path: String,
65 | vocab: usize,
66 | verbose: bool,
67 | ) -> PyResult {
68 | if vocab <= u16::MAX as usize + 1 {
69 | Ok(MemmapIndex {
70 | index: Box::new(MemmapIndexRs::::build(
71 | text_path, table_path, vocab, verbose,
72 | )?),
73 | })
74 | } else {
75 | Ok(MemmapIndex {
76 | index: Box::new(MemmapIndexRs::::build(
77 | text_path, table_path, vocab, verbose,
78 | )?),
79 | })
80 | }
81 | }
82 |
83 | pub fn is_sorted(&self) -> bool {
84 | self.index.is_sorted()
85 | }
86 |
87 | pub fn contains(&self, query: Vec) -> bool {
88 | self.index.contains(query)
89 | }
90 |
91 | pub fn positions(&self, query: Vec) -> Vec {
92 | self.index.positions(query).to_vec()
93 | }
94 |
95 | pub fn count(&self, query: Vec) -> usize {
96 | self.index.positions(query).len()
97 | }
98 |
99 | pub fn count_next(&self, query: Vec) -> Vec {
100 | self.index.count_next(query)
101 | }
102 |
103 | pub fn batch_count_next(&self, queries: Vec>) -> Vec> {
104 | self.index.batch_count_next(queries)
105 | }
106 |
107 | /// Autoregressively sample num_samples of k characters from an unsmoothed n-gram model."""
108 | pub fn sample_unsmoothed(
109 | &self,
110 | query: Vec,
111 | n: usize,
112 | k: usize,
113 | num_samples: usize,
114 | ) -> Result>> {
115 | self.index.sample_unsmoothed(query, n, k, num_samples)
116 | }
117 |
118 | /// Returns interpolated Kneser-Ney smoothed token probability distribution using all previous
119 | /// tokens in the query.
120 | pub fn get_smoothed_probs(&mut self, query: Vec) -> Vec {
121 | self.index.get_smoothed_probs(query)
122 | }
123 |
124 | /// Returns interpolated Kneser-Ney smoothed token probability distribution using all previous
125 | /// tokens in the query.
126 | pub fn batch_get_smoothed_probs(&mut self, queries: Vec>) -> Vec> {
127 | self.index.batch_get_smoothed_probs(queries)
128 | }
129 |
130 | /// Autoregressively sample num_samples of k characters from a Kneser-Ney smoothed n-gram model.
131 | pub fn sample_smoothed(
132 | &mut self,
133 | query: Vec,
134 | n: usize,
135 | k: usize,
136 | num_samples: usize,
137 | ) -> Result>> {
138 | self.index.sample_smoothed(query, n, k, num_samples)
139 | }
140 |
141 | /// Warning: O(k**n) where k is vocabulary size, use with caution.
142 | /// Improve smoothed model quality by replacing the default delta hyperparameters
143 | /// for models of order n and below with improved estimates over the entire index.
144 | /// , page 16.
145 | pub fn estimate_deltas(&mut self, n: usize) {
146 | self.index.estimate_deltas(n);
147 | }
148 | }
149 |
--------------------------------------------------------------------------------
/src/bindings/mod.rs:
--------------------------------------------------------------------------------
1 | pub mod in_memory_index;
2 | pub mod memmap_index;
3 | pub mod sharded_memmap_index;
4 | pub mod sharded_in_memory_index;
--------------------------------------------------------------------------------
/src/bindings/sharded_in_memory_index.rs:
--------------------------------------------------------------------------------
1 | use crate::sharded_in_memory_index::ShardedInMemoryIndexRs;
2 | use anyhow::Result;
3 | use pyo3::prelude::*;
4 |
5 | /// Expose suffix table functionality over text corpora too large to fit in memory.
6 | #[pyclass]
7 | pub struct ShardedInMemoryIndex {
8 | index: Box,
9 | }
10 |
11 | /// This trait is non-generic for PyO3 compatibility. Implementing structs may cast data
12 | /// to other unsigned integer types.
13 | pub trait ShardedInMemoryIndexTrait {
14 | fn is_sorted(&self) -> bool;
15 | fn contains(&self, query: Vec) -> bool;
16 | fn count(&self, query: Vec) -> usize;
17 | fn count_next(&self, query: Vec) -> Vec;
18 | fn batch_count_next(&self, queries: Vec>) -> Vec>;
19 | fn sample_unsmoothed(
20 | &self,
21 | query: Vec,
22 | n: usize,
23 | k: usize,
24 | num_samples: usize,
25 | ) -> Result>>;
26 | fn sample_smoothed(
27 | &mut self,
28 | query: Vec,
29 | n: usize,
30 | k: usize,
31 | num_samples: usize,
32 | ) -> Result>>;
33 | fn get_smoothed_probs(&mut self, query: Vec) -> Vec;
34 | fn batch_get_smoothed_probs(&mut self, queries: Vec>) -> Vec>;
35 | fn estimate_deltas(&mut self, n: usize);
36 | }
37 |
38 | #[pymethods]
39 | impl ShardedInMemoryIndex {
40 | #[new]
41 | #[pyo3(signature = (paths, vocab=u16::MAX as usize + 1))]
42 | pub fn new(_py: Python, paths: Vec<(String, String)>, vocab: usize) -> PyResult {
43 | if vocab <= u16::MAX as usize + 1 {
44 | Ok(ShardedInMemoryIndex {
45 | index: Box::new(ShardedInMemoryIndexRs::::new(paths, vocab)?),
46 | })
47 | } else {
48 | Ok(ShardedInMemoryIndex {
49 | index: Box::new(ShardedInMemoryIndexRs::::new(paths, vocab)?),
50 | })
51 | }
52 | }
53 |
54 | pub fn is_sorted(&self) -> bool {
55 | self.index.is_sorted()
56 | }
57 |
58 | pub fn contains(&self, query: Vec) -> bool {
59 | self.index.contains(query)
60 | }
61 |
62 | pub fn count(&self, query: Vec) -> usize {
63 | self.index.count(query)
64 | }
65 |
66 | pub fn count_next(&self, query: Vec) -> Vec {
67 | self.index.count_next(query)
68 | }
69 |
70 | pub fn batch_count_next(&self, queries: Vec>) -> Vec> {
71 | self.index.batch_count_next(queries)
72 | }
73 |
74 | /// Autoregressively sample num_samples of k characters from an unsmoothed n-gram model."""
75 | pub fn sample_unsmoothed(
76 | &self,
77 | query: Vec,
78 | n: usize,
79 | k: usize,
80 | num_samples: usize,
81 | ) -> Result>> {
82 | self.index.sample_unsmoothed(query, n, k, num_samples)
83 | }
84 |
85 | /// Returns interpolated Kneser-Ney smoothed token probability distribution using all previous
86 | /// tokens in the query.
87 | pub fn get_smoothed_probs(&mut self, query: Vec) -> Vec {
88 | self.index.get_smoothed_probs(query)
89 | }
90 |
91 | /// Returns interpolated Kneser-Ney smoothed token probability distribution using all previous
92 | /// tokens in the query.
93 | pub fn batch_get_smoothed_probs(&mut self, queries: Vec>) -> Vec> {
94 | self.index.batch_get_smoothed_probs(queries)
95 | }
96 |
97 | /// Autoregressively sample num_samples of k characters from a Kneser-Ney smoothed n-gram model.
98 | pub fn sample_smoothed(
99 | &mut self,
100 | query: Vec,
101 | n: usize,
102 | k: usize,
103 | num_samples: usize,
104 | ) -> Result>> {
105 | self.index.sample_smoothed(query, n, k, num_samples)
106 | }
107 |
108 | /// Warning: O(k**n) where k is vocabulary size, use with caution.
109 | /// Improve smoothed model quality by replacing the default delta hyperparameters
110 | /// for models of order n and below with improved estimates over the entire index.
111 | /// , page 16.
112 | pub fn estimate_deltas(&mut self, n: usize) {
113 | self.index.estimate_deltas(n);
114 | }
115 | }
116 |
--------------------------------------------------------------------------------
/src/bindings/sharded_memmap_index.rs:
--------------------------------------------------------------------------------
1 | use crate::sharded_memmap_index::ShardedMemmapIndexRs;
2 | use anyhow::Result;
3 | use pyo3::prelude::*;
4 |
5 | /// Expose suffix table functionality over text corpora too large to fit in memory.
6 | #[pyclass]
7 | pub struct ShardedMemmapIndex {
8 | index: Box,
9 | }
10 |
11 | /// This trait is non-generic for PyO3 compatibility. Implementing structs may cast data
12 | /// to other unsigned integer types.
13 | pub trait ShardedMemmapIndexTrait {
14 | fn is_sorted(&self) -> bool;
15 | fn contains(&self, query: Vec) -> bool;
16 | fn count(&self, query: Vec) -> usize;
17 | fn count_next(&self, query: Vec) -> Vec;
18 | fn batch_count_next(&self, queries: Vec>) -> Vec>;
19 | fn sample_unsmoothed(
20 | &self,
21 | query: Vec,
22 | n: usize,
23 | k: usize,
24 | num_samples: usize,
25 | ) -> Result>>;
26 | fn sample_smoothed(
27 | &mut self,
28 | query: Vec,
29 | n: usize,
30 | k: usize,
31 | num_samples: usize,
32 | ) -> Result>>;
33 | fn get_smoothed_probs(&mut self, query: Vec) -> Vec;
34 | fn batch_get_smoothed_probs(&mut self, queries: Vec>) -> Vec>;
35 | fn estimate_deltas(&mut self, n: usize);
36 | }
37 |
38 | #[pymethods]
39 | impl ShardedMemmapIndex {
40 | #[new]
41 | #[pyo3(signature = (paths, vocab=u16::MAX as usize + 1))]
42 | pub fn new(_py: Python, paths: Vec<(String, String)>, vocab: usize) -> PyResult {
43 | if vocab <= u16::MAX as usize + 1 {
44 | Ok(ShardedMemmapIndex {
45 | index: Box::new(ShardedMemmapIndexRs::::new(paths, vocab)?),
46 | })
47 | } else {
48 | Ok(ShardedMemmapIndex {
49 | index: Box::new(ShardedMemmapIndexRs::::new(paths, vocab)?),
50 | })
51 | }
52 | }
53 |
54 | #[staticmethod]
55 | #[pyo3(signature = (paths, vocab=u16::MAX as usize + 1, verbose=false))]
56 | pub fn build(paths: Vec<(String, String)>, vocab: usize, verbose: bool) -> PyResult {
57 | if vocab <= u16::MAX as usize + 1 {
58 | Ok(ShardedMemmapIndex {
59 | index: Box::new(ShardedMemmapIndexRs::::build(paths, vocab, verbose)?),
60 | })
61 | } else {
62 | Ok(ShardedMemmapIndex {
63 | index: Box::new(ShardedMemmapIndexRs::::build(paths, vocab, verbose)?),
64 | })
65 | }
66 | }
67 |
68 | pub fn is_sorted(&self) -> bool {
69 | self.index.is_sorted()
70 | }
71 |
72 | pub fn contains(&self, query: Vec) -> bool {
73 | self.index.contains(query)
74 | }
75 |
76 | pub fn count(&self, query: Vec) -> usize {
77 | self.index.count(query)
78 | }
79 |
80 | pub fn count_next(&self, query: Vec) -> Vec {
81 | self.index.count_next(query)
82 | }
83 |
84 | pub fn batch_count_next(&self, queries: Vec>) -> Vec> {
85 | self.index.batch_count_next(queries)
86 | }
87 |
88 | /// Autoregressively sample num_samples of k characters from an unsmoothed n-gram model."""
89 | pub fn sample_unsmoothed(
90 | &self,
91 | query: Vec,
92 | n: usize,
93 | k: usize,
94 | num_samples: usize,
95 | ) -> Result>> {
96 | self.index.sample_unsmoothed(query, n, k, num_samples)
97 | }
98 |
99 | /// Returns interpolated Kneser-Ney smoothed token probability distribution using all previous
100 | /// tokens in the query.
101 | pub fn get_smoothed_probs(&mut self, query: Vec) -> Vec {
102 | self.index.get_smoothed_probs(query)
103 | }
104 |
105 | /// Returns interpolated Kneser-Ney smoothed token probability distribution using all previous
106 | /// tokens in the query.
107 | pub fn batch_get_smoothed_probs(&mut self, queries: Vec>) -> Vec> {
108 | self.index.batch_get_smoothed_probs(queries)
109 | }
110 |
111 | /// Autoregressively sample num_samples of k characters from a Kneser-Ney smoothed n-gram model.
112 | pub fn sample_smoothed(
113 | &mut self,
114 | query: Vec,
115 | n: usize,
116 | k: usize,
117 | num_samples: usize,
118 | ) -> Result>> {
119 | self.index.sample_smoothed(query, n, k, num_samples)
120 | }
121 |
122 | /// Warning: O(k**n) where k is vocabulary size, use with caution.
123 | /// Improve smoothed model quality by replacing the default delta hyperparameters
124 | /// for models of order n and below with improved estimates over the entire index.
125 | /// , page 16.
126 | pub fn estimate_deltas(&mut self, n: usize) {
127 | self.index.estimate_deltas(n);
128 | }
129 | }
130 |
--------------------------------------------------------------------------------
/src/in_memory_index.rs:
--------------------------------------------------------------------------------
1 | use anyhow::Result;
2 | use funty::Unsigned;
3 | use rayon::prelude::*;
4 | use std::collections::HashMap;
5 | use std::fmt::Debug;
6 | use std::fs::{File, OpenOptions};
7 | use std::io::Read;
8 |
9 | use crate::bindings::in_memory_index::InMemoryIndexTrait;
10 | use crate::mmap_slice::MmapSliceMut;
11 | use crate::sample::{KneserNeyCache, Sample};
12 | use crate::table::SuffixTable;
13 | use crate::util::transmute_slice;
14 |
15 | /// An in-memory index exposes suffix table functionality over text corpora small enough to fit in memory.
16 | pub struct InMemoryIndexRs {
17 | table: SuffixTable, Box<[u64]>>,
18 | cache: KneserNeyCache,
19 | }
20 |
21 | impl InMemoryIndexRs {
22 | pub fn new(tokens: Vec, vocab: Option, verbose: bool) -> Self {
23 | let vocab = vocab.unwrap_or(u16::MAX as usize + 1);
24 |
25 | let table = SuffixTable::new(tokens, Some(vocab), verbose);
26 | debug_assert!(table.is_sorted());
27 |
28 | InMemoryIndexRs {
29 | table,
30 | cache: KneserNeyCache::default(),
31 | }
32 | }
33 |
34 | pub fn from_token_file(
35 | path: String,
36 | token_limit: Option,
37 | vocab: usize,
38 | verbose: bool,
39 | ) -> Result {
40 | let mut buffer = Vec::new();
41 | let mut file = File::open(&path)?;
42 |
43 | if let Some(max_tokens) = token_limit {
44 | // Limit on the number of tokens to consider is provided
45 | let max_bytes = max_tokens * std::mem::size_of::();
46 | file.take(max_bytes as u64).read_to_end(&mut buffer)?;
47 | } else {
48 | file.read_to_end(&mut buffer)?;
49 | };
50 |
51 | let tokens = transmute_slice::(buffer.as_slice());
52 | let table = SuffixTable::new(tokens, Some(vocab), verbose);
53 | debug_assert!(table.is_sorted());
54 |
55 | Ok(InMemoryIndexRs {
56 | table,
57 | cache: KneserNeyCache::default(),
58 | })
59 | }
60 |
61 | fn read_file_to_boxed_slice(path: &str) -> Result> {
62 | let mut file = File::open(path)?;
63 | let file_len_bytes = file.metadata()?.len() as usize;
64 |
65 | // Ensure file size is a multiple of size of E
66 | if file_len_bytes % std::mem::size_of::() != 0 {
67 | anyhow::bail!("File size is not a multiple of element size");
68 | }
69 |
70 | let num_elements = file_len_bytes / std::mem::size_of::();
71 | let mut vec: Vec = Vec::with_capacity(num_elements);
72 | unsafe {
73 | let buf = std::slice::from_raw_parts_mut(vec.as_mut_ptr() as *mut u8, file_len_bytes);
74 | file.read_exact(buf)?;
75 | vec.set_len(num_elements);
76 | }
77 |
78 | Ok(vec.into_boxed_slice())
79 | }
80 |
81 | pub fn from_disk(text_path: String, table_path: String, vocab: usize) -> Result {
82 | let text = Self::read_file_to_boxed_slice::(&text_path)?;
83 | let table = Self::read_file_to_boxed_slice::(&table_path)?;
84 |
85 | let suffix_table = SuffixTable::from_parts(text, table, Some(vocab));
86 | debug_assert!(suffix_table.is_sorted());
87 |
88 | Ok(InMemoryIndexRs {
89 | table: suffix_table,
90 | cache: KneserNeyCache::default(),
91 | })
92 | }
93 |
94 | pub fn save_text(&self, path: String) -> Result<()> {
95 | let text = self.table.get_text();
96 | let file = OpenOptions::new()
97 | .create(true)
98 | .read(true)
99 | .write(true)
100 | .open(&path)?;
101 |
102 | let file_len = text.len() * std::mem::size_of::();
103 | file.set_len(file_len as u64)?;
104 |
105 | let mut mmap = MmapSliceMut::::new(&file)?;
106 | mmap.copy_from_slice(text);
107 | mmap.flush()?;
108 |
109 | Ok(())
110 | }
111 |
112 | pub fn save_table(&self, path: String) -> Result<()> {
113 | let table = self.table.get_table();
114 | let file = OpenOptions::new()
115 | .create(true)
116 | .read(true)
117 | .write(true)
118 | .open(&path)?;
119 |
120 | file.set_len((table.len() * 8) as u64)?;
121 |
122 | let mut mmap = MmapSliceMut::::new(&file)?;
123 | mmap.copy_from_slice(table);
124 | mmap.flush()?;
125 |
126 | Ok(())
127 | }
128 | }
129 |
130 | impl Sample for InMemoryIndexRs {
131 | fn get_cache(&self) -> &KneserNeyCache {
132 | &self.cache
133 | }
134 |
135 | fn get_mut_cache(&mut self) -> &mut KneserNeyCache {
136 | &mut self.cache
137 | }
138 |
139 | fn count_next_slice(&self, query: &[T]) -> Vec {
140 | self.table.count_next(query)
141 | }
142 |
143 | fn count_ngrams(&self, n: usize) -> HashMap {
144 | self.table.count_ngrams(n)
145 | }
146 | }
147 |
148 | impl InMemoryIndexTrait for InMemoryIndexRs {
149 | fn save_table(&self, table_path: String) -> Result<()> {
150 | self.save_table(table_path)
151 | }
152 |
153 | fn save_text(&self, text_path: String) -> Result<()> {
154 | self.save_text(text_path)
155 | }
156 |
157 | fn is_sorted(&self) -> bool {
158 | self.table.is_sorted()
159 | }
160 |
161 | fn contains(&self, query: Vec) -> bool {
162 | let query: Vec = query
163 | .iter()
164 | .filter_map(|&item| T::try_from(item).ok())
165 | .collect();
166 | self.table.contains(&query)
167 | }
168 |
169 | fn positions(&self, query: Vec) -> Vec {
170 | let query: Vec = query
171 | .iter()
172 | .filter_map(|&item| T::try_from(item).ok())
173 | .collect();
174 | self.table.positions(&query).to_vec()
175 | }
176 |
177 | fn count(&self, query: Vec) -> usize {
178 | let query: Vec = query
179 | .iter()
180 | .filter_map(|&item| T::try_from(item).ok())
181 | .collect();
182 | self.table.positions(&query).len()
183 | }
184 |
185 | fn count_next(&self, query: Vec) -> Vec {
186 | let query: Vec = query
187 | .iter()
188 | .filter_map(|&item| T::try_from(item).ok())
189 | .collect();
190 | self.table.count_next(&query)
191 | }
192 |
193 | fn batch_count_next(&self, queries: Vec>) -> Vec> {
194 | queries
195 | .into_par_iter()
196 | .map(|query| self.count_next(query))
197 | .collect()
198 | }
199 |
200 | fn sample_smoothed(
201 | &mut self,
202 | query: Vec,
203 | n: usize,
204 | k: usize,
205 | num_samples: usize,
206 | ) -> Result>> {
207 | let query: Vec = query
208 | .iter()
209 | .filter_map(|&item| T::try_from(item).ok())
210 | .collect();
211 |
212 | let samples_batch = >::sample_smoothed(self, &query, n, k, num_samples)?;
213 | Ok(samples_batch
214 | .into_iter()
215 | .map(|samples| {
216 | samples
217 | .into_iter()
218 | .filter_map(|sample| {
219 | match TryInto::::try_into(sample) {
220 | Ok(value) => Some(value),
221 | Err(_) => None, // Silently skip values that can't be converted
222 | }
223 | })
224 | .collect::>()
225 | })
226 | .collect())
227 | }
228 |
229 | fn sample_unsmoothed(
230 | &self,
231 | query: Vec,
232 | n: usize,
233 | k: usize,
234 | num_samples: usize,
235 | ) -> Result>> {
236 | let query: Vec = query
237 | .iter()
238 | .filter_map(|&item| T::try_from(item).ok())
239 | .collect();
240 |
241 | let samples_batch =
242 | >::sample_unsmoothed(self, &query, n, k, num_samples)?;
243 | Ok(samples_batch
244 | .into_iter()
245 | .map(|samples| {
246 | samples
247 | .into_iter()
248 | .filter_map(|sample| {
249 | match TryInto::::try_into(sample) {
250 | Ok(value) => Some(value),
251 | Err(_) => None, // Silently skip values that can't be converted
252 | }
253 | })
254 | .collect::>()
255 | })
256 | .collect())
257 | }
258 |
259 | fn get_smoothed_probs(&mut self, query: Vec) -> Vec {
260 | let query: Vec = query
261 | .iter()
262 | .filter_map(|&item| T::try_from(item).ok())
263 | .collect();
264 |
265 | >::get_smoothed_probs(self, &query)
266 | }
267 |
268 | fn batch_get_smoothed_probs(&mut self, queries: Vec>) -> Vec> {
269 | let queries: Vec> = queries
270 | .into_iter()
271 | .map(|query| {
272 | query
273 | .iter()
274 | .filter_map(|&item| T::try_from(item).ok())
275 | .collect()
276 | })
277 | .collect();
278 | >::batch_get_smoothed_probs(self, &queries)
279 | }
280 |
281 | fn estimate_deltas(&mut self, n: usize) {
282 | >::estimate_deltas(self, n)
283 | }
284 | }
285 |
286 |
287 | #[cfg(test)]
288 | pub mod tests {
289 | use super::*;
290 | use utf16_literal::utf16;
291 | use crate::table::SuffixTable;
292 |
293 | fn sais(text: &str) -> SuffixTable {
294 | SuffixTable::new(text.encode_utf16().collect::>(), None, false)
295 | }
296 |
297 | fn utf16_as_usize(s: &str) -> Vec {
298 | s.encode_utf16().map(|x| x as usize).collect()
299 | }
300 |
301 | #[test]
302 | fn sample_unsmoothed_empty_query_exists() {
303 | let s = utf16!("aaa");
304 | let index: Box> = Box::new(InMemoryIndexRs::new(s.to_vec(), None, false));
305 |
306 | let seqs = index.sample_unsmoothed(&[], 3, 10, 1).unwrap();
307 |
308 | assert_eq!(*seqs[0].last().unwrap(), s[0]);
309 | }
310 |
311 | #[test]
312 | fn sample_unsmoothed_u16_exists() {
313 | let s = utf16!("aaaa");
314 | let a = &s[0..1];
315 | let index: Box> = Box::new(InMemoryIndexRs::new(s.to_vec(), None, false));
316 |
317 | let seqs = index.sample_unsmoothed(a, 3, 10, 1).unwrap();
318 |
319 | assert_eq!(*seqs[0].last().unwrap(), a[0]);
320 | }
321 |
322 | #[test]
323 | fn sample_unsmoothed_u32_exists() {
324 | let s: Vec = "aaaa".encode_utf16().map(|c| c as u32).collect();
325 | let u32_vocab = Some(u16::MAX as usize + 2);
326 | let index: Box> = Box::new(InMemoryIndexRs::::new(s.clone(), u32_vocab, false));
327 |
328 | let seqs = index.sample_unsmoothed(&s[0..1], 3, 10, 1).unwrap();
329 |
330 | assert_eq!(*seqs[0].last().unwrap(), s[0]);
331 | }
332 |
333 | #[test]
334 | fn sample_unsmoothed_usize_exists() {
335 | let s = utf16_as_usize("aaaa");
336 | let index: Box = Box::new(InMemoryIndexRs::new(s.to_vec(), None, false));
337 |
338 | let seqs = index.sample_unsmoothed(s[0..1].to_vec(), 3, 10, 1).unwrap();
339 |
340 | assert_eq!(*seqs[0].last().unwrap(), s[0]);
341 | }
342 |
343 | #[test]
344 | fn sample_smoothed_exists() {
345 | let s = utf16!("aabbccabccba");
346 | let mut index: Box> = Box::new(InMemoryIndexRs::new(s.to_vec(), None, false));
347 |
348 | let tokens = &index.sample_smoothed(&s[0..1], 3, 10, 1).unwrap()[0];
349 |
350 | assert_eq!(tokens.len(), 11);
351 | }
352 |
353 | #[test]
354 | fn sample_smoothed_empty_query_exists() {
355 | let s: Vec = "aabbccabccba".encode_utf16().collect();
356 | let mut index: Box> = Box::new(InMemoryIndexRs::new(s, None, false));
357 |
358 | let tokens = &index.sample_smoothed(&[], 1, 10, 10).unwrap()[0];
359 |
360 | assert_eq!(tokens.len(), 10);
361 | }
362 |
363 | #[test]
364 | fn smoothed_probs_exists() {
365 | let tokens = "aaaaaaaabc".to_string();
366 | let tokens_vec: Vec = tokens.encode_utf16().collect();
367 | let query: Vec<_> = vec![utf16!("b")[0]];
368 |
369 | // Get unsmoothed probs for query
370 | let sa: SuffixTable = sais(&tokens);
371 | let bigram_counts = sa.count_next(&query);
372 | let unsmoothed_probs = bigram_counts
373 | .iter()
374 | .map(|&x| x as f64 / bigram_counts.iter().sum::() as f64)
375 | .collect::>();
376 |
377 | // Get smoothed probs for query
378 | let mut index: Box> = Box::new(InMemoryIndexRs::new(tokens_vec, None, false));
379 | let smoothed_probs = index.get_smoothed_probs(&query);
380 |
381 | // Compare unsmoothed and smoothed probabilities
382 | let a = utf16!("a")[0] as usize;
383 | let c = utf16!("c")[0] as usize;
384 |
385 | // The naive bigram probability for query 'b' is p(c) = 1.0.
386 | assert!(unsmoothed_probs[a] == 0.0);
387 | assert!(unsmoothed_probs[c] == 1.0);
388 |
389 | // The smoothed bigram probabilities interpolate with the lower-order unigram
390 | // probabilities where p(a) is high, lowering p(c)
391 | assert!(smoothed_probs[a] > 0.1);
392 | assert!(smoothed_probs[c] < 1.0);
393 | }
394 | }
395 |
--------------------------------------------------------------------------------
/src/lib.rs:
--------------------------------------------------------------------------------
1 | pub mod mmap_slice;
2 | pub use bindings::in_memory_index::InMemoryIndex;
3 | pub use bindings::memmap_index::MemmapIndex;
4 | pub use bindings::sharded_memmap_index::ShardedMemmapIndex;
5 | pub use bindings::sharded_in_memory_index::ShardedInMemoryIndex;
6 | pub use sharded_in_memory_index::ShardedInMemoryIndexRs;
7 | pub use in_memory_index::InMemoryIndexRs;
8 | pub use sample::Sample;
9 |
10 | pub use table::SuffixTable;
11 |
12 | /// Python bindings
13 | #[cfg(not(feature = "rust"))]
14 | use pyo3::prelude::*;
15 |
16 | mod bindings;
17 | mod in_memory_index;
18 | mod memmap_index;
19 | mod par_quicksort;
20 | mod sample;
21 | mod sharded_memmap_index;
22 | mod sharded_in_memory_index;
23 | mod table;
24 | mod util;
25 |
26 | #[cfg(not(feature = "rust"))]
27 | #[pymodule]
28 | fn tokengrams(m: &Bound<'_, PyModule>) -> PyResult<()> {
29 | m.add_class::()?;
30 | m.add_class::()?;
31 | m.add_class::()?;
32 | m.add_class::()?;
33 | Ok(())
34 | }
35 |
--------------------------------------------------------------------------------
/src/memmap_index.rs:
--------------------------------------------------------------------------------
1 | use anyhow::Result;
2 | use funty::Unsigned;
3 | use rayon::prelude::*;
4 | use std::collections::HashMap;
5 | use std::fs::{File, OpenOptions};
6 | use std::time::Instant;
7 |
8 | use crate::bindings::memmap_index::MemmapIndexTrait;
9 | use crate::mmap_slice::{MmapSlice, MmapSliceMut};
10 | use crate::par_quicksort::par_sort_unstable_by_key;
11 | use crate::sample::{KneserNeyCache, Sample};
12 | use crate::table::SuffixTable;
13 |
14 | /// A memmap index exposes suffix table functionality over text corpora too large to fit in memory.
15 | pub struct MemmapIndexRs {
16 | table: SuffixTable, MmapSlice>,
17 | cache: KneserNeyCache,
18 | }
19 |
20 | impl MemmapIndexRs {
21 | pub fn new(text_path: String, table_path: String, vocab: usize) -> Result {
22 | let text_file = File::open(&text_path)?;
23 | let table_file = File::open(&table_path)?;
24 |
25 | let table = SuffixTable::from_parts(
26 | MmapSlice::::new(&text_file)?,
27 | MmapSlice::new(&table_file)?,
28 | Some(vocab),
29 | );
30 | assert!(table.is_sorted());
31 |
32 | Ok(MemmapIndexRs {
33 | table,
34 | cache: KneserNeyCache::default(),
35 | })
36 | }
37 |
38 | pub fn build(
39 | text_path: String,
40 | table_path: String,
41 | vocab: usize,
42 | verbose: bool,
43 | ) -> Result {
44 | // Memory map the text as read-only
45 | let text_mmap = MmapSlice::::new(&File::open(&text_path).unwrap()).unwrap();
46 |
47 | let table_file = OpenOptions::new()
48 | .create(true)
49 | .read(true)
50 | .write(true)
51 | .open(&table_path)?;
52 |
53 | // Allocate space on disk for the table
54 | let table_size = text_mmap.len() * 8;
55 | table_file.set_len(table_size as u64)?;
56 |
57 | if verbose {
58 | println!("Writing indices to disk...");
59 | }
60 | let start = Instant::now();
61 | let mut table_mmap = MmapSliceMut::::new(&table_file)?;
62 | table_mmap
63 | .iter_mut()
64 | .enumerate()
65 | .for_each(|(i, x)| *x = i as u64);
66 |
67 | assert_eq!(table_mmap.len(), text_mmap.len());
68 | if verbose {
69 | println!("Time elapsed: {:?}", start.elapsed());
70 | }
71 | let start = Instant::now();
72 |
73 | // TODO: Be even smarter about this? We may need to take into account the number of CPUs
74 | // available as well. These magic numbers were tuned on a server with 48 physical cores.
75 | // Empirically we start getting stack overflows between 5B and 10B tokens when using the
76 | // default stack size of 2MB. We scale the stack size as log2(n) * 8MB to avoid this.
77 | let scale = (text_mmap.len() as f64) / 5e9; // 5B tokens
78 | let stack_size = scale.log2().max(1.0) * 8e6; // 8MB
79 |
80 | rayon::ThreadPoolBuilder::new()
81 | .stack_size(stack_size as usize)
82 | .build()
83 | .unwrap()
84 | .install(|| {
85 | // Sort the indices by the suffixes they point to.
86 | // The unstable algorithm is critical for avoiding out-of-memory errors, since it does
87 | // not allocate any more memory than the input and output slices.
88 | println!("Sorting indices...");
89 | par_sort_unstable_by_key(
90 | table_mmap.as_slice_mut(),
91 | |&i| &text_mmap[i as usize..],
92 | verbose,
93 | );
94 | });
95 | if verbose {
96 | println!("Time elapsed: {:?}", start.elapsed());
97 | }
98 |
99 | // Re-open the table as read-only
100 | let table_mmap = MmapSlice::new(&table_file)?;
101 | let table = SuffixTable::from_parts(text_mmap, table_mmap, Some(vocab));
102 | debug_assert!(table.is_sorted());
103 |
104 | Ok(MemmapIndexRs {
105 | table,
106 | cache: KneserNeyCache::default(),
107 | })
108 | }
109 | }
110 |
111 | impl Sample for MemmapIndexRs {
112 | fn get_cache(&self) -> &KneserNeyCache {
113 | &self.cache
114 | }
115 |
116 | fn get_mut_cache(&mut self) -> &mut KneserNeyCache {
117 | &mut self.cache
118 | }
119 |
120 | fn count_next_slice(&self, query: &[T]) -> Vec {
121 | self.table.count_next(query)
122 | }
123 |
124 | fn count_ngrams(&self, n: usize) -> HashMap {
125 | self.table.count_ngrams(n)
126 | }
127 | }
128 |
129 | impl MemmapIndexTrait for MemmapIndexRs
130 | where
131 | T: Unsigned,
132 | {
133 | fn positions(&self, query: Vec) -> Vec {
134 | let query: Vec = query
135 | .iter()
136 | .filter_map(|&item| T::try_from(item).ok())
137 | .collect();
138 | self.table.positions(&query).to_vec()
139 | }
140 |
141 | fn is_sorted(&self) -> bool {
142 | self.table.is_sorted()
143 | }
144 |
145 | fn contains(&self, query: Vec) -> bool {
146 | let query: Vec