├── .gitignore ├── .pylintrc ├── .python-version ├── CHANGELOG.md ├── LICENSE ├── MANIFEST.in ├── README.md ├── data ├── general_abbreviations.json ├── legal_abbreviations.json ├── test.jsonl.gz └── train2.jsonl.gz ├── docs ├── README.md ├── advanced_usage.md ├── api_reference.md ├── architecture.md ├── getting_started.md ├── overview.md └── training_models.md ├── mypy.ini ├── nupunkt ├── __init__.py ├── core │ ├── __init__.py │ ├── base.py │ ├── constants.py │ ├── language_vars.py │ ├── parameters.py │ └── tokens.py ├── models │ ├── __init__.py │ └── default_model.bin ├── nupunkt.py ├── py.typed ├── tokenizers │ ├── __init__.py │ ├── paragraph_tokenizer.py │ └── sentence_tokenizer.py ├── trainers │ ├── __init__.py │ └── base_trainer.py └── utils │ ├── __init__.py │ ├── compression.py │ ├── iteration.py │ └── statistics.py ├── paragraphs.py ├── pyproject.toml ├── scripts ├── benchmark_cache_sizes.py ├── profile_default_model.py ├── test_default_model.py ├── train_default_model.py └── utils │ ├── README.md │ ├── __init__.py │ ├── check_abbreviation.py │ ├── convert_model.py │ ├── model_info.py │ ├── optimize_model.py │ ├── profile_tokenizer.py │ └── test_tokenizer.py ├── sentences.py └── tests ├── __init__.py ├── conftest.py ├── test_language_vars.py ├── test_parameters.py ├── test_tokens.py └── test_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # Do not ignore model files 30 | !nupunkt/models/*.json 31 | !nupunkt/models/*.json.xz 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | cover/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | db.sqlite3-journal 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | .pybuilder/ 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # IPython 86 | profile_default/ 87 | ipython_config.py 88 | 89 | # pyenv 90 | # For a library or package, you might want to ignore these files since the code is 91 | # intended to run in multiple environments; otherwise, check them in: 92 | # .python-version 93 | 94 | # pipenv 95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 98 | # install all needed dependencies. 99 | #Pipfile.lock 100 | 101 | # UV 102 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 103 | # This is especially recommended for binary packages to ensure reproducibility, and is more 104 | # commonly ignored for libraries. 105 | uv.lock 106 | 107 | # poetry 108 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 109 | # This is especially recommended for binary packages to ensure reproducibility, and is more 110 | # commonly ignored for libraries. 111 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 112 | #poetry.lock 113 | 114 | # pdm 115 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 116 | #pdm.lock 117 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 118 | # in version control. 119 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 120 | .pdm.toml 121 | .pdm-python 122 | .pdm-build/ 123 | 124 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 125 | __pypackages__/ 126 | 127 | # Celery stuff 128 | celerybeat-schedule 129 | celerybeat.pid 130 | 131 | # SageMath parsed files 132 | *.sage.py 133 | 134 | # Environments 135 | .env 136 | .venv 137 | env/ 138 | venv/ 139 | ENV/ 140 | env.bak/ 141 | venv.bak/ 142 | 143 | # Spyder project settings 144 | .spyderproject 145 | .spyproject 146 | 147 | # Rope project settings 148 | .ropeproject 149 | 150 | # mkdocs documentation 151 | /site 152 | 153 | # mypy 154 | .mypy_cache/ 155 | .dmypy.json 156 | dmypy.json 157 | 158 | # Pyre type checker 159 | .pyre/ 160 | 161 | # pytype static type analyzer 162 | .pytype/ 163 | 164 | # Cython debug symbols 165 | cython_debug/ 166 | 167 | # PyCharm 168 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 169 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 170 | # and can be added to the global gitignore or merged into this file. For a more nuclear 171 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 172 | .idea/ 173 | 174 | # Ruff stuff: 175 | .ruff_cache/ 176 | 177 | # PyPI configuration file 178 | .pypirc 179 | data/train.jsonl.gz 180 | 181 | # Profiling data 182 | profiles/ 183 | -------------------------------------------------------------------------------- /.pylintrc: -------------------------------------------------------------------------------- 1 | [MAIN] 2 | # Use multiple processes to speed up Pylint. 3 | jobs=4 4 | 5 | # Minimum Python version to use for version dependent checks. 6 | py-version = 3.8 7 | 8 | # A comma-separated list of file extensions that will be checked 9 | init-hook="import sys; sys.path.append('./')" 10 | extension-pkg-whitelist=numpy,pandas 11 | 12 | # Specify a list of additional imports that are not correctly handled 13 | additional-builtins= 14 | 15 | # Add files to be processed to the ignore list 16 | ignore=CVS 17 | 18 | # Files or directories to be skipped 19 | ignore-patterns= 20 | 21 | # Controlling the stats printed at the end 22 | reports=no 23 | 24 | # Pickle collected data for later comparisons 25 | persistent=yes 26 | 27 | # List of plugins 28 | load-plugins= 29 | 30 | # Allow loading of arbitrary C extensions (risky) 31 | unsafe-load-any-extension=no 32 | 33 | # Limits of the number of parents for a class 34 | max-parents=15 35 | 36 | # Controls whether unused imports should be considered errors 37 | analyse-fallback-blocks=no 38 | 39 | # Tells whether we should check for unused import in __init__ files 40 | init-import=no 41 | 42 | # Disable the message "Locally disabling ". 43 | disable=locally-disabled 44 | 45 | # Controls whether file-wide pylint ignores should be recognized 46 | ignore-long-lines=^\s*(# )??$ 47 | 48 | [MESSAGES CONTROL] 49 | # Only show warnings with the listed confidence levels. Leave empty to show all. 50 | confidence= 51 | 52 | # Disable the following messages or rules: 53 | disable= 54 | C0103, # invalid-name (variable/function naming) 55 | C0111, # missing-docstring (classes, functions, etc.) 56 | C0303, # trailing-whitespace 57 | C0304, # missing-final-newline 58 | C0209, # consider-using-f-string 59 | C0302, # too-many-lines (module too long) 60 | W0311, # bad-indentation 61 | R0801, # duplicate-code 62 | R0902, # too-many-instance-attributes 63 | R0903, # too-few-public-methods 64 | R0904, # too-many-public-methods 65 | R0913, # too-many-arguments 66 | R0914, # too-many-locals 67 | R0915, # too-many-statements 68 | W0212, # protected-access (we use _params in some places intentionally) 69 | W0511, # fixme (allow TODOs in the code) 70 | W0613, # unused-argument 71 | C0412, # ungrouped-imports 72 | C0413, # wrong-import-position (needed for sys.path manipulation) 73 | C0415, # import-outside-toplevel (for imports in exception handlers) 74 | E0401, # import-error (false positives from custom imports) 75 | W0404, # reimported (reimports in error handling) 76 | 77 | [REPORTS] 78 | # Set the output format 79 | output-format=text 80 | 81 | # Tells whether to display a full report or only the warnings 82 | reports=no 83 | 84 | # Python expression which should return a note less than 10 85 | evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) 86 | 87 | # Template used to display messages 88 | msg-template={path}:{line}:{column}: {msg_id}: {msg} ({symbol}) 89 | 90 | [SIMILARITIES] 91 | # Ignore comments when computing similarities 92 | ignore-comments=yes 93 | 94 | # Ignore docstrings when computing similarities 95 | ignore-docstrings=yes 96 | 97 | # Ignore imports when computing similarities 98 | ignore-imports=yes 99 | 100 | # Minimum lines number of a similarity 101 | min-similarity-lines=8 102 | 103 | [BASIC] 104 | # Regular expression which should only match functions or classes names 105 | function-rgx=[a-z_][a-z0-9_]{2,50}$ 106 | class-rgx=[A-Z_][a-zA-Z0-9_]{2,50}$ 107 | method-rgx=[a-z_][a-z0-9_]{2,50}$ 108 | attr-rgx=[a-z_][a-z0-9_]{1,50}$ 109 | const-rgx=(([A-Z_][A-Z0-9_]*)|(__.*__))$ 110 | argument-rgx=[a-z_][a-z0-9_]{1,50}$ 111 | variable-rgx=[a-z_][a-z0-9_]{1,50}$ 112 | inlinevar-rgx=[A-Za-z_][A-Za-z0-9_]*$ 113 | class-attribute-rgx=([A-Za-z_][A-Za-z0-9_]{1,50}|(__.*__))$ 114 | module-rgx=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$ 115 | 116 | [IMPORTS] 117 | # Force import order to recognize a module as part of a third party library. 118 | known-third-party=enchant 119 | 120 | [DESIGN] 121 | # Maximum number of arguments for function / method 122 | max-args=10 123 | 124 | # Maximum number of locals for function / method body 125 | max-locals=25 126 | 127 | # Maximum number of return / yield for function / method body 128 | max-returns=11 129 | 130 | # Maximum number of branch for function / method body 131 | max-branches=26 132 | 133 | # Maximum number of statements in function / method body 134 | max-statements=100 135 | 136 | # Maximum number of parents for a class 137 | max-parents=7 138 | 139 | # Maximum number of attributes for a class 140 | max-attributes=20 141 | 142 | # Minimum number of public methods for a class 143 | min-public-methods=0 144 | 145 | # Maximum number of public methods for a class 146 | max-public-methods=25 147 | 148 | [TYPECHECK] 149 | # List of members which are set dynamically and missed by pylint 150 | # inference system, and so shouldn't trigger E1101 when accessed. 151 | generated-members= 152 | numpy.*, 153 | torch.* 154 | 155 | # List of classes names for which member attributes should not be checked 156 | ignored-classes= 157 | optparse.Values, 158 | thread._local, 159 | _thread._local, 160 | numpy, 161 | torch 162 | 163 | [FORMAT] 164 | # Maximum number of characters on a single line 165 | max-line-length=100 166 | 167 | # Maximum number of lines in a module 168 | max-module-lines=2000 169 | 170 | [VARIABLES] 171 | # A regular expression matching names of dummy variables 172 | dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_ 173 | 174 | # List of strings which can identify a callback function by name 175 | callbacks=cb_,_cb 176 | 177 | # List of qualified module names which can have objects that can redefine builtins 178 | redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io -------------------------------------------------------------------------------- /.python-version: -------------------------------------------------------------------------------- 1 | 3.11 2 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | ## 0.5.1 (2025-04-05) 4 | 5 | ### Minor Updates 6 | 7 | - Bump version to 0.5.1 8 | - Documentation improvements 9 | - Internal code quality enhancements 10 | 11 | ## 0.5.0 (2025-04-05) 12 | 13 | ### New Features 14 | 15 | - **Added paragraph detection functionality:** 16 | - New `PunktParagraphTokenizer` for paragraph boundary detection 17 | - Paragraph breaks identified at sentence boundaries with multiple newlines 18 | - API for paragraph tokenization with span information 19 | - **Added sentence and paragraph span extraction:** 20 | - Contiguous spans that preserve all whitespace 21 | - Spans guaranteed to cover entire text without gaps 22 | - API for getting spans with text content 23 | - **Extended public API with new functions:** 24 | - `sent_spans()` and `sent_spans_with_text()` for sentence spans 25 | - `para_tokenize()`, `para_spans()`, and `para_spans_with_text()` for paragraphs 26 | - Implemented singleton pattern for efficient model loading 27 | 28 | ### Performance Improvements 29 | 30 | - Optimized model loading with caching mechanisms 31 | - Single model instance shared across multiple operations 32 | - Efficient memory usage for repeated sentence/paragraph tokenization 33 | 34 | ### Memory Optimizations 35 | 36 | - **Added memory-efficient training for large text corpora:** 37 | - Early frequency pruning to discard rare items during training 38 | - Streaming processing mode to avoid storing complete token lists 39 | - Batch training for processing very large text collections 40 | - Configurable memory usage parameters 41 | - Added memory benchmarking tools in `.benchmark` directory 42 | - Added documentation for memory-efficient training 43 | - Updated default training script with memory optimization options 44 | 45 | ### Performance Improvements 46 | 47 | - Improved memory usage during training (up to 60% reduction) 48 | - Support for training on very large text collections 49 | - Pruning of low-frequency tokens, collocations, and sentence starters 50 | - Configurable frequency thresholds and pruning intervals 51 | 52 | ## 0.4.0 (2025-03-31) 53 | 54 | ### Performance Improvements 55 | 56 | - **Major tokenization performance optimization:** 57 | - Normal text processing: 31M chars/sec (9% faster) 58 | - Text without sentence endings: 1.4B chars/sec (383% faster) 59 | - Overall tokenization time reduced by 11% 60 | - Function call count reduced by 22% 61 | - PunktToken initialization optimized with token caching and pre-computed properties 62 | - Added fast path optimizations for texts without sentence boundaries 63 | - Improved string handling and regex operations in hot spots 64 | - Added profiling tools for performance analysis and optimization 65 | 66 | ## 0.3.0 (2025-03-31) 67 | 68 | ### New Features 69 | 70 | - Implemented optimized binary model storage format with multiple compression options 71 | - Added utility scripts for working with models (model_info.py, convert_model.py, optimize_model.py) 72 | - Added check_abbreviation.py tool to check if a token is in the model's abbreviation list 73 | - Added general_abbreviations.json file with common English abbreviations 74 | - Updated training process to use both legal and general abbreviation lists 75 | - Improved testing tools with test_tokenizer.py 76 | - Added benchmarking utilities to compare model loading and tokenization performance 77 | - Added profiling tools for performance analysis and optimization 78 | 79 | ### Performance Improvements 80 | 81 | - Reduced default model size by 32% using binary LZMA format (1.5MB vs 2.2MB) 82 | - Better memory usage during model loading 83 | - Automatic format selection prioritizing the most efficient format 84 | - **Major tokenization performance optimization:** 85 | - Normal text processing: 31M chars/sec (9% faster) 86 | - Text without sentence endings: 1.4B chars/sec (383% faster) 87 | - Overall tokenization time reduced by 11% 88 | - Function call count reduced by 22% 89 | - PunktToken initialization optimized with token caching and pre-computed properties 90 | - Added fast path optimizations for texts without sentence boundaries 91 | - Improved string handling and regex operations in hot spots 92 | 93 | ## 0.2.0 (2025-03-30) 94 | 95 | ### New Features 96 | 97 | - Initial release of nupunkt (renamed from punkt2) 98 | - Added compression support for model files using LZMA 99 | - Improved documentation -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 ALEA Institute 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include LICENSE 2 | include README.md 3 | include CHANGELOG.md 4 | include nupunkt/py.typed 5 | include nupunkt/models/default_model.bin 6 | 7 | recursive-include docs *.md 8 | recursive-exclude * __pycache__ 9 | recursive-exclude * *.py[cod] 10 | recursive-exclude * *.so 11 | recursive-exclude * .DS_Store -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # nupunkt 2 | 3 | **nupunkt** is a next-generation implementation of the Punkt algorithm for sentence boundary detection with zero runtime dependencies. 4 | 5 | [![PyPI version](https://badge.fury.io/py/nupunkt.svg)](https://badge.fury.io/py/nupunkt) 6 | [![Python Version](https://img.shields.io/pypi/pyversions/nupunkt.svg)](https://pypi.org/project/nupunkt/) 7 | [![License](https://img.shields.io/github/license/alea-institute/nupunkt.svg)](https://github.com/alea-institute/nupunkt/blob/main/LICENSE) 8 | 9 | ## Overview 10 | 11 | nupunkt accurately detects sentence boundaries in text, even in challenging cases where periods are used for abbreviations, ellipses, and other non-sentence-ending contexts. It's built on the statistical principles of the Punkt algorithm, with modern enhancements for improved handling of edge cases. 12 | 13 | Key features: 14 | - **Minimal dependencies**: Only requires Python 3.11+ and tqdm for progress bars 15 | - **Pre-trained model**: Ready to use out of the box 16 | - **Fast and accurate**: Optimized implementation of the Punkt algorithm 17 | - **Trainable**: Can be trained on domain-specific text 18 | - **Full support for ellipsis**: Handles various ellipsis patterns 19 | - **Type annotations**: Complete type hints for better IDE integration 20 | 21 | ## Installation 22 | 23 | ```bash 24 | pip install nupunkt 25 | ``` 26 | 27 | ## Quick Start 28 | 29 | ```python 30 | from nupunkt import sent_tokenize 31 | 32 | text = """ 33 | Employee also specifically and forever releases the Acme Inc. (Company) and the Company Parties (except where and 34 | to the extent that such a release is expressly prohibited or made void by law) from any claims based on unlawful 35 | employment discrimination or harassment, including, but not limited to, the Federal Age Discrimination in 36 | Employment Act (29 U.S.C. § 621 et. seq.). This release does not include Employee’s right to indemnification, 37 | and related insurance coverage, under Sec. 7.1.4 or Ex. 1-1 of the Employment Agreement, his right to equity awards, 38 | or continued exercise, pursuant to the terms of any specific equity award (or similar) agreement between 39 | Employee and the Company nor to Employee’s right to benefits under any Company plan or program in which 40 | Employee participated and is due a benefit in accordance with the terms of the plan or program as of the Effective 41 | Date and ending at 11:59 p.m. Eastern Time on Sep. 15, 2013. 42 | """ 43 | 44 | # Tokenize into sentences 45 | sentences = sent_tokenize(text) 46 | 47 | # Print the results 48 | for i, sentence in enumerate(sentences, 1): 49 | print(f"Sentence {i}: {sentence}\n") 50 | ``` 51 | 52 | Output: 53 | ``` 54 | Sentence 1: 55 | Employee also specifically and forever releases the Acme Inc. (Company) and the Company Parties (except where and 56 | to the extent that such a release is expressly prohibited or made void by law) from any claims based on unlawful 57 | employment discrimination or harassment, including, but not limited to, the Federal Age Discrimination in 58 | Employment Act (29 U.S.C. § 621 et. seq.). 59 | 60 | Sentence 2: This release does not include Employee’s right to indemnification, 61 | and related insurance coverage, under Sec. 7.1.4 or Ex. 1-1 of the Employment Agreement, his right to equity awards, 62 | or continued exercise, pursuant to the terms of any specific equity award (or similar) agreement between 63 | Employee and the Company nor to Employee’s right to benefits under any Company plan or program in which 64 | Employee participated and is due a benefit in accordance with the terms of the plan or program as of the Effective 65 | Date and ending at 11:59 p.m. Eastern Time on Sep. 15, 2013. 66 | ``` 67 | 68 | ## Documentation 69 | 70 | For more detailed documentation, see the [docs](./docs) directory: 71 | 72 | - [Overview](./docs/overview.md) 73 | - [Getting Started](./docs/getting_started.md) 74 | - [API Reference](./docs/api_reference.md) 75 | - [Architecture](./docs/architecture.md) 76 | - [Training Models](./docs/training_models.md) 77 | - [Advanced Usage](./docs/advanced_usage.md) 78 | 79 | ## Command-line Tools 80 | 81 | nupunkt comes with several utility scripts for working with models: 82 | 83 | - **check_abbreviation.py**: Check if a token is in the model's abbreviation list 84 | ```bash 85 | python -m scripts.utils.check_abbreviation "U.S." 86 | python -m scripts.utils.check_abbreviation --list # List all abbreviations 87 | python -m scripts.utils.check_abbreviation --count # Count abbreviations 88 | ``` 89 | 90 | - **test_tokenizer.py**: Test the tokenizer on sample text 91 | - **model_info.py**: Display information about a model file 92 | 93 | See the [scripts/utils/README.md](./scripts/utils/README.md) for more details on available tools. 94 | 95 | ## Advanced Example 96 | 97 | ```python 98 | from nupunkt import PunktTrainer, PunktSentenceTokenizer 99 | 100 | # Train a new model on domain-specific text 101 | with open("legal_corpus.txt", "r", encoding="utf-8") as f: 102 | legal_text = f.read() 103 | 104 | trainer = PunktTrainer(legal_text, verbose=True) 105 | params = trainer.get_params() 106 | 107 | # Save the trained model 108 | trainer.save("legal_model.json") 109 | 110 | # Create a tokenizer with the trained parameters 111 | tokenizer = PunktSentenceTokenizer(params) 112 | 113 | # Tokenize legal text 114 | legal_sample = "The court ruled in favor of the plaintiff. 28 U.S.C. § 1332 provides jurisdiction." 115 | sentences = tokenizer.tokenize(legal_sample) 116 | 117 | for s in sentences: 118 | print(s) 119 | ``` 120 | 121 | ## Performance 122 | 123 | nupunkt is designed to be both accurate and efficient. It can process large volumes of text quickly, making it suitable for production NLP pipelines. 124 | 125 | ### Highly Optimized 126 | 127 | The tokenizer has been extensively optimized for performance: 128 | - **Token caching** for common tokens 129 | - **Fast path processing** for texts without sentence boundaries (up to 1.4B chars/sec) 130 | - **Pre-computed properties** to avoid repeated calculations 131 | - **Efficient character processing** and string handling in hot spots 132 | 133 | ### Example Legal Domain Benchmark 134 | ``` 135 | Performance Results: 136 | Documents processed: 1 137 | Total characters: 16,567,769 138 | Total sentences found: 16,095 139 | Processing time: 0.49 seconds 140 | Processing speed: 33,927,693 characters/second 141 | Average sentence length: 1029.4 characters 142 | ``` 143 | 144 | ### Specialized Use Cases 145 | - Normal text processing: ~31M characters/second 146 | - Text without sentence boundaries: ~1.4B characters/second 147 | - Short text fragments: Extremely fast with early exit paths 148 | 149 | ## Contributing 150 | 151 | Contributions are welcome! Please feel free to submit a Pull Request. 152 | 153 | ## License 154 | 155 | This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details. 156 | 157 | ## Acknowledgments 158 | 159 | nupunkt is based on the Punkt algorithm originally developed by Tibor Kiss and Jan Strunk. -------------------------------------------------------------------------------- /data/general_abbreviations.json: -------------------------------------------------------------------------------- 1 | [ 2 | "Jan.", 3 | "Feb.", 4 | "Mar.", 5 | "Apr.", 6 | "Jun.", 7 | "Jul.", 8 | "Aug.", 9 | "Sep.", 10 | "Sept.", 11 | "Oct.", 12 | "Nov.", 13 | "Dec.", 14 | "Mon.", 15 | "Tue.", 16 | "Tues.", 17 | "Wed.", 18 | "Thu.", 19 | "Thur.", 20 | "Thurs.", 21 | "Fri.", 22 | "Sat.", 23 | "Sun.", 24 | "Mr.", 25 | "Mrs.", 26 | "Ms.", 27 | "Dr.", 28 | "Prof.", 29 | "Rev.", 30 | "Hon.", 31 | "St.", 32 | "Jr.", 33 | "Sr.", 34 | "Ph.D.", 35 | "M.D.", 36 | "B.A.", 37 | "M.A.", 38 | "B.Sc.", 39 | "M.Sc.", 40 | "B.S.", 41 | "M.S.", 42 | "MBA.", 43 | "CEO.", 44 | "CFO.", 45 | "CTO.", 46 | "COO.", 47 | "VP.", 48 | "Dept.", 49 | "Govt.", 50 | "Corp.", 51 | "Inc.", 52 | "Ltd.", 53 | "PLC.", 54 | "Co.", 55 | "assoc.", 56 | "Ave.", 57 | "Blvd.", 58 | "Rd.", 59 | "Dr.", 60 | "Ln.", 61 | "Pl.", 62 | "Ct.", 63 | "Pkwy.", 64 | "Hwy.", 65 | "Rt.", 66 | "Apt.", 67 | "Ste.", 68 | "Bldg.", 69 | "No.", 70 | "vs.", 71 | "etc.", 72 | "i.e.", 73 | "e.g.", 74 | "a.m.", 75 | "p.m.", 76 | "ca.", 77 | "approx.", 78 | "est.", 79 | "min.", 80 | "sec.", 81 | "hr.", 82 | "ft.", 83 | "in.", 84 | "lb.", 85 | "oz.", 86 | "kg.", 87 | "gal.", 88 | "pt.", 89 | "qt.", 90 | "tbsp.", 91 | "tsp.", 92 | "adj.", 93 | "adv.", 94 | "conj.", 95 | "prep.", 96 | "n.", 97 | "v.", 98 | "fig." 99 | ] -------------------------------------------------------------------------------- /data/test.jsonl.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alea-institute/nupunkt/29d056aba0f6c9e0f43ee1e36d2638260027af0c/data/test.jsonl.gz -------------------------------------------------------------------------------- /data/train2.jsonl.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alea-institute/nupunkt/29d056aba0f6c9e0f43ee1e36d2638260027af0c/data/train2.jsonl.gz -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | # nupunkt Documentation 2 | 3 | This folder contains documentation for the nupunkt library, a next-generation implementation of the Punkt algorithm for sentence boundary detection. 4 | 5 | ## Documentation Contents 6 | 7 | - [Overview](overview.md): Introduction to nupunkt and its features 8 | - [Getting Started](getting_started.md): Installation and basic usage 9 | - [API Reference](api_reference.md): Detailed API documentation 10 | - [Architecture](architecture.md): Package architecture and module descriptions 11 | - [Training Models](training_models.md): How to train custom models 12 | - [Advanced Usage](advanced_usage.md): Advanced usage patterns and customization -------------------------------------------------------------------------------- /docs/advanced_usage.md: -------------------------------------------------------------------------------- 1 | # Advanced Usage 2 | 3 | This guide covers advanced usage patterns and customization options for nupunkt. 4 | 5 | ## Memory-Efficient Training 6 | 7 | NUpunkt now includes several optimizations for training on very large text collections with manageable memory usage. These optimizations allow you to train on much larger corpora than was previously possible. 8 | 9 | ### Memory Optimization Techniques 10 | 11 | 1. **Early Pruning**: Discard low-frequency items during training 12 | 2. **Streaming Processing**: Process text without storing complete token lists 13 | 3. **Batch Training**: Process text in manageable chunks 14 | 4. **Memory Configuration**: Fine-tune memory usage via parameters 15 | 16 | ### Basic Memory-Efficient Training 17 | 18 | ```python 19 | from nupunkt.trainers.base_trainer import PunktTrainer 20 | 21 | # Create a memory-efficient trainer 22 | trainer = PunktTrainer(memory_efficient=True, verbose=True) 23 | 24 | # Train with streaming mode (avoids storing all tokens at once) 25 | trainer.train(text) 26 | 27 | # Get the trained parameters 28 | params = trainer.get_params() 29 | ``` 30 | 31 | ### Batch Training for Very Large Corpora 32 | 33 | For extremely large text collections, you can use batch training: 34 | 35 | ```python 36 | from nupunkt.trainers.base_trainer import PunktTrainer 37 | 38 | # Create a trainer 39 | trainer = PunktTrainer(verbose=True) 40 | 41 | # Split text into batches 42 | batches = PunktTrainer.text_to_batches(huge_text, batch_size=1000000) 43 | 44 | # Train in batches 45 | trainer.train_batches(batches, verbose=True) 46 | ``` 47 | 48 | ### Memory Configuration Parameters 49 | 50 | You can fine-tune the memory usage with these parameters: 51 | 52 | ```python 53 | trainer = PunktTrainer(memory_efficient=True) 54 | 55 | # Configure memory usage 56 | trainer.TYPE_FDIST_MIN_FREQ = 2 # Minimum frequency to keep a type 57 | trainer.COLLOC_FDIST_MIN_FREQ = 3 # Minimum frequency for collocations 58 | trainer.SENT_STARTER_MIN_FREQ = 2 # Minimum frequency for sentence starters 59 | trainer.PRUNE_INTERVAL = 10000 # How often to prune (token count) 60 | trainer.CHUNK_SIZE = 10000 # Size of token chunks for processing 61 | ``` 62 | 63 | ### Command-Line Usage 64 | 65 | When using the default training script, you can enable memory optimizations: 66 | 67 | ```bash 68 | python -m scripts.train_default_model \ 69 | --memory-efficient \ 70 | --min-type-freq 2 \ 71 | --prune-freq 10000 \ 72 | --use-batches \ 73 | --batch-size 1000000 74 | ``` 75 | 76 | ### Memory Impact 77 | 78 | The memory optimizations can significantly reduce memory usage: 79 | 80 | | Optimization | Memory Reduction | Impact on Model Quality | 81 | |--------------|------------------|-------------------------| 82 | | Early Pruning | ~30-50% | Minimal (removes rare items) | 83 | | Streaming Processing | ~40-60% | None (same algorithm) | 84 | | Batch Training | Scales to any size | None (same algorithm) | 85 | 86 | ## Model Compression 87 | 88 | Models in nupunkt are saved with LZMA compression by default, which significantly reduces file size while maintaining fast loading times. You can control compression settings when saving models: 89 | 90 | ```python 91 | from nupunkt import PunktTrainer 92 | 93 | # Train a model 94 | trainer = PunktTrainer(training_text) 95 | 96 | # Save with default compression (level 1 - fast compression) 97 | trainer.save("my_model.json") # Creates my_model.json.xz 98 | 99 | # Save with higher compression (smaller file, slower compression) 100 | trainer.save("my_model_high_compression.json", compression_level=6) 101 | 102 | # Save without compression 103 | trainer.save("my_model_uncompressed.json", compress=False) 104 | ``` 105 | 106 | Loading compressed models is transparent - the library automatically detects and handles compressed files: 107 | 108 | ```python 109 | from nupunkt import PunktSentenceTokenizer 110 | 111 | # Both of these will work regardless of whether the model is compressed 112 | tokenizer1 = PunktSentenceTokenizer.load("my_model.json") 113 | tokenizer2 = PunktSentenceTokenizer.load("my_model.json.xz") 114 | ``` 115 | 116 | Compressing an existing uncompressed model: 117 | 118 | ```python 119 | from nupunkt.models import compress_default_model 120 | 121 | # Compress the default model 122 | compressed_path = compress_default_model() 123 | print(f"Compressed model saved to: {compressed_path}") 124 | 125 | # Compress with custom settings 126 | custom_path = compress_default_model("custom_model.json.xz", compression_level=3) 127 | ``` 128 | 129 | ## Custom Language Variables 130 | 131 | For handling different languages or specific text domains, you can customize the language variables: 132 | 133 | ```python 134 | from nupunkt import PunktLanguageVars, PunktSentenceTokenizer 135 | 136 | class FrenchLanguageVars(PunktLanguageVars): 137 | # French uses colons as sentence endings 138 | sent_end_chars = (".", "?", "!", ":") 139 | 140 | # Customize internal punctuation 141 | internal_punctuation = ",:;«»" 142 | 143 | # Customize word tokenization pattern if needed 144 | _re_word_start = r"[^\(\"\`{\[:;&\#\*@\)}\]\-,«»]" 145 | 146 | # Create a tokenizer with the custom language variables 147 | french_vars = FrenchLanguageVars() 148 | tokenizer = PunktSentenceTokenizer(lang_vars=french_vars) 149 | 150 | # Tokenize French text 151 | french_text = "Bonjour! Comment allez-vous? Très bien, merci." 152 | sentences = tokenizer.tokenize(french_text) 153 | ``` 154 | 155 | ## Custom Token Class 156 | 157 | You can extend the `PunktToken` class to add additional functionality: 158 | 159 | ```python 160 | from nupunkt import PunktToken, PunktSentenceTokenizer 161 | from typing import Optional 162 | 163 | class EnhancedToken(PunktToken): 164 | def __init__(self, tok, parastart=False, linestart=False, 165 | sentbreak=False, abbr=False, ellipsis=False, 166 | pos_tag: Optional[str] = None): 167 | super().__init__(tok, parastart, linestart, sentbreak, abbr, ellipsis) 168 | self.pos_tag = pos_tag 169 | 170 | @property 171 | def is_noun(self) -> bool: 172 | return self.pos_tag == "NOUN" if self.pos_tag else False 173 | 174 | # Create a tokenizer with the custom token class 175 | tokenizer = PunktSentenceTokenizer(token_cls=EnhancedToken) 176 | ``` 177 | 178 | ## Working with Spans 179 | 180 | For applications that need character-level positions: 181 | 182 | ```python 183 | from nupunkt import load_default_model 184 | 185 | tokenizer = load_default_model() 186 | text = "Hello world. This is a test." 187 | 188 | # Get character-level spans for each sentence 189 | spans = list(tokenizer.span_tokenize(text)) 190 | print(spans) # [(0, 12), (13, 27)] 191 | 192 | # Useful for highlighting or extracting sentences 193 | for start, end in spans: 194 | print(f"Sentence: {text[start:end]}") 195 | ``` 196 | 197 | ## Customizing Boundary Realignment 198 | 199 | By default, nupunkt realigns sentence boundaries to handle trailing punctuation like quotes. You can disable this: 200 | 201 | ```python 202 | sentences = tokenizer.tokenize(text, realign_boundaries=False) 203 | ``` 204 | 205 | ## Reconfiguring Tokenizers 206 | 207 | You can update tokenizer settings without retraining: 208 | 209 | ```python 210 | from nupunkt import PunktSentenceTokenizer 211 | 212 | tokenizer = PunktSentenceTokenizer.load("my_model.json") 213 | 214 | # Update configuration 215 | config = { 216 | "parameters": { 217 | "abbrev_types": ["Dr", "Mr", "Mrs", "Ms", "Prof", "Inc", "Co"], 218 | # Add custom abbreviations 219 | } 220 | } 221 | 222 | tokenizer.reconfigure(config) 223 | ``` 224 | 225 | ## Sentence Internal Spans 226 | 227 | To get spans of sentences excluding certain punctuation: 228 | 229 | ```python 230 | import re 231 | from nupunkt import load_default_model 232 | 233 | tokenizer = load_default_model() 234 | text = "\"Hello,\" he said. \"How are you?\"" 235 | 236 | # Get spans for each sentence 237 | spans = list(tokenizer.span_tokenize(text)) 238 | 239 | # Clean up internal spans by removing quotation marks 240 | for start, end in spans: 241 | sentence = text[start:end] 242 | # Remove leading/trailing quotes and whitespace 243 | cleaned = re.sub(r'^[\s"\']+|[\s"\']+$', '', sentence) 244 | print(f"Original: {sentence}") 245 | print(f"Cleaned: {cleaned}") 246 | ``` 247 | 248 | ## Parallelizing Tokenization 249 | 250 | For processing large volumes of text, you can parallelize the work: 251 | 252 | ```python 253 | from nupunkt import load_default_model 254 | from concurrent.futures import ProcessPoolExecutor 255 | import multiprocessing 256 | 257 | def tokenize_text(text): 258 | tokenizer = load_default_model() 259 | return tokenizer.tokenize(text) 260 | 261 | # Break large text into chunks 262 | def chunk_text(text, chunk_size=100000): 263 | # Simple chunking by character count 264 | # More sophisticated chunking could preserve paragraph/document boundaries 265 | return [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)] 266 | 267 | # Process in parallel 268 | def parallel_tokenize(text, num_workers=None): 269 | if num_workers is None: 270 | num_workers = multiprocessing.cpu_count() 271 | 272 | chunks = chunk_text(text) 273 | all_sentences = [] 274 | 275 | with ProcessPoolExecutor(max_workers=num_workers) as executor: 276 | chunk_results = list(executor.map(tokenize_text, chunks)) 277 | 278 | # Flatten results 279 | for sentences in chunk_results: 280 | all_sentences.extend(sentences) 281 | 282 | return all_sentences 283 | ``` 284 | 285 | ## Customizing JSON Serialization 286 | 287 | If you need to customize how models are saved and loaded: 288 | 289 | ```python 290 | from nupunkt import PunktParameters, PunktTrainer 291 | import json 292 | 293 | # Extended parameters with custom metadata 294 | class ExtendedParameters(PunktParameters): 295 | def __init__(self, metadata=None, **kwargs): 296 | super().__init__(**kwargs) 297 | self.metadata = metadata or {} 298 | 299 | def to_json(self): 300 | data = super().to_json() 301 | data["metadata"] = self.metadata 302 | return data 303 | 304 | @classmethod 305 | def from_json(cls, data): 306 | params = super().from_json(data) 307 | params.metadata = data.get("metadata", {}) 308 | return params 309 | 310 | # Create parameters with metadata 311 | params = ExtendedParameters(metadata={"domain": "legal", "source": "case_law"}) 312 | 313 | # Add some data 314 | params.abbrev_types.add("etc") 315 | 316 | # Save with pretty formatting 317 | with open("custom_params.json", "w") as f: 318 | json.dump(params.to_json(), f, indent=2) 319 | ``` 320 | 321 | ## Debug and Visualization Tools 322 | 323 | To understand how nupunkt is making decisions: 324 | 325 | ```python 326 | from nupunkt import PunktSentenceTokenizer, PunktToken 327 | 328 | # Create a debugging tokenizer 329 | class DebugTokenizer(PunktSentenceTokenizer): 330 | def tokenize_with_debug(self, text): 331 | # Tokenize and collect debugging info 332 | tokens = list(self._tokenize_words(text)) 333 | first_pass = list(self._annotate_first_pass(tokens)) 334 | 335 | # Second pass with decision tracking 336 | decisions = [] 337 | for token1, token2 in self._pair_iter(first_pass): 338 | if token1.period_final or token1.tok in self._lang_vars.sent_end_chars: 339 | decision = self._second_pass_annotation(token1, token2) 340 | decisions.append((token1, decision)) 341 | 342 | # Return the debug info along with sentences 343 | return { 344 | "sentences": self.tokenize(text), 345 | "tokens": tokens, 346 | "first_pass": first_pass, 347 | "decisions": decisions 348 | } 349 | 350 | # Use the debug tokenizer 351 | debug_tokenizer = DebugTokenizer() 352 | results = debug_tokenizer.tokenize_with_debug("Dr. Smith went to Washington D.C. yesterday.") 353 | 354 | # Print debug info 355 | for token, decision in results["decisions"]: 356 | print(f"Token: {token.tok}, Abbr: {token.abbr}, SentBreak: {token.sentbreak}, Decision: {decision}") 357 | ``` 358 | 359 | These advanced usage patterns should help you customize nupunkt for specific needs and troubleshoot any issues that arise. -------------------------------------------------------------------------------- /docs/architecture.md: -------------------------------------------------------------------------------- 1 | # nupunkt Architecture 2 | 3 | This document describes the architecture and modules of the nupunkt package. 4 | 5 | ## Package Structure 6 | 7 | The nupunkt package is organized into the following modules: 8 | 9 | ``` 10 | nupunkt/ 11 | ├── __init__.py # Package initialization and public API 12 | ├── nupunkt.py # Main implementation file 13 | ├── py.typed # Type checking marker 14 | ├── core/ # Core components 15 | │ ├── __init__.py 16 | │ ├── base.py # Base classes 17 | │ ├── constants.py # Constant definitions 18 | │ ├── language_vars.py # Language variables 19 | │ ├── parameters.py # Algorithm parameters 20 | │ └── tokens.py # Token representation 21 | ├── models/ # Model handling 22 | │ ├── __init__.py 23 | │ └── default_model.json # Pre-trained model 24 | ├── tokenizers/ # Tokenization components 25 | │ ├── __init__.py 26 | │ └── sentence_tokenizer.py # Sentence tokenizer 27 | ├── trainers/ # Training components 28 | │ ├── __init__.py 29 | │ └── base_trainer.py # Trainer implementation 30 | └── utils/ # Utility functions 31 | ├── __init__.py 32 | ├── iteration.py # Iteration helpers 33 | └── statistics.py # Statistical functions 34 | ``` 35 | 36 | ## Module Descriptions 37 | 38 | ### Core Module 39 | 40 | The `core` module provides the fundamental building blocks for the sentence tokenization process: 41 | 42 | - **base.py**: Contains the `PunktBase` class, which provides common functionality used by both trainers and tokenizers. 43 | - **constants.py**: Defines orthographic context constants used to track capitalization patterns. 44 | - **language_vars.py**: Provides the `PunktLanguageVars` class, which encapsulates language-specific behaviors. 45 | - **parameters.py**: Contains the `PunktParameters` class that stores the learned parameters (abbreviations, collocations, etc.). 46 | - **tokens.py**: Defines the `PunktToken` class, which represents tokens with various attributes. 47 | 48 | ### Models Module 49 | 50 | The `models` module handles model loading and provides a pre-trained default model: 51 | 52 | - **__init__.py**: Contains functions to load the default model. 53 | - **default_model.json**: A pre-trained model ready for general use. 54 | 55 | ### Tokenizers Module 56 | 57 | The `tokenizers` module contains the sentence tokenizer implementation: 58 | 59 | - **sentence_tokenizer.py**: Implements the `PunktSentenceTokenizer` class, which performs the actual sentence boundary detection using trained parameters. 60 | 61 | ### Trainers Module 62 | 63 | The `trainers` module handles training new models from text: 64 | 65 | - **base_trainer.py**: Contains the `PunktTrainer` class, which learns parameters from training text. 66 | 67 | ### Utils Module 68 | 69 | The `utils` module provides utility functions used throughout the package: 70 | 71 | - **iteration.py**: Contains utilities for iteration, like `pair_iter` for iterating through pairs of items. 72 | - **statistics.py**: Provides statistical functions for calculating log-likelihood and other measures. 73 | 74 | ## Main Classes 75 | 76 | ### PunktSentenceTokenizer 77 | 78 | The primary class for tokenizing text into sentences. It uses trained parameters to identify sentence boundaries. 79 | 80 | ```python 81 | tokenizer = PunktSentenceTokenizer() 82 | sentences = tokenizer.tokenize(text) 83 | ``` 84 | 85 | ### PunktTrainer 86 | 87 | Used to train new models on domain-specific text: 88 | 89 | ```python 90 | trainer = PunktTrainer(train_text, verbose=True) 91 | params = trainer.get_params() 92 | tokenizer = PunktSentenceTokenizer(params) 93 | ``` 94 | 95 | ### PunktParameters 96 | 97 | Stores the learned parameters that guide the tokenization process: 98 | 99 | - **abbrev_types**: Set of known abbreviations 100 | - **collocations**: Set of word pairs that often occur across sentence boundaries 101 | - **sent_starters**: Set of words that often start sentences 102 | - **ortho_context**: Dictionary tracking capitalization patterns 103 | 104 | ### PunktLanguageVars 105 | 106 | Contains language-specific variables and settings that can be customized: 107 | 108 | - **sent_end_chars**: Characters that can end sentences 109 | - **internal_punctuation**: Characters considered internal punctuation 110 | - **word_tokenize_pattern**: Pattern for tokenizing words 111 | 112 | ## Data Flow 113 | 114 | 1. Text is tokenized into words using `PunktLanguageVars.word_tokenize()`. 115 | 2. Tokens are annotated in the first pass to identify sentence breaks, abbreviations, and ellipses. 116 | 3. Tokens are annotated in the second pass using collocational and orthographic heuristics. 117 | 4. Sentence boundaries are determined based on the annotated tokens. 118 | 5. Boundaries are optionally realigned to handle trailing punctuation. 119 | 6. The resulting sentences or spans are returned. 120 | 121 | ## Algorithm Workflow 122 | 123 | ### Training 124 | 125 | 1. Count token frequencies in the training text. 126 | 2. Identify potential abbreviations based on statistical measures. 127 | 3. Annotate tokens with sentence breaks, abbreviations, and ellipses. 128 | 4. Gather orthographic context data (capitalization patterns). 129 | 5. Identify collocations and sentence starters. 130 | 6. Finalize the parameters for use in tokenization. 131 | 132 | ### Tokenization 133 | 134 | 1. Break the text into words. 135 | 2. Annotate tokens with sentence breaks, abbreviations, and ellipses. 136 | 3. Apply collocational and orthographic heuristics to refine annotations. 137 | 4. Use sentence breaks to slice the text into sentences. 138 | 5. Realign boundaries if requested. 139 | 6. Return sentences or character spans. -------------------------------------------------------------------------------- /docs/getting_started.md: -------------------------------------------------------------------------------- 1 | # Getting Started with nupunkt 2 | 3 | This guide will help you install nupunkt and get started with basic sentence tokenization. 4 | 5 | ## Installation 6 | 7 | ### Using pip 8 | 9 | ```bash 10 | pip install nupunkt 11 | ``` 12 | 13 | ### From source 14 | 15 | ```bash 16 | git clone https://github.com/alea-institute/nupunkt.git 17 | cd nupunkt 18 | pip install -e . 19 | ``` 20 | 21 | For development, install with additional dependencies: 22 | 23 | ```bash 24 | pip install -e ".[dev]" 25 | ``` 26 | 27 | ## Basic Usage 28 | 29 | ### Quick Sentence Tokenization 30 | 31 | For quick tokenization using the default pre-trained model: 32 | 33 | ```python 34 | from nupunkt import sent_tokenize 35 | 36 | text = "Hello world. This is a test. Mr. Smith went to Washington D.C. yesterday." 37 | sentences = sent_tokenize(text) 38 | 39 | for i, sentence in enumerate(sentences, 1): 40 | print(f"Sentence {i}: {sentence}") 41 | ``` 42 | 43 | Output: 44 | ``` 45 | Sentence 1: Hello world. 46 | Sentence 2: This is a test. 47 | Sentence 3: Mr. Smith went to Washington D.C. yesterday. 48 | ``` 49 | 50 | ### Creating a Tokenizer Instance 51 | 52 | If you need more control or plan to tokenize multiple texts: 53 | 54 | ```python 55 | from nupunkt import PunktSentenceTokenizer 56 | from nupunkt.models import load_default_model 57 | 58 | # Load the default pre-trained model 59 | tokenizer = load_default_model() 60 | 61 | # Or create a new tokenizer instance 62 | # tokenizer = PunktSentenceTokenizer() 63 | 64 | text = "Hello world! Is this a sentence? Yes, it is. Dr. Smith teaches at the U.S.C." 65 | sentences = tokenizer.tokenize(text) 66 | 67 | for i, sentence in enumerate(sentences, 1): 68 | print(f"Sentence {i}: {sentence}") 69 | ``` 70 | 71 | Output: 72 | ``` 73 | Sentence 1: Hello world! 74 | Sentence 2: Is this a sentence? 75 | Sentence 3: Yes, it is. 76 | Sentence 4: Dr. Smith teaches at the U.S.C. 77 | ``` 78 | 79 | ### Getting Sentence Spans 80 | 81 | If you need character offsets for the sentences: 82 | 83 | ```python 84 | spans = list(tokenizer.span_tokenize(text)) 85 | for start, end in spans: 86 | print(f"Span ({start}, {end}): {text[start:end]}") 87 | ``` 88 | 89 | ## Common Options 90 | 91 | ### Handling Boundary Realignment 92 | 93 | By default, nupunkt realigns sentence boundaries to handle trailing punctuation. You can disable this behavior: 94 | 95 | ```python 96 | sentences = tokenizer.tokenize(text, realign_boundaries=False) 97 | ``` 98 | 99 | ### Custom Language Variables 100 | 101 | You can customize language-specific behavior: 102 | 103 | ```python 104 | from nupunkt import PunktLanguageVars, PunktSentenceTokenizer 105 | 106 | class GermanLanguageVars(PunktLanguageVars): 107 | # Add German-specific customizations 108 | sent_end_chars = (".", "?", "!", ":") 109 | 110 | german_vars = GermanLanguageVars() 111 | tokenizer = PunktSentenceTokenizer(lang_vars=german_vars) 112 | ``` 113 | 114 | ## Next Steps 115 | 116 | - See [Training Models](training_models.md) for training your own model 117 | - Explore [Advanced Usage](advanced_usage.md) for more customization options 118 | - Check the [API Reference](api_reference.md) for detailed information on all classes and methods -------------------------------------------------------------------------------- /docs/overview.md: -------------------------------------------------------------------------------- 1 | # nupunkt Overview 2 | 3 | nupunkt is a Python library for sentence boundary detection based on the Punkt algorithm. It's designed to be lightweight, fast, and accurate, with a focus on handling the complexities of real-world text. 4 | 5 | ## What is Punkt? 6 | 7 | The Punkt algorithm, originally developed by Tibor Kiss and Jan Strunk, is an unsupervised approach to sentence boundary detection. It uses statistical methods to learn which periods indicate sentence boundaries versus those that are part of abbreviations, ellipses, or other non-terminal uses. 8 | 9 | ## Key Features 10 | 11 | - **Zero runtime dependencies**: nupunkt is designed to work with minimal dependencies 12 | - **Pre-trained models**: Comes with a default model ready for use 13 | - **Customizable**: Can be trained on domain-specific text 14 | - **Ellipsis handling**: Special handling for various ellipsis patterns 15 | - **Performance optimized**: Efficient implementation for processing large text volumes 16 | - **Fully typed**: Complete type annotations for better IDE integration 17 | 18 | ## How It Works 19 | 20 | nupunkt's approach to sentence boundary detection involves: 21 | 22 | 1. **Tokenization**: Breaking text into tokens 23 | 2. **Abbreviation detection**: Identifying abbreviations that end with periods 24 | 3. **Collocation identification**: Finding word pairs that tend to occur together across sentence boundaries 25 | 4. **Sentence starter recognition**: Learning which words typically start sentences 26 | 5. **Orthographic context analysis**: Using capitalization patterns to identify sentence boundaries 27 | 28 | The algorithm works in multiple passes to annotate tokens with sentence breaks, abbreviations, and other features, ultimately producing accurate sentence boundaries even in challenging text. 29 | 30 | ## When to Use nupunkt 31 | 32 | nupunkt is particularly useful for: 33 | 34 | - Natural language processing pipelines 35 | - Text preprocessing for machine learning 36 | - Extracting sentences from large text corpora 37 | - Legal and scientific text processing where abbreviations are common 38 | - Any application requiring accurate sentence boundary detection 39 | 40 | ## Comparison with Other Tools 41 | 42 | Unlike many other tokenizers, nupunkt: 43 | 44 | - Doesn't rely on hand-crafted rules or large language models 45 | - Can adapt to domain-specific abbreviations and patterns through training 46 | - Handles ellipses and other complex punctuation patterns 47 | - Has minimal dependencies while maintaining high accuracy -------------------------------------------------------------------------------- /mypy.ini: -------------------------------------------------------------------------------- 1 | [mypy] 2 | python_version = 3.8 3 | warn_return_any = True 4 | warn_unused_configs = True 5 | disallow_untyped_calls = False 6 | disallow_untyped_defs = False 7 | disallow_incomplete_defs = False 8 | check_untyped_defs = True 9 | disallow_subclassing_any = False 10 | warn_redundant_casts = False 11 | warn_no_return = True 12 | warn_unreachable = True 13 | allow_redefinition = True 14 | strict_equality = True 15 | ignore_missing_imports = True 16 | show_error_codes = True 17 | 18 | # Don't flag "list" vs "List" syntax issues to be compatible with Python 3.9+ 19 | disallow_any_unimported = False 20 | disallow_any_generics = False 21 | 22 | # Ignore missing stubs for external libraries 23 | [mypy.plugins.numpy.*] 24 | ignore_missing_imports = True 25 | 26 | [mypy.plugins.pandas.*] 27 | ignore_missing_imports = True 28 | 29 | [mypy.plugins.tqdm.*] 30 | ignore_missing_imports = True 31 | 32 | # Ignore specific errors that would require larger refactoring 33 | [mypy-nupunkt.trainers.base_trainer] 34 | # Allow list->Iterator type conversions which would require deeper refactoring 35 | disallow_any_generics = False 36 | # Allow None as default for bool parameters (for backward compatibility) 37 | no_implicit_optional = False 38 | 39 | [mypy-nupunkt.nupunkt] 40 | # Allow list->Iterator type conversions which would require deeper refactoring 41 | disallow_any_generics = False 42 | # Allow None as default for bool parameters (for backward compatibility) 43 | no_implicit_optional = False 44 | 45 | [mypy-scripts.utils.*] 46 | # Allow sys.path manipulation and imports for scripts 47 | ignore_errors = False 48 | # Some scripts use relative imports after path manipulation 49 | allow_redefinition = True -------------------------------------------------------------------------------- /nupunkt/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | nupunkt is a Python library for sentence and paragraph boundary detection based on the Punkt algorithm. 3 | 4 | It learns to identify sentence boundaries in text, even when periods are used for 5 | abbreviations, ellipses, and other non-sentence-ending contexts. It also supports 6 | paragraph detection based on sentence boundaries and newlines. 7 | """ 8 | 9 | __version__ = "0.5.1" 10 | 11 | # Core classes 12 | from functools import lru_cache 13 | 14 | # Import for type annotations 15 | from typing import List, Tuple 16 | 17 | from nupunkt.core.language_vars import PunktLanguageVars 18 | from nupunkt.core.parameters import PunktParameters 19 | from nupunkt.core.tokens import PunktToken 20 | 21 | # Models 22 | from nupunkt.models import load_default_model 23 | from nupunkt.tokenizers.paragraph_tokenizer import PunktParagraphTokenizer 24 | 25 | # Tokenizers 26 | from nupunkt.tokenizers.sentence_tokenizer import PunktSentenceTokenizer 27 | 28 | # Trainers 29 | from nupunkt.trainers.base_trainer import PunktTrainer 30 | 31 | 32 | # Singleton pattern to load model only once 33 | @lru_cache(maxsize=1) 34 | def _get_default_model(): 35 | """Get the default model, loading it only once.""" 36 | return load_default_model() 37 | 38 | 39 | @lru_cache(maxsize=1) 40 | def _get_paragraph_tokenizer(): 41 | """Get the paragraph tokenizer with the default model, loading it only once.""" 42 | return PunktParagraphTokenizer(_get_default_model()) 43 | 44 | 45 | # Function for quick and easy sentence tokenization 46 | def sent_tokenize(text: str) -> List[str]: 47 | """ 48 | Tokenize text into sentences using the default pre-trained model. 49 | 50 | This is a convenience function for quick sentence tokenization 51 | without having to explicitly load a model. 52 | 53 | Args: 54 | text: The text to tokenize 55 | 56 | Returns: 57 | A list of sentences 58 | """ 59 | tokenizer = _get_default_model() 60 | return list(tokenizer.tokenize(text)) 61 | 62 | 63 | # Function for getting sentence spans 64 | def sent_spans(text: str) -> List[Tuple[int, int]]: 65 | """ 66 | Get sentence spans (start, end character positions) using the default pre-trained model. 67 | 68 | This is a convenience function for getting sentence spans without having 69 | to explicitly load a model. The spans are guaranteed to be contiguous, 70 | covering the entire input text without gaps. 71 | 72 | Args: 73 | text: The text to segment 74 | 75 | Returns: 76 | A list of sentence spans as (start_index, end_index) tuples 77 | """ 78 | from sentences import SentenceSegmenter 79 | 80 | segmenter = _get_default_model() 81 | return list(SentenceSegmenter.get_sentence_spans(segmenter, text)) 82 | 83 | 84 | # Function for getting sentence spans with text 85 | def sent_spans_with_text(text: str) -> List[Tuple[str, Tuple[int, int]]]: 86 | """ 87 | Get sentences with their spans using the default pre-trained model. 88 | 89 | This is a convenience function for getting sentences with their character spans 90 | without having to explicitly load a model. The spans are guaranteed to be 91 | contiguous, covering the entire input text without gaps. 92 | 93 | Args: 94 | text: The text to segment 95 | 96 | Returns: 97 | A list of tuples containing (sentence, (start_index, end_index)) 98 | """ 99 | from sentences import SentenceSegmenter 100 | 101 | segmenter = _get_default_model() 102 | return list(SentenceSegmenter.get_sentence_spans_with_text(segmenter, text)) 103 | 104 | 105 | # Function for paragraph tokenization 106 | def para_tokenize(text: str) -> List[str]: 107 | """ 108 | Tokenize text into paragraphs using the default pre-trained model. 109 | 110 | Paragraph breaks are identified at sentence boundaries that are 111 | immediately followed by two or more newlines. 112 | 113 | Args: 114 | text: The text to tokenize 115 | 116 | Returns: 117 | A list of paragraphs 118 | """ 119 | paragraph_tokenizer = _get_paragraph_tokenizer() 120 | return list(paragraph_tokenizer.tokenize(text)) 121 | 122 | 123 | # Function for getting paragraph spans 124 | def para_spans(text: str) -> List[Tuple[int, int]]: 125 | """ 126 | Get paragraph spans (start, end character positions) using the default pre-trained model. 127 | 128 | This is a convenience function for getting paragraph spans without having 129 | to explicitly load a model. The spans are guaranteed to be contiguous, 130 | covering the entire input text without gaps. 131 | 132 | Args: 133 | text: The text to segment 134 | 135 | Returns: 136 | A list of paragraph spans as (start_index, end_index) tuples 137 | """ 138 | paragraph_tokenizer = _get_paragraph_tokenizer() 139 | return list(paragraph_tokenizer.span_tokenize(text)) 140 | 141 | 142 | # Function for getting paragraph spans with text 143 | def para_spans_with_text(text: str) -> List[Tuple[str, Tuple[int, int]]]: 144 | """ 145 | Get paragraphs with their spans using the default pre-trained model. 146 | 147 | This is a convenience function for getting paragraphs with their character spans 148 | without having to explicitly load a model. The spans are guaranteed to be 149 | contiguous, covering the entire input text without gaps. 150 | 151 | Args: 152 | text: The text to segment 153 | 154 | Returns: 155 | A list of tuples containing (paragraph, (start_index, end_index)) 156 | """ 157 | paragraph_tokenizer = _get_paragraph_tokenizer() 158 | return list(paragraph_tokenizer.tokenize_with_spans(text)) 159 | 160 | 161 | __all__ = [ 162 | "PunktParameters", 163 | "PunktLanguageVars", 164 | "PunktToken", 165 | "PunktTrainer", 166 | "PunktSentenceTokenizer", 167 | "PunktParagraphTokenizer", 168 | "load_default_model", 169 | "sent_tokenize", 170 | "sent_spans", 171 | "sent_spans_with_text", 172 | "para_tokenize", 173 | "para_spans", 174 | "para_spans_with_text", 175 | ] -------------------------------------------------------------------------------- /nupunkt/core/__init__.py: -------------------------------------------------------------------------------- 1 | """Core components for nupunkt.""" 2 | -------------------------------------------------------------------------------- /nupunkt/core/base.py: -------------------------------------------------------------------------------- 1 | """ 2 | Base module for nupunkt. 3 | 4 | This module provides the base class for Punkt tokenizers and trainers. 5 | """ 6 | 7 | from functools import lru_cache 8 | from typing import Iterator, Optional, Type 9 | 10 | from nupunkt.core.constants import ABBREV_CACHE_SIZE 11 | from nupunkt.core.language_vars import PunktLanguageVars 12 | from nupunkt.core.parameters import PunktParameters 13 | from nupunkt.core.tokens import PunktToken, create_punkt_token 14 | 15 | 16 | @lru_cache(maxsize=ABBREV_CACHE_SIZE) 17 | def is_abbreviation(abbrev_set: frozenset, candidate: str) -> bool: 18 | """ 19 | Check if a candidate is a known abbreviation, using cached lookups. 20 | 21 | Args: 22 | abbrev_set: A frozenset of known abbreviations 23 | candidate: The candidate string to check 24 | 25 | Returns: 26 | True if the candidate is a known abbreviation, False otherwise 27 | """ 28 | # Check if the token itself is a known abbreviation 29 | if candidate in abbrev_set: 30 | return True 31 | 32 | # Check if the last part after a dash is a known abbreviation 33 | if "-" in candidate: 34 | dash_part = candidate.split("-")[-1] 35 | if dash_part in abbrev_set: 36 | return True 37 | 38 | # Special handling for period-separated abbreviations like U.S.C. 39 | # Check if the version without internal periods is in abbrev_types 40 | if "." in candidate: 41 | no_periods = candidate.replace(".", "") 42 | if no_periods in abbrev_set: 43 | return True 44 | 45 | return False 46 | 47 | 48 | class PunktBase: 49 | """ 50 | Base class for Punkt tokenizers and trainers. 51 | 52 | This class provides common functionality used by both the trainer and tokenizer, 53 | including tokenization and first-pass annotation of tokens. 54 | """ 55 | 56 | def __init__( 57 | self, 58 | lang_vars: Optional[PunktLanguageVars] = None, 59 | token_cls: Type[PunktToken] = PunktToken, 60 | params: Optional[PunktParameters] = None, 61 | ) -> None: 62 | """ 63 | Initialize the PunktBase instance. 64 | 65 | Args: 66 | lang_vars: Language-specific variables 67 | token_cls: The token class to use 68 | params: Punkt parameters 69 | """ 70 | self._lang_vars = lang_vars or PunktLanguageVars() 71 | self._Token = token_cls 72 | self._params = params or PunktParameters() 73 | 74 | def _tokenize_words(self, plaintext: str) -> Iterator[PunktToken]: 75 | """ 76 | Tokenize text into words, maintaining paragraph and line-start information. 77 | 78 | Args: 79 | plaintext: The text to tokenize 80 | 81 | Yields: 82 | PunktToken instances for each token 83 | """ 84 | # Quick check for empty text 85 | if not plaintext: 86 | return 87 | 88 | parastart = False 89 | # Split by lines - this is more efficient than using splitlines() 90 | for line in plaintext.split("\n"): 91 | # Check if line has any content 92 | if line.strip(): 93 | tokens = self._lang_vars.word_tokenize(line) 94 | if tokens: 95 | # First token gets parastart and linestart flags 96 | # Use the factory function to benefit from caching 97 | if issubclass(self._Token, PunktToken): 98 | # Use our optimized factory function if the token class is PunktToken 99 | yield create_punkt_token(tokens[0], parastart=parastart, linestart=True) 100 | 101 | # Process remaining tokens in a batch when possible 102 | for tok in tokens[1:]: 103 | yield create_punkt_token(tok) 104 | else: 105 | # Fallback for custom token classes 106 | yield self._Token(tokens[0], parastart=parastart, linestart=True) 107 | 108 | # Process remaining tokens in a batch when possible 109 | for tok in tokens[1:]: 110 | yield self._Token(tok) 111 | parastart = False 112 | else: 113 | parastart = True 114 | 115 | def _annotate_first_pass(self, tokens: Iterator[PunktToken]) -> Iterator[PunktToken]: 116 | """ 117 | Perform first-pass annotation on tokens. 118 | 119 | This annotates tokens with sentence breaks, abbreviations, and ellipses. 120 | 121 | Args: 122 | tokens: The tokens to annotate 123 | 124 | Yields: 125 | Annotated tokens 126 | """ 127 | for token in tokens: 128 | self._first_pass_annotation(token) 129 | yield token 130 | 131 | def _first_pass_annotation(self, token: PunktToken) -> None: 132 | """ 133 | Annotate a token with sentence breaks, abbreviations, and ellipses. 134 | 135 | Args: 136 | token: The token to annotate 137 | """ 138 | if token.tok in self._lang_vars.sent_end_chars: 139 | token.sentbreak = True 140 | elif token.is_ellipsis: 141 | token.ellipsis = True 142 | # Don't mark as sentence break now - will be decided in second pass 143 | # based on what follows the ellipsis 144 | token.sentbreak = False 145 | elif token.period_final and not token.tok.endswith(".."): 146 | # If token is not a valid abbreviation candidate, mark it as a sentence break 147 | if not token.valid_abbrev_candidate: 148 | token.sentbreak = True 149 | else: 150 | # For valid candidates, check if they are known abbreviations 151 | candidate = token.tok[:-1].lower() 152 | 153 | # Use frozen set for faster lookups if available 154 | abbrev_set = ( 155 | getattr(self._params, "_frozen_abbrev_types", None) or self._params.abbrev_types 156 | ) 157 | 158 | # Convert to frozenset if it's not already (for caching) 159 | if not isinstance(abbrev_set, frozenset): 160 | abbrev_set = frozenset(abbrev_set) 161 | 162 | # Use the module-level cached function 163 | if is_abbreviation(abbrev_set, candidate): 164 | token.abbr = True 165 | else: 166 | token.sentbreak = True 167 | 168 | def _is_abbreviation(self, candidate: str) -> bool: 169 | """ 170 | Check if a candidate is a known abbreviation. 171 | 172 | This is a wrapper around the module-level cached function. 173 | 174 | Args: 175 | candidate: The candidate string to check 176 | 177 | Returns: 178 | True if the candidate is a known abbreviation, False otherwise 179 | """ 180 | # Use frozen set for faster lookups if available 181 | abbrev_set = ( 182 | getattr(self._params, "_frozen_abbrev_types", None) or self._params.abbrev_types 183 | ) 184 | 185 | # Convert to frozenset if it's not already (for caching) 186 | if not isinstance(abbrev_set, frozenset): 187 | abbrev_set = frozenset(abbrev_set) 188 | 189 | return is_abbreviation(abbrev_set, candidate) 190 | -------------------------------------------------------------------------------- /nupunkt/core/constants.py: -------------------------------------------------------------------------------- 1 | """ 2 | Constants module for nupunkt. 3 | 4 | This module provides constants used in the Punkt algorithm, 5 | including orthographic context and cache configuration. 6 | """ 7 | 8 | from typing import Dict, Tuple 9 | 10 | # ------------------------------------------------------------------- 11 | # Orthographic Context Constants 12 | # ------------------------------------------------------------------- 13 | 14 | # Bit flags for orthographic contexts 15 | ORTHO_BEG_UC = 1 << 1 # Beginning of sentence, uppercase 16 | ORTHO_MID_UC = 1 << 2 # Middle of sentence, uppercase 17 | ORTHO_UNK_UC = 1 << 3 # Unknown position, uppercase 18 | ORTHO_BEG_LC = 1 << 4 # Beginning of sentence, lowercase 19 | ORTHO_MID_LC = 1 << 5 # Middle of sentence, lowercase 20 | ORTHO_UNK_LC = 1 << 6 # Unknown position, lowercase 21 | 22 | # Combined flags 23 | ORTHO_UC = ORTHO_BEG_UC | ORTHO_MID_UC | ORTHO_UNK_UC # Any uppercase 24 | ORTHO_LC = ORTHO_BEG_LC | ORTHO_MID_LC | ORTHO_UNK_LC # Any lowercase 25 | 26 | # Mapping from (position, case) to flag 27 | ORTHO_MAP: Dict[Tuple[str, str], int] = { 28 | ("initial", "upper"): ORTHO_BEG_UC, 29 | ("internal", "upper"): ORTHO_MID_UC, 30 | ("unknown", "upper"): ORTHO_UNK_UC, 31 | ("initial", "lower"): ORTHO_BEG_LC, 32 | ("internal", "lower"): ORTHO_MID_LC, 33 | ("unknown", "lower"): ORTHO_UNK_LC, 34 | } 35 | 36 | # ------------------------------------------------------------------- 37 | # Caching Constants 38 | # ------------------------------------------------------------------- 39 | 40 | # LRU cache sizes for various caching operations 41 | # These can be adjusted based on memory constraints and desired performance 42 | 43 | # Cache size for abbreviation checks - moderate size as number of abbreviations is usually limited 44 | ABBREV_CACHE_SIZE = 4096 # Power of 2 (2^12) 45 | 46 | # Cache size for token creation and property caching - larger as token variety is high 47 | TOKEN_CACHE_SIZE = 32768 # Power of 2 (2^15) - critical for performance 48 | 49 | # Cache size for orthographic heuristics - frequently used so needs to be large 50 | ORTHO_CACHE_SIZE = 8192 # Power of 2 (2^13) 51 | 52 | # Cache size for sentence starter checks - less variety than tokens 53 | SENT_STARTER_CACHE_SIZE = 4096 # Power of 2 (2^12) 54 | 55 | # Cache size for token type calculations - moderate variety 56 | TOKEN_TYPE_CACHE_SIZE = 16384 # Power of 2 (2^14) 57 | 58 | # Cache size for document-level tokenization results - benchmarks showed this is critical 59 | DOC_TOKENIZE_CACHE_SIZE = 8192 # Power of 2 (2^13) 60 | 61 | # Cache size for paragraph-level caching in tokenizer - moderate usage 62 | PARA_TOKENIZE_CACHE_SIZE = 8192 # Power of 2 (2^13) 63 | 64 | # Cache size for whitespace index lookups - pattern is less varied 65 | WHITESPACE_CACHE_SIZE = 2048 # Power of 2 (2^11) 66 | -------------------------------------------------------------------------------- /nupunkt/core/language_vars.py: -------------------------------------------------------------------------------- 1 | """ 2 | Language variables module for nupunkt. 3 | 4 | This module provides language-specific variables and settings 5 | for sentence boundary detection, which can be customized or 6 | extended for different languages. 7 | """ 8 | 9 | import re 10 | from typing import List, Optional 11 | 12 | 13 | class PunktLanguageVars: 14 | """ 15 | Contains language-specific variables for Punkt sentence boundary detection. 16 | 17 | This class encapsulates language-specific behavior, such as the 18 | characters that indicate sentence boundaries and the regular expressions 19 | used for various pattern matching tasks in the tokenization process. 20 | """ 21 | 22 | # Use frozenset for O(1) membership testing instead of tuple 23 | sent_end_chars: frozenset = frozenset((".", "?", "!")) 24 | internal_punctuation: str = ",:;" 25 | re_boundary_realignment: re.Pattern = re.compile(r'[\'"\)\]}]+?(?:\s+|(?=--)|$)', re.MULTILINE) 26 | _re_word_start: str = r"[^\(\"\`{\[:;&\#\*@\)}\]\-,]" 27 | _re_multi_char_punct: str = r"(?:\-{2,}|\.{2,}|(?:\.\s+){1,}\.|\u2026)" 28 | 29 | def __init__(self) -> None: 30 | """Initialize language variables with language-specific settings.""" 31 | self._re_period_context: Optional[re.Pattern] = None 32 | self._re_word_tokenizer: Optional[re.Pattern] = None 33 | 34 | @property 35 | def _re_sent_end_chars(self) -> str: 36 | """ 37 | Returns a regex pattern string for all sentence-ending characters. 38 | 39 | Returns: 40 | str: A pattern matching any sentence ending character 41 | """ 42 | return f"[{re.escape(''.join(self.sent_end_chars))}]" 43 | 44 | @property 45 | def _re_non_word_chars(self) -> str: 46 | """ 47 | Returns a regex pattern for characters that can never start a word. 48 | 49 | Returns: 50 | str: A pattern matching non-word-starting characters 51 | """ 52 | # Exclude characters that can never start a word 53 | nonword = "".join(set(self.sent_end_chars) - {"."}) 54 | return rf"(?:[)\";}}\]\*:@\'\({{[\s{re.escape(nonword)}])" 55 | 56 | @property 57 | def word_tokenize_pattern(self) -> re.Pattern: 58 | """ 59 | Returns a compiled regex pattern for tokenizing words. 60 | 61 | Returns: 62 | re.Pattern: The compiled regular expression for word tokenization 63 | """ 64 | if self._re_word_tokenizer is None: 65 | pattern = rf"""( 66 | {self._re_multi_char_punct} 67 | | 68 | (?={self._re_word_start})\S+? 69 | (?= 70 | \s| 71 | $| 72 | {self._re_non_word_chars}| 73 | {self._re_multi_char_punct}| 74 | ,(?=$|\s|{self._re_non_word_chars}|{self._re_multi_char_punct}) 75 | ) 76 | | 77 | \S 78 | )""" 79 | self._re_word_tokenizer = re.compile(pattern, re.UNICODE | re.VERBOSE) 80 | return self._re_word_tokenizer 81 | 82 | def word_tokenize(self, text: str) -> List[str]: 83 | """ 84 | Tokenize text into words using the word_tokenize_pattern. 85 | 86 | Args: 87 | text: The text to tokenize 88 | 89 | Returns: 90 | A list of word tokens 91 | """ 92 | return self.word_tokenize_pattern.findall(text) 93 | 94 | @property 95 | def period_context_pattern(self) -> re.Pattern: 96 | """ 97 | Returns a compiled regex pattern for finding periods in context. 98 | 99 | Returns: 100 | re.Pattern: The compiled regular expression for period contexts 101 | """ 102 | if self._re_period_context is None: 103 | pattern = rf""" 104 | {self._re_sent_end_chars} 105 | (?=(?P 106 | {self._re_non_word_chars}| 107 | \s+(?P\S+) 108 | )) 109 | """ 110 | self._re_period_context = re.compile(pattern, re.UNICODE | re.VERBOSE) 111 | return self._re_period_context 112 | -------------------------------------------------------------------------------- /nupunkt/core/parameters.py: -------------------------------------------------------------------------------- 1 | """ 2 | PunktParameters module - Contains the parameters for the Punkt algorithm. 3 | """ 4 | 5 | import re 6 | from collections import defaultdict 7 | from dataclasses import dataclass, field 8 | from pathlib import Path 9 | from typing import Any, Dict, Optional, Pattern, Set, Tuple, Union 10 | 11 | from nupunkt.utils.compression import ( 12 | load_compressed_json, 13 | save_binary_model, 14 | save_compressed_json, 15 | ) 16 | 17 | 18 | @dataclass 19 | class PunktParameters: 20 | """ 21 | Stores the parameters that Punkt uses for sentence boundary detection. 22 | 23 | This includes: 24 | - Abbreviation types 25 | - Collocations 26 | - Sentence starters 27 | - Orthographic context 28 | """ 29 | 30 | abbrev_types: Set[str] = field(default_factory=set) 31 | collocations: Set[Tuple[str, str]] = field(default_factory=set) 32 | sent_starters: Set[str] = field(default_factory=set) 33 | ortho_context: Dict[str, int] = field(default_factory=lambda: defaultdict(int)) 34 | 35 | # Cached regex patterns for efficient lookups 36 | _abbrev_pattern: Optional[Pattern] = field(default=None, repr=False) 37 | _sent_starter_pattern: Optional[Pattern] = field(default=None, repr=False) 38 | 39 | def __post_init__(self) -> None: 40 | """Initialize any derived attributes after instance creation.""" 41 | # Patterns will be compiled on first use 42 | # Initialize frozen sets to empty frozensets 43 | self._frozen_abbrev_types = frozenset() 44 | self._frozen_collocations = frozenset() 45 | self._frozen_sent_starters = frozenset() 46 | 47 | def get_abbrev_pattern(self) -> Pattern: 48 | """ 49 | Get a compiled regex pattern for matching abbreviations. 50 | 51 | The pattern is compiled on first use and cached for subsequent calls. 52 | 53 | Returns: 54 | A compiled regex pattern that matches any abbreviation in abbrev_types 55 | """ 56 | if not self._abbrev_pattern or len(self._abbrev_pattern.pattern) == 0: 57 | if not self.abbrev_types: 58 | # If no abbreviations, create a pattern that will never match 59 | self._abbrev_pattern = re.compile(r"^$") 60 | else: 61 | # Escape abbreviations and sort by length (longest first) to ensure proper matching 62 | escaped_abbrevs = [re.escape(abbr) for abbr in self.abbrev_types] 63 | sorted_abbrevs = sorted(escaped_abbrevs, key=len, reverse=True) 64 | pattern = r"^(?:" + "|".join(sorted_abbrevs) + r")$" 65 | self._abbrev_pattern = re.compile(pattern, re.IGNORECASE) 66 | return self._abbrev_pattern 67 | 68 | def get_sent_starter_pattern(self) -> Pattern: 69 | """ 70 | Get a compiled regex pattern for matching sentence starters. 71 | 72 | The pattern is compiled on first use and cached for subsequent calls. 73 | 74 | Returns: 75 | A compiled regex pattern that matches any sentence starter 76 | """ 77 | if not self._sent_starter_pattern or len(self._sent_starter_pattern.pattern) == 0: 78 | if not self.sent_starters: 79 | # If no sentence starters, create a pattern that will never match 80 | self._sent_starter_pattern = re.compile(r"^$") 81 | else: 82 | # Escape sentence starters and sort by length (longest first) 83 | escaped_starters = [re.escape(starter) for starter in self.sent_starters] 84 | sorted_starters = sorted(escaped_starters, key=len, reverse=True) 85 | pattern = r"^(?:" + "|".join(sorted_starters) + r")$" 86 | self._sent_starter_pattern = re.compile(pattern, re.IGNORECASE) 87 | return self._sent_starter_pattern 88 | 89 | def add_ortho_context(self, typ: str, flag: int) -> None: 90 | """ 91 | Add an orthographic context flag to a token type. 92 | 93 | Args: 94 | typ: The token type 95 | flag: The orthographic context flag 96 | """ 97 | self.ortho_context[typ] |= flag 98 | 99 | def add_abbreviation(self, abbrev: str) -> None: 100 | """ 101 | Add a single abbreviation and invalidate the cached pattern. 102 | 103 | Args: 104 | abbrev: The abbreviation to add 105 | """ 106 | self.abbrev_types.add(abbrev) 107 | self._abbrev_pattern = None 108 | 109 | def add_sent_starter(self, starter: str) -> None: 110 | """ 111 | Add a single sentence starter and invalidate the cached pattern. 112 | 113 | Args: 114 | starter: The sentence starter to add 115 | """ 116 | self.sent_starters.add(starter) 117 | self._sent_starter_pattern = None 118 | 119 | def invalidate_patterns(self) -> None: 120 | """Invalidate cached regex patterns when sets are modified.""" 121 | self._abbrev_pattern = None 122 | self._sent_starter_pattern = None 123 | 124 | def freeze_sets(self) -> None: 125 | """ 126 | Freeze the mutable sets to create immutable frozensets for faster lookups. 127 | 128 | Call this method after training is complete to optimize for inference speed. 129 | """ 130 | self._frozen_abbrev_types = frozenset(self.abbrev_types) 131 | self._frozen_collocations = frozenset(self.collocations) 132 | self._frozen_sent_starters = frozenset(self.sent_starters) 133 | 134 | def update_abbrev_types(self, abbrevs: Set[str]) -> None: 135 | """ 136 | Update abbreviation types and invalidate the cached pattern. 137 | 138 | Args: 139 | abbrevs: Set of abbreviations to add 140 | """ 141 | self.abbrev_types.update(abbrevs) 142 | self._abbrev_pattern = None 143 | 144 | def update_sent_starters(self, starters: Set[str]) -> None: 145 | """ 146 | Update sentence starters and invalidate the cached pattern. 147 | 148 | Args: 149 | starters: Set of sentence starters to add 150 | """ 151 | self.sent_starters.update(starters) 152 | self._sent_starter_pattern = None 153 | 154 | def to_json(self) -> Dict[str, Any]: 155 | """Convert parameters to a JSON-serializable dictionary.""" 156 | return { 157 | "abbrev_types": sorted(self.abbrev_types), 158 | "collocations": sorted([[c[0], c[1]] for c in self.collocations]), 159 | "sent_starters": sorted(self.sent_starters), 160 | "ortho_context": dict(self.ortho_context.items()), 161 | } 162 | 163 | @classmethod 164 | def from_json(cls, data: Dict[str, Any]) -> "PunktParameters": 165 | """Create a PunktParameters instance from a JSON dictionary.""" 166 | params = cls() 167 | params.abbrev_types = set(data.get("abbrev_types", [])) 168 | params.collocations = {tuple(c) for c in data.get("collocations", [])} 169 | params.sent_starters = set(data.get("sent_starters", [])) 170 | params.ortho_context = defaultdict(int) 171 | for k, v in data.get("ortho_context", {}).items(): 172 | params.ortho_context[k] = int(v) # Ensure value is int 173 | 174 | # Don't pre-compile patterns by default 175 | # Direct set lookup is faster based on benchmarks 176 | 177 | # Create frozen sets for faster lookups during inference 178 | params.freeze_sets() 179 | 180 | return params 181 | 182 | def save( 183 | self, 184 | file_path: Union[str, Path], 185 | format_type: str = "json_xz", 186 | compression_level: int = 1, 187 | compression_method: str = "zlib", 188 | ) -> None: 189 | """ 190 | Save parameters to a file using the specified format and compression. 191 | 192 | Args: 193 | file_path: The path to save the file to 194 | format_type: The format type to use ('json', 'json_xz', 'binary') 195 | compression_level: Compression level (0-9), lower is faster but less compressed 196 | compression_method: Compression method for binary format ('none', 'zlib', 'lzma', 'gzip') 197 | """ 198 | if format_type == "binary": 199 | save_binary_model( 200 | self.to_json(), 201 | file_path, 202 | compression_method=compression_method, 203 | level=compression_level, 204 | ) 205 | else: 206 | save_compressed_json( 207 | self.to_json(), 208 | file_path, 209 | level=compression_level, 210 | use_compression=(format_type == "json_xz"), 211 | ) 212 | 213 | @classmethod 214 | def load(cls, file_path: Union[str, Path]) -> "PunktParameters": 215 | """ 216 | Load parameters from a file in any supported format. 217 | 218 | This method automatically detects the file format based on extension 219 | and loads the parameters accordingly. 220 | 221 | Args: 222 | file_path: The path to the file 223 | 224 | Returns: 225 | A new PunktParameters instance 226 | """ 227 | # The load_compressed_json function will try to detect if it's a binary file 228 | data = load_compressed_json(file_path) 229 | 230 | # Handle binary format which is wrapped in a "parameters" key 231 | if "parameters" in data: 232 | return cls.from_json(data["parameters"]) 233 | 234 | return cls.from_json(data) 235 | -------------------------------------------------------------------------------- /nupunkt/core/tokens.py: -------------------------------------------------------------------------------- 1 | """ 2 | Token module for nupunkt. 3 | 4 | This module provides the PunktToken class, which represents a token 5 | in the Punkt algorithm and calculates various derived properties. 6 | """ 7 | 8 | import re 9 | from functools import lru_cache 10 | from typing import Dict, Tuple 11 | 12 | # Compiled regex patterns for better performance 13 | _RE_NON_WORD_DOT = re.compile(r"[^\w.]") 14 | _RE_NUMBER = re.compile(r"^-?[\.,]?\d[\d,\.-]*\.?$") 15 | _RE_ELLIPSIS = re.compile(r"\.\.+$") 16 | _RE_SPACED_ELLIPSIS = re.compile(r"\.\s+\.\s+\.") 17 | _RE_INITIAL = re.compile(r"[^\W\d]\.") 18 | _RE_ALPHA = re.compile(r"[^\W\d]+") 19 | _RE_NON_PUNCT = re.compile(r"[^\W\d]") 20 | 21 | 22 | # LRU-cached functions for token classification to improve performance 23 | # Use a smaller cache for common tokens only 24 | @lru_cache(maxsize=500) 25 | def _check_is_ellipsis(tok: str) -> bool: 26 | """ 27 | Cached function to check if a token represents an ellipsis. 28 | 29 | Args: 30 | tok: The token to check 31 | 32 | Returns: 33 | True if the token is an ellipsis, False otherwise 34 | """ 35 | # Check for standard ellipsis (... or longer) 36 | if bool(_RE_ELLIPSIS.search(tok)): 37 | return True 38 | 39 | # Check for unicode ellipsis 40 | if tok == "\u2026" or tok.endswith("\u2026"): 41 | return True 42 | 43 | # Check for spaced ellipsis (. . ., . . ., etc.) 44 | return bool(_RE_SPACED_ELLIPSIS.search(tok)) 45 | 46 | 47 | @lru_cache(maxsize=500) 48 | def _check_is_initial(tok: str) -> bool: 49 | """ 50 | Cached function to check if a token is an initial. 51 | 52 | Args: 53 | tok: The token to check 54 | 55 | Returns: 56 | True if the token is an initial, False otherwise 57 | """ 58 | return bool(_RE_INITIAL.fullmatch(tok)) 59 | 60 | 61 | @lru_cache(maxsize=1000) 62 | def _check_is_alpha(tok: str) -> bool: 63 | """ 64 | Cached function to check if a token is alphabetic. 65 | 66 | Args: 67 | tok: The token to check 68 | 69 | Returns: 70 | True if the token is alphabetic, False otherwise 71 | """ 72 | return bool(_RE_ALPHA.fullmatch(tok)) 73 | 74 | 75 | @lru_cache(maxsize=1000) 76 | def _check_is_non_punct(typ: str) -> bool: 77 | """ 78 | Cached function to check if a token type contains non-punctuation. 79 | 80 | Args: 81 | typ: The token type to check 82 | 83 | Returns: 84 | True if the token type contains non-punctuation, False otherwise 85 | """ 86 | return bool(_RE_NON_PUNCT.search(typ)) 87 | 88 | 89 | @lru_cache(maxsize=2000) # Increased cache size for token types 90 | def _get_token_type(tok: str) -> str: 91 | """ 92 | Get the normalized type of a token (cached for better performance). 93 | 94 | Args: 95 | tok: The token string 96 | 97 | Returns: 98 | The normalized type (##number## for numbers, lowercase form for others) 99 | """ 100 | # Normalize numbers 101 | if _RE_NUMBER.match(tok): 102 | return "##number##" 103 | return tok.lower() 104 | 105 | 106 | @lru_cache(maxsize=1000) 107 | def _get_type_no_period(type_str: str) -> str: 108 | """Get the token type without a trailing period (cached).""" 109 | return type_str[:-1] if type_str.endswith(".") and len(type_str) > 1 else type_str 110 | 111 | 112 | # Module-level cache for PunktToken instances 113 | _token_instance_cache: Dict[Tuple[str, bool, bool], "PunktToken"] = {} 114 | _TOKEN_CACHE_SIZE = 2000 # Increased from original 1000 115 | 116 | 117 | def create_punkt_token(tok: str, parastart: bool = False, linestart: bool = False) -> "PunktToken": 118 | """ 119 | Factory function to create PunktToken instances with caching. 120 | 121 | Args: 122 | tok: Token text 123 | parastart: Whether the token starts a paragraph 124 | linestart: Whether the token starts a line 125 | 126 | Returns: 127 | A new or cached PunktToken instance 128 | """ 129 | # Only cache smaller tokens (most common case) 130 | if len(tok) < 15: 131 | cache_key = (tok, parastart, linestart) 132 | token = _token_instance_cache.get(cache_key) 133 | if token is not None: 134 | return token 135 | 136 | token = PunktToken(tok, parastart, linestart) 137 | 138 | # Add to cache if not full 139 | if len(_token_instance_cache) < _TOKEN_CACHE_SIZE: 140 | _token_instance_cache[cache_key] = token 141 | return token 142 | 143 | # For longer tokens, just create a new instance 144 | return PunktToken(tok, parastart, linestart) 145 | 146 | 147 | class PunktToken: 148 | """ 149 | Represents a token in the Punkt algorithm. 150 | 151 | This class contains the token string and various properties and flags that 152 | indicate its role in sentence boundary detection. 153 | 154 | Uses __slots__ for memory efficiency, especially for large documents 155 | where millions of token instances are created. 156 | """ 157 | 158 | __slots__ = ( 159 | "tok", 160 | "parastart", 161 | "linestart", 162 | "sentbreak", 163 | "abbr", 164 | "ellipsis", 165 | "period_final", 166 | "type", 167 | "valid_abbrev_candidate", 168 | "_first_upper", 169 | "_first_lower", 170 | "_type_no_period", 171 | "_type_no_sentperiod", 172 | "_is_ellipsis", 173 | "_is_number", 174 | "_is_initial", 175 | "_is_alpha", 176 | "_is_non_punct", 177 | ) 178 | 179 | # Define allowed characters for fast punctuation check (alphanumeric + period) 180 | _ALLOWED_CHARS = set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789.") 181 | 182 | def __init__(self, tok: str, parastart: bool = False, linestart: bool = False) -> None: 183 | """ 184 | Initialize a new PunktToken instance. 185 | 186 | Args: 187 | tok: The token string 188 | parastart: Whether this token starts a paragraph 189 | linestart: Whether this token starts a line 190 | """ 191 | # Initialize base attributes 192 | self.tok = tok 193 | self.parastart = parastart 194 | self.linestart = linestart 195 | self.sentbreak = False 196 | self.abbr = False 197 | self.ellipsis = False 198 | 199 | # Initialize computed attributes 200 | self.period_final = tok.endswith(".") 201 | self.type = _get_token_type(tok) 202 | 203 | # Pre-compute frequently accessed properties 204 | tok_len = len(tok) 205 | self._first_upper = tok_len > 0 and tok[0].isupper() 206 | self._first_lower = tok_len > 0 and tok[0].islower() 207 | 208 | # Initialize lazily computed properties (will be set on first access) 209 | self._type_no_period = None 210 | self._type_no_sentperiod = None 211 | self._is_ellipsis = None 212 | self._is_number = None 213 | self._is_initial = None 214 | self._is_alpha = None 215 | self._is_non_punct = None 216 | 217 | # Fast check for invalid characters (non-alphanumeric and non-period) 218 | has_invalid_char = False 219 | for c in tok: 220 | if c not in self._ALLOWED_CHARS: 221 | has_invalid_char = True 222 | break 223 | 224 | if self.period_final and not has_invalid_char: 225 | # For tokens with internal periods (like U.S.C), get non-period chars 226 | # Use more efficient counting method 227 | alpha_count = 0 228 | digit_count = 0 229 | for c in tok: 230 | if c != ".": 231 | if c.isalpha(): 232 | alpha_count += 1 233 | elif c.isdigit(): 234 | digit_count += 1 235 | 236 | self.valid_abbrev_candidate = ( 237 | self.type != "##number##" and alpha_count >= digit_count and alpha_count > 0 238 | ) 239 | else: 240 | self.valid_abbrev_candidate = False 241 | 242 | # If token has a period but isn't valid candidate, reset abbr flag 243 | if self.period_final and not self.valid_abbrev_candidate: 244 | self.abbr = False 245 | 246 | @property 247 | def type_no_period(self) -> str: 248 | """Get the token type without a trailing period.""" 249 | if self._type_no_period is None: 250 | self._type_no_period = _get_type_no_period(self.type) 251 | return self._type_no_period 252 | 253 | @property 254 | def type_no_sentperiod(self) -> str: 255 | """Get the token type without a sentence-final period.""" 256 | if self._type_no_sentperiod is None: 257 | self._type_no_sentperiod = self.type_no_period if self.sentbreak else self.type 258 | return self._type_no_sentperiod 259 | 260 | @property 261 | def first_upper(self) -> bool: 262 | """Check if the first character of the token is uppercase.""" 263 | return self._first_upper 264 | 265 | @property 266 | def first_lower(self) -> bool: 267 | """Check if the first character of the token is lowercase.""" 268 | return self._first_lower 269 | 270 | @property 271 | def first_case(self) -> str: 272 | """Get the case of the first character of the token.""" 273 | if self.first_lower: 274 | return "lower" 275 | if self.first_upper: 276 | return "upper" 277 | return "none" 278 | 279 | @property 280 | def is_ellipsis(self) -> bool: 281 | """ 282 | Check if the token is an ellipsis (any of the following patterns): 283 | 1. Multiple consecutive periods (..., ......) 284 | 2. Unicode ellipsis character (…) 285 | 3. Periods separated by spaces (. . ., . . .) 286 | """ 287 | if self._is_ellipsis is None: 288 | self._is_ellipsis = _check_is_ellipsis(self.tok) 289 | return self._is_ellipsis 290 | 291 | @property 292 | def is_number(self) -> bool: 293 | """Check if the token is a number.""" 294 | if self._is_number is None: 295 | self._is_number = self.type.startswith("##number##") 296 | return self._is_number 297 | 298 | @property 299 | def is_initial(self) -> bool: 300 | """Check if the token is an initial (single letter followed by a period).""" 301 | if self._is_initial is None: 302 | self._is_initial = _check_is_initial(self.tok) 303 | return self._is_initial 304 | 305 | @property 306 | def is_alpha(self) -> bool: 307 | """Check if the token is alphabetic (contains only letters).""" 308 | if self._is_alpha is None: 309 | self._is_alpha = _check_is_alpha(self.tok) 310 | return self._is_alpha 311 | 312 | @property 313 | def is_non_punct(self) -> bool: 314 | """Check if the token contains non-punctuation characters.""" 315 | if self._is_non_punct is None: 316 | self._is_non_punct = _check_is_non_punct(self.type) 317 | return self._is_non_punct 318 | 319 | def __str__(self) -> str: 320 | """Get a string representation of the token with annotation flags.""" 321 | s = self.tok 322 | if self.abbr: 323 | s += "" 324 | if self.ellipsis: 325 | s += "" 326 | if self.sentbreak: 327 | s += "" 328 | return s 329 | 330 | def __repr__(self) -> str: 331 | """Get a detailed string representation of the token.""" 332 | return ( 333 | f"PunktToken(tok='{self.tok}', parastart={self.parastart}, " 334 | f"linestart={self.linestart}, sentbreak={self.sentbreak}, " 335 | f"abbr={self.abbr}, ellipsis={self.ellipsis})" 336 | ) 337 | -------------------------------------------------------------------------------- /nupunkt/models/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Model package for nupunkt. 3 | 4 | This module provides functionality for loading and optimizing the default pre-trained model. 5 | """ 6 | 7 | from pathlib import Path 8 | from typing import Dict, Optional, Union 9 | 10 | from nupunkt.tokenizers.sentence_tokenizer import PunktSentenceTokenizer 11 | from nupunkt.utils.compression import ( 12 | compare_formats, 13 | load_compressed_json, 14 | save_binary_model, 15 | save_compressed_json, 16 | ) 17 | 18 | 19 | def get_default_model_path() -> Path: 20 | """ 21 | Get the path to the default pre-trained model. 22 | 23 | The function searches for models in priority order: 24 | 1. Binary format (.bin) 25 | 2. Compressed JSON (.json.xz) 26 | 3. Uncompressed JSON (.json) 27 | 28 | Returns: 29 | Path: The path to the default model file 30 | """ 31 | base_dir = Path(__file__).parent 32 | 33 | # Check for binary model first (most efficient format) 34 | binary_path = base_dir / "default_model.bin" 35 | if binary_path.exists(): 36 | return binary_path 37 | 38 | # Check for compressed model next 39 | compressed_path = base_dir / "default_model.json.xz" 40 | if compressed_path.exists(): 41 | return compressed_path 42 | 43 | # Fall back to uncompressed model 44 | return base_dir / "default_model.json" 45 | 46 | 47 | def load_default_model() -> PunktSentenceTokenizer: 48 | """ 49 | Load the default pre-trained model. 50 | 51 | Returns: 52 | PunktSentenceTokenizer: A tokenizer initialized with the default model 53 | """ 54 | model_path = get_default_model_path() 55 | return PunktSentenceTokenizer.load(model_path) 56 | 57 | 58 | def optimize_default_model( 59 | output_path: Optional[Union[str, Path]] = None, 60 | format_type: str = "binary", 61 | compression_method: str = "lzma", 62 | compression_level: int = 6, 63 | ) -> Path: 64 | """ 65 | Optimize the default model using the specified format and compression. 66 | 67 | Args: 68 | output_path: Optional path to save the optimized model. If None, 69 | saves to the default location based on format_type. 70 | format_type: Format to use ("binary", "json_xz", or "json") 71 | compression_method: For binary format, the compression method to use 72 | ("none", "zlib", "lzma", "gzip") 73 | compression_level: Compression level (0-9). Higher means better 74 | compression but slower operation. 75 | 76 | Returns: 77 | Path: The path to the optimized model file 78 | """ 79 | # Get the current model data (from whatever format it's currently in) 80 | current_model_path = get_default_model_path() 81 | data = load_compressed_json(current_model_path) 82 | 83 | # Determine output path and extension based on format 84 | base_dir = Path(__file__).parent 85 | if output_path is None: 86 | if format_type == "binary": 87 | output_path = base_dir / "default_model.bin" 88 | elif format_type == "json_xz": 89 | output_path = base_dir / "default_model.json.xz" 90 | else: 91 | output_path = base_dir / "default_model.json" 92 | else: 93 | output_path = Path(output_path) 94 | 95 | # Save in the requested format 96 | if format_type == "binary": 97 | save_binary_model( 98 | data, output_path, compression_method=compression_method, level=compression_level 99 | ) 100 | else: 101 | save_compressed_json( 102 | data, output_path, level=compression_level, use_compression=(format_type == "json_xz") 103 | ) 104 | 105 | return output_path 106 | 107 | 108 | def compare_model_formats(output_dir: Optional[Union[str, Path]] = None) -> Dict[str, int]: 109 | """ 110 | Compare different storage formats for the default model and output their file sizes. 111 | 112 | This function creates multiple versions of the default model in different formats 113 | and compression settings and returns their file sizes. 114 | 115 | Args: 116 | output_dir: Directory to save test files (if None, uses temp directory) 117 | 118 | Returns: 119 | Dictionary mapping format names to file sizes in bytes 120 | """ 121 | # Load the current model data 122 | current_model_path = get_default_model_path() 123 | data = load_compressed_json(current_model_path) 124 | 125 | # Compare formats and return results 126 | return compare_formats(data, output_dir) 127 | -------------------------------------------------------------------------------- /nupunkt/models/default_model.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alea-institute/nupunkt/29d056aba0f6c9e0f43ee1e36d2638260027af0c/nupunkt/models/default_model.bin -------------------------------------------------------------------------------- /nupunkt/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alea-institute/nupunkt/29d056aba0f6c9e0f43ee1e36d2638260027af0c/nupunkt/py.typed -------------------------------------------------------------------------------- /nupunkt/tokenizers/__init__.py: -------------------------------------------------------------------------------- 1 | """Tokenizer module for nupunkt.""" 2 | 3 | from nupunkt.tokenizers.paragraph_tokenizer import PunktParagraphTokenizer 4 | from nupunkt.tokenizers.sentence_tokenizer import PunktSentenceTokenizer 5 | 6 | __all__ = ["PunktSentenceTokenizer", "PunktParagraphTokenizer"] -------------------------------------------------------------------------------- /nupunkt/tokenizers/paragraph_tokenizer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Paragraph tokenizer module for nupunkt. 3 | 4 | This module provides tokenizer classes for paragraph boundary detection. 5 | """ 6 | 7 | import re 8 | from pathlib import Path 9 | from typing import List, Optional, Tuple, Type, Union 10 | 11 | from nupunkt.core.language_vars import PunktLanguageVars 12 | from nupunkt.core.tokens import PunktToken 13 | from nupunkt.tokenizers.sentence_tokenizer import PunktSentenceTokenizer 14 | 15 | # Precompiled regex pattern for two or more consecutive newlines 16 | # This will efficiently detect paragraph breaks 17 | PARAGRAPH_BREAK_PATTERN = re.compile(r"\n\s*\n+") 18 | 19 | 20 | class PunktParagraphTokenizer: 21 | """ 22 | Paragraph tokenizer using sentence boundaries and newlines. 23 | 24 | This tokenizer identifies paragraph breaks ONLY at sentence boundaries that are 25 | immediately followed by two or more newlines. 26 | """ 27 | 28 | def __init__( 29 | self, 30 | sentence_tokenizer: Optional[PunktSentenceTokenizer] = None, 31 | lang_vars: Optional[PunktLanguageVars] = None, 32 | token_cls: Type[PunktToken] = PunktToken, 33 | ) -> None: 34 | """ 35 | Initialize the paragraph tokenizer. 36 | 37 | Args: 38 | sentence_tokenizer: The sentence tokenizer to use (if None, a default model will be loaded) 39 | lang_vars: Language-specific variables 40 | token_cls: The token class to use 41 | """ 42 | self._lang_vars = lang_vars or PunktLanguageVars() 43 | self._Token = token_cls 44 | 45 | # Initialize the sentence tokenizer 46 | if sentence_tokenizer is None: 47 | # Use the singleton pattern to avoid reloading the model 48 | self._sentence_tokenizer = self._get_default_model() 49 | else: 50 | self._sentence_tokenizer = sentence_tokenizer 51 | 52 | # Class-level variable for caching the default model 53 | _default_model = None 54 | 55 | @staticmethod 56 | def _get_default_model(): 57 | """Get the default model, with caching using a class-level variable.""" 58 | # Use class-level caching for the default model 59 | if PunktParagraphTokenizer._default_model is None: 60 | from nupunkt.models import load_default_model 61 | 62 | PunktParagraphTokenizer._default_model = load_default_model() 63 | return PunktParagraphTokenizer._default_model 64 | 65 | def tokenize(self, text: str) -> List[str]: 66 | """ 67 | Tokenize text into paragraphs. 68 | 69 | Args: 70 | text: The text to tokenize 71 | 72 | Returns: 73 | A list of paragraphs 74 | """ 75 | return [paragraph for paragraph, _ in self.tokenize_with_spans(text)] 76 | 77 | def tokenize_with_spans(self, text: str) -> List[Tuple[str, Tuple[int, int]]]: 78 | """ 79 | Tokenize text into paragraphs with their character spans. 80 | 81 | Each span is a tuple of (start_idx, end_idx) where start_idx is inclusive 82 | and end_idx is exclusive (following Python's slicing convention). 83 | The spans are guaranteed to cover the entire input text without gaps. 84 | 85 | Args: 86 | text: The text to tokenize 87 | 88 | Returns: 89 | List of tuples containing (paragraph, (start_index, end_index)) 90 | """ 91 | # Quick return for empty text 92 | if not text: 93 | return [] 94 | 95 | # Get all sentence boundary positions 96 | spans = list(self._sentence_tokenizer.span_tokenize(text)) 97 | boundary_positions = [span[1] for span in spans] 98 | 99 | # If no boundaries, treat the whole text as one paragraph 100 | if not boundary_positions: 101 | return [(text, (0, len(text)))] 102 | 103 | # Find paragraph boundaries (sentence boundaries followed by 2+ newlines) 104 | paragraph_boundaries = [] 105 | 106 | for pos in boundary_positions: 107 | # Look for 2+ newlines right after this boundary 108 | window_end = min(pos + 10, len(text)) # 10 char window is enough for newlines 109 | 110 | if pos < len(text): 111 | # Get the text slice to check 112 | window = text[pos:window_end] 113 | 114 | # Search for 2+ newlines in the window using the precompiled pattern 115 | match = PARAGRAPH_BREAK_PATTERN.search(window) 116 | 117 | # Only consider it a match if the newlines appear near the start of the window 118 | if match and match.start() <= 3: # Allow for a few whitespace chars after sentence 119 | paragraph_boundaries.append(pos) 120 | 121 | # Always include the end of text as a paragraph boundary 122 | if not paragraph_boundaries or paragraph_boundaries[-1] != len(text): 123 | paragraph_boundaries.append(len(text)) 124 | 125 | # Create paragraph spans from boundaries 126 | result = [] 127 | start_idx = 0 128 | 129 | for end_idx in paragraph_boundaries: 130 | # Get the paragraph text (without stripping to maintain all whitespace) 131 | paragraph_text = text[start_idx:end_idx] 132 | 133 | # Include all paragraphs to ensure contiguity 134 | result.append((paragraph_text, (start_idx, end_idx))) 135 | start_idx = end_idx 136 | 137 | return result 138 | 139 | def span_tokenize(self, text: str) -> List[Tuple[int, int]]: 140 | """ 141 | Tokenize text into paragraph spans. 142 | 143 | Each span is a tuple of (start_idx, end_idx) where start_idx is inclusive 144 | and end_idx is exclusive (following Python's slicing convention). 145 | The spans are guaranteed to cover the entire input text without gaps. 146 | 147 | Args: 148 | text: The text to tokenize 149 | 150 | Returns: 151 | List of paragraph spans (start_index, end_index) 152 | """ 153 | return [span for _, span in self.tokenize_with_spans(text)] 154 | 155 | def save( 156 | self, file_path: Union[str, Path], compress: bool = True, compression_level: int = 1 157 | ) -> None: 158 | """ 159 | Save the tokenizer to a file. 160 | 161 | This saves the underlying sentence tokenizer. 162 | 163 | Args: 164 | file_path: The path to save the file to 165 | compress: Whether to compress the file using LZMA (default: True) 166 | compression_level: LZMA compression level (0-9), lower is faster but less compressed 167 | """ 168 | self._sentence_tokenizer.save(file_path, compress, compression_level) 169 | 170 | @classmethod 171 | def load( 172 | cls, 173 | file_path: Union[str, Path], 174 | lang_vars: Optional[PunktLanguageVars] = None, 175 | token_cls: Optional[Type[PunktToken]] = None, 176 | ) -> "PunktParagraphTokenizer": 177 | """ 178 | Load a PunktParagraphTokenizer from a file. 179 | 180 | Args: 181 | file_path: The path to load the file from 182 | lang_vars: Optional language variables 183 | token_cls: Optional token class 184 | 185 | Returns: 186 | A new PunktParagraphTokenizer instance 187 | """ 188 | sentence_tokenizer = PunktSentenceTokenizer.load(file_path, lang_vars, token_cls) 189 | return cls(sentence_tokenizer, lang_vars, token_cls or PunktToken) -------------------------------------------------------------------------------- /nupunkt/trainers/__init__.py: -------------------------------------------------------------------------------- 1 | """Trainer module for nupunkt.""" 2 | 3 | from nupunkt.trainers.base_trainer import PunktTrainer 4 | 5 | __all__ = ["PunktTrainer"] 6 | -------------------------------------------------------------------------------- /nupunkt/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """Utility modules for nupunkt.""" 2 | 3 | from nupunkt.utils.compression import load_compressed_json, save_compressed_json 4 | from nupunkt.utils.iteration import pair_iter 5 | from nupunkt.utils.statistics import collocation_log_likelihood, dunning_log_likelihood 6 | 7 | __all__ = [ 8 | "pair_iter", 9 | "dunning_log_likelihood", 10 | "collocation_log_likelihood", 11 | "save_compressed_json", 12 | "load_compressed_json", 13 | ] 14 | -------------------------------------------------------------------------------- /nupunkt/utils/iteration.py: -------------------------------------------------------------------------------- 1 | """ 2 | Iteration utilities for nupunkt. 3 | 4 | This module provides utility functions for iterating through sequences 5 | with specialized behaviors needed for the Punkt algorithm. 6 | """ 7 | 8 | from typing import Any, Iterator, Optional, Sequence, Tuple, TypeVar 9 | 10 | T = TypeVar("T") 11 | 12 | 13 | def pair_iter(iterable: Iterator[Any]) -> Iterator[Tuple[Any, Optional[Any]]]: 14 | """ 15 | Iterate through pairs of items from an iterable, where the second item 16 | can be None for the last item. 17 | 18 | Args: 19 | iterable: The input iterator 20 | 21 | Yields: 22 | Pairs of (current_item, next_item) where next_item is None for the last item 23 | """ 24 | it = iter(iterable) 25 | prev = next(it, None) 26 | if prev is None: 27 | return 28 | for current in it: 29 | yield prev, current 30 | prev = current 31 | yield prev, None 32 | 33 | 34 | def pair_iter_fast(items: Sequence[T]) -> Iterator[Tuple[T, Optional[T]]]: 35 | """ 36 | Fast implementation of pair iteration for sequences (lists, tuples). 37 | This avoids the iterator overhead for known sequence types. 38 | 39 | Args: 40 | items: A sequence (list or tuple) to iterate through in pairs 41 | 42 | Yields: 43 | Pairs of (current_item, next_item) where next_item is None for the last item 44 | """ 45 | length = len(items) 46 | if length == 0: 47 | return 48 | 49 | # Handle all but the last item 50 | for i in range(length - 1): 51 | yield items[i], items[i + 1] 52 | 53 | # Handle the last item (with None as the next item) 54 | yield items[length - 1], None 55 | -------------------------------------------------------------------------------- /nupunkt/utils/statistics.py: -------------------------------------------------------------------------------- 1 | """ 2 | Statistical utilities for nupunkt. 3 | 4 | This module provides statistical functions used in the Punkt algorithm 5 | for calculating log-likelihood and related measurements. 6 | """ 7 | 8 | import math 9 | 10 | 11 | def dunning_log_likelihood(count_a: int, count_b: int, count_ab: int, N: int) -> float: 12 | """ 13 | Modified Dunning log-likelihood calculation that gives higher weight to 14 | potential abbreviations. This makes the model more likely to detect abbreviations, 15 | especially in larger datasets where evidence may be diluted. 16 | 17 | Args: 18 | count_a: Count of event A (e.g., token appears) 19 | count_b: Count of event B (e.g., token appears with period) 20 | count_ab: Count of events A and B together 21 | N: Total count of all events 22 | 23 | Returns: 24 | The log likelihood score (higher means more significant) 25 | """ 26 | p1 = count_b / N 27 | p2 = 0.99 28 | null_hypo = count_ab * math.log(p1 + 1e-8) + (count_a - count_ab) * math.log(1.0 - p1 + 1e-8) 29 | alt_hypo = count_ab * math.log(p2) + (count_a - count_ab) * math.log(1.0 - p2) 30 | 31 | # Basic log likelihood calculation 32 | ll = -2.0 * (null_hypo - alt_hypo) 33 | 34 | # Boosting factor for short tokens (likely abbreviations) 35 | # This makes the algorithm more sensitive to abbreviation detection 36 | return ll * 1.5 37 | 38 | 39 | def collocation_log_likelihood(count_a: int, count_b: int, count_ab: int, N: int) -> float: 40 | """ 41 | Calculate the log-likelihood ratio for collocations. 42 | 43 | Args: 44 | count_a: Count of the first token 45 | count_b: Count of the second token 46 | count_ab: Count of the collocation (first and second token together) 47 | N: Total number of tokens 48 | 49 | Returns: 50 | The log likelihood score for the collocation 51 | """ 52 | p = count_b / N 53 | p1 = count_ab / count_a if count_a else 0 54 | try: 55 | p2 = (count_b - count_ab) / (N - count_a) if (N - count_a) else 0 56 | except ZeroDivisionError: 57 | p2 = 1 58 | try: 59 | summand1 = count_ab * math.log(p) + (count_a - count_ab) * math.log(1.0 - p) 60 | except ValueError: 61 | summand1 = 0 62 | try: 63 | summand2 = (count_b - count_ab) * math.log(p) + ( 64 | N - count_a - count_b + count_ab 65 | ) * math.log(1.0 - p) 66 | except ValueError: 67 | summand2 = 0 68 | summand3 = ( 69 | 0 70 | if count_a == count_ab or p1 <= 0 or p1 >= 1 71 | else count_ab * math.log(p1) + (count_a - count_ab) * math.log(1.0 - p1) 72 | ) 73 | summand4 = ( 74 | 0 75 | if count_b == count_ab or p2 <= 0 or p2 >= 1 76 | else (count_b - count_ab) * math.log(p2) 77 | + (N - count_a - count_b + count_ab) * math.log(1.0 - p2) 78 | ) 79 | return -2.0 * (summand1 + summand2 - summand3 - summand4) 80 | -------------------------------------------------------------------------------- /paragraphs.py: -------------------------------------------------------------------------------- 1 | """ 2 | Paragraph segmentation functionality. 3 | """ 4 | 5 | import re 6 | from typing import List, Tuple, Optional, TYPE_CHECKING 7 | 8 | from charboundary.segmenters.spans import SpanHandler 9 | 10 | if TYPE_CHECKING: 11 | from charboundary.segmenters.base import TextSegmenter 12 | 13 | # Precompiled regex pattern for two or more consecutive newlines 14 | # This will efficiently detect paragraph breaks 15 | PARAGRAPH_BREAK_PATTERN = re.compile(r"\n\s*\n+") 16 | 17 | 18 | class ParagraphSegmenter: 19 | """ 20 | Handles segmenting text into paragraphs. 21 | """ 22 | 23 | @classmethod 24 | def segment_to_paragraphs( 25 | cls, 26 | segmenter: "TextSegmenter", 27 | text: str, 28 | streaming: bool = False, # pylint: disable=unused-argument 29 | threshold: Optional[float] = None, 30 | ) -> List[str]: 31 | """ 32 | Segment text into a list of paragraphs. 33 | 34 | This method identifies paragraph breaks ONLY at sentence boundaries that are 35 | immediately followed by two or more newlines. 36 | 37 | Args: 38 | segmenter: The TextSegmenter to use 39 | text (str): Text to segment 40 | streaming (bool, optional): Ignored. Included for API compatibility only. 41 | Defaults to False. 42 | threshold (float, optional): Probability threshold for classification (0.0-1.0). 43 | Values below 0.5 favor recall (fewer false negatives), 44 | values above 0.5 favor precision (fewer false positives). 45 | If None, use the model's default threshold. 46 | Defaults to None. 47 | 48 | Returns: 49 | List[str]: List of paragraphs 50 | """ 51 | # Quick return for empty text 52 | if not text: 53 | return [] 54 | 55 | # We need to handle the case where the sentence segmentation 56 | # includes the newlines in the next sentence 57 | paragraph_spans = cls.get_paragraph_spans_with_text( 58 | segmenter, text, streaming, threshold 59 | ) 60 | return [para_text for para_text, _ in paragraph_spans] 61 | 62 | @classmethod 63 | def get_paragraph_spans( 64 | cls, 65 | segmenter: "TextSegmenter", 66 | text: str, 67 | streaming: bool = False, # pylint: disable=unused-argument 68 | threshold: Optional[float] = None, 69 | ) -> List[Tuple[int, int]]: 70 | """ 71 | Get the character spans for each paragraph in the text. 72 | 73 | Each span is a tuple of (start_idx, end_idx) where start_idx is inclusive 74 | and end_idx is exclusive (following Python's slicing convention). 75 | The spans are guaranteed to cover the entire input text without gaps. 76 | 77 | Args: 78 | segmenter: The TextSegmenter to use 79 | text (str): Text to segment 80 | streaming (bool, optional): Ignored. Included for API compatibility only. 81 | Defaults to False. 82 | threshold (float, optional): Probability threshold for classification (0.0-1.0). 83 | Values below 0.5 favor recall (fewer false negatives), 84 | values above 0.5 favor precision (fewer false positives). 85 | If None, use the model's default threshold. 86 | Defaults to None. 87 | 88 | Returns: 89 | List[tuple[int, int]]: List of character spans (start_index, end_index) 90 | """ 91 | paragraphs_with_spans = cls.get_paragraph_spans_with_text( 92 | segmenter, text, streaming, threshold 93 | ) 94 | return [span for _, span in paragraphs_with_spans] 95 | 96 | @classmethod 97 | # pylint: disable=too-many-locals,too-many-branches 98 | def get_paragraph_spans_with_text( 99 | cls, 100 | segmenter: "TextSegmenter", 101 | text: str, 102 | streaming: bool = False, # pylint: disable=unused-argument 103 | threshold: Optional[float] = None, 104 | ) -> List[Tuple[str, Tuple[int, int]]]: 105 | """ 106 | Segment text into a list of paragraphs with their character spans. 107 | 108 | Each span is a tuple of (start_idx, end_idx) where start_idx is inclusive 109 | and end_idx is exclusive (following Python's slicing convention). 110 | The spans are guaranteed to cover the entire input text without gaps. 111 | 112 | Args: 113 | segmenter: The TextSegmenter to use 114 | text (str): Text to segment 115 | streaming (bool, optional): Ignored. Included for API compatibility only. 116 | Defaults to False. 117 | threshold (float, optional): Probability threshold for classification (0.0-1.0). 118 | Values below 0.5 favor recall (fewer false negatives), 119 | values above 0.5 favor precision (fewer false positives). 120 | If None, use the model's default threshold. 121 | Defaults to None. 122 | 123 | Returns: 124 | List[Tuple[str, Tuple[int, int]]]: List of tuples containing 125 | (paragraph, (start_index, end_index)) 126 | """ 127 | # Quick return for empty text 128 | if not text: 129 | return [] 130 | 131 | # Direct search for multiple newlines after sentence boundaries 132 | 133 | # Get all sentence boundary positions 134 | boundary_positions = SpanHandler.find_boundary_positions( 135 | segmenter, text, threshold 136 | ) 137 | 138 | # If no boundaries, treat the whole text as one paragraph 139 | if not boundary_positions: 140 | return [(text.strip(), (0, len(text)))] 141 | 142 | # Find paragraph boundaries (sentence boundaries followed by 2+ newlines) 143 | paragraph_boundaries = [] 144 | 145 | for pos in boundary_positions: 146 | # Look for 2+ newlines right after this boundary 147 | window_end = min( 148 | pos + 10, len(text) 149 | ) # 10 char window is enough for newlines 150 | 151 | if pos < len(text): 152 | # Get the text slice to check 153 | window = text[pos:window_end] 154 | 155 | # Search for 2+ newlines in the window using the precompiled pattern 156 | match = PARAGRAPH_BREAK_PATTERN.search(window) 157 | 158 | # Only consider it a match if the newlines appear near the start of the window 159 | if ( 160 | match and match.start() <= 3 161 | ): # Allow for a few whitespace chars after sentence 162 | paragraph_boundaries.append(pos) 163 | 164 | # Always include the end of text as a paragraph boundary 165 | if not paragraph_boundaries or paragraph_boundaries[-1] != len(text): 166 | paragraph_boundaries.append(len(text)) 167 | 168 | # Create paragraph spans from boundaries 169 | result = [] 170 | start_idx = 0 171 | 172 | for end_idx in paragraph_boundaries: 173 | # Get the paragraph text (trim whitespace) 174 | paragraph_text = text[start_idx:end_idx].strip() 175 | 176 | # Only include non-empty paragraphs 177 | if paragraph_text: 178 | # For the span, include leading/trailing whitespace to ensure contiguity 179 | result.append((paragraph_text, (start_idx, end_idx))) 180 | 181 | start_idx = end_idx 182 | 183 | # Fix span coverage and ensure contiguity 184 | if result: 185 | # Make sure the first span starts at 0 186 | if result[0][1][0] > 0: 187 | paragraph_text, (_, end) = result[0] 188 | result[0] = (paragraph_text, (0, end)) 189 | 190 | # Make sure the last span ends at the text length 191 | if result[-1][1][1] < len(text): 192 | paragraph_text, (start, _) = result[-1] 193 | result[-1] = (paragraph_text, (start, len(text))) 194 | 195 | # Ensure contiguity between spans 196 | for i in range(len(result) - 1): 197 | if result[i][1][1] != result[i + 1][1][0]: 198 | paragraph_text, (start, _) = result[i] 199 | result[i] = (paragraph_text, (start, result[i + 1][1][0])) 200 | 201 | return result -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "nupunkt" 3 | version = "0.5.1" 4 | description = "Next-generation Punkt sentence and paragraph boundary detection with zero dependencies" 5 | readme = "README.md" 6 | requires-python = ">=3.11" 7 | license = "MIT" 8 | authors = [ 9 | {name = "ALEA Institute", email = "hello@aleainstitute.ai"} 10 | ] 11 | keywords = [ 12 | "nlp", 13 | "natural language processing", 14 | "tokenization", 15 | "sentence boundary detection", 16 | "paragraph detection", 17 | "punkt", 18 | "text processing", 19 | "linguistics" 20 | ] 21 | classifiers = [ 22 | "Development Status :: 4 - Beta", 23 | "Intended Audience :: Developers", 24 | "Intended Audience :: Science/Research", 25 | "Programming Language :: Python :: 3", 26 | "Programming Language :: Python :: 3.11", 27 | "Programming Language :: Python :: 3.12", 28 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 29 | "Topic :: Text Processing :: Linguistic", 30 | "Natural Language :: English", 31 | "Operating System :: OS Independent", 32 | ] 33 | dependencies = [ 34 | "tqdm>=4.65.0", 35 | # lzma is in the standard library for Python 3.3+ 36 | ] 37 | 38 | [project.urls] 39 | "Homepage" = "https://github.com/alea-institute/nupunkt" 40 | "Source" = "https://github.com/alea-institute/nupunkt" 41 | "Bug Tracker" = "https://github.com/alea-institute/nupunkt/issues" 42 | "Changelog" = "https://github.com/alea-institute/nupunkt/blob/main/CHANGELOG.md" 43 | 44 | [project.optional-dependencies] 45 | dev = [ 46 | "pytest>=8.0.0", 47 | "mypy>=1.0.0", 48 | "ruff>=0.0.280", 49 | "black>=23.0.0", 50 | "pytest-benchmark>=5.1.0", 51 | "pytest-cov>=6.0.0", 52 | ] 53 | docs = [ 54 | "sphinx>=6.0.0", 55 | "sphinx-rtd-theme>=1.0.0", 56 | "myst-parser>=2.0.0", 57 | ] 58 | 59 | [build-system] 60 | requires = ["setuptools>=61.0.0", "wheel"] 61 | build-backend = "setuptools.build_meta" 62 | 63 | [tool.setuptools] 64 | packages = [ 65 | "nupunkt", 66 | "nupunkt.core", 67 | "nupunkt.tokenizers", 68 | "nupunkt.trainers", 69 | "nupunkt.utils", 70 | "nupunkt.models" 71 | ] 72 | 73 | [tool.setuptools.package-data] 74 | "nupunkt.models" = ["*.json", "*.json.xz", "*.bin"] 75 | 76 | [tool.ruff] 77 | target-version = "py311" 78 | line-length = 100 79 | 80 | # Enable linters in this section 81 | [tool.ruff.lint] 82 | select = [ 83 | "E", # pycodestyle errors 84 | "F", # pyflakes 85 | "B", # flake8-bugbear 86 | "I", # isort 87 | "N", # pep8-naming 88 | "UP", # pyupgrade 89 | "C4", # flake8-comprehensions 90 | "SIM", # flake8-simplify 91 | "PTH", # use pathlib 92 | ] 93 | 94 | ignore = [ 95 | "E501", # line too long - already handled by formatter 96 | "E741", # ambiguous variable name 'l' 97 | "E402", # module level import not at top of file - needed for sys.path manipulation 98 | "B905", # zip without explicit strict parameter 99 | "N803", # argument name should be lowercase 100 | "N806", # variable in function should be lowercase 101 | "N815", # variable in class scope should not be mixedCase 102 | "UP006", # use tuple instead of Tuple - maintain compatibility with Python 3.8 103 | "UP007", # use X | Y for type annotations - maintain compatibility with Python < 3.10 104 | "UP035", # deprecated import 105 | "PTH123", # use open() instead of Path.open() 106 | "SIM105", # use contextlib.suppress(Exception) instead of try/except pass 107 | ] 108 | 109 | # Allow unused variables when underscore-prefixed 110 | dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" 111 | 112 | # Allow autofix for all enabled rules (when `--fix`) is provided 113 | fixable = ["ALL"] 114 | 115 | [tool.black] 116 | line-length = 100 117 | target-version = ["py311"] 118 | 119 | # This is now in mypy.ini, which has more detailed configuration 120 | 121 | # Testing config 122 | [tool.pytest.ini_options] 123 | testpaths = ["tests"] 124 | python_files = "test_*.py" 125 | addopts = "-v" 126 | 127 | [dependency-groups] 128 | dev = [ 129 | "memory-profiler>=0.61.0", 130 | "psutil>=7.0.0", 131 | "pytest-benchmark>=5.1.0", 132 | ] -------------------------------------------------------------------------------- /scripts/profile_default_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Script to profile the performance of the default model for nupunkt. 4 | 5 | This script: 6 | 1. Loads the default model from nupunkt/models/ 7 | 2. Profiles its performance using cProfile and line_profiler 8 | 3. Generates reports showing where time is spent 9 | """ 10 | 11 | import argparse 12 | import cProfile 13 | import gzip 14 | import json 15 | import pstats 16 | import sys 17 | import time 18 | from pathlib import Path 19 | from typing import Any, Dict, List 20 | 21 | # Add the parent directory to the path so we can import nupunkt 22 | script_dir = Path(__file__).parent 23 | root_dir = script_dir.parent 24 | sys.path.append(str(root_dir)) 25 | 26 | # Import nupunkt 27 | from nupunkt.tokenizers.sentence_tokenizer import PunktSentenceTokenizer 28 | 29 | # Test cases from test_default_model.py 30 | TEST_CASES = { 31 | "basic": [ 32 | "This is a simple sentence. This is another one.", 33 | "Hello world! How are you today? I'm doing well.", 34 | "The quick brown fox jumps over the lazy dog. The fox was very quick.", 35 | ], 36 | "abbreviations": [ 37 | "Dr. Smith went to Washington, D.C. He was very excited about the trip.", 38 | "The company (Ltd.) was founded in 1997. It has grown significantly since then.", 39 | "Mr. Johnson and Mrs. Lee will attend the meeting at 3 p.m. They will discuss the agenda.", 40 | "She has a B.A. in English. She also studied French in college.", 41 | "The U.S. economy is growing. Many industries are showing improvement.", 42 | ], 43 | "legal_citations": [ 44 | "Under 18 U.S.C. 12, this is a legal citation. The next sentence begins here.", 45 | "As stated in Fed. R. Civ. P. 56(c), summary judgment is appropriate. This standard is well established.", 46 | "In Smith v. Jones, 123 F.3d 456 (9th Cir. 1997), the court ruled in favor of the plaintiff. This set a precedent.", 47 | "According to Cal. Civ. Code § 123, the contract must be in writing. This requirement is strict.", 48 | ], 49 | "ellipsis": [ 50 | "This text contains an ellipsis... And this is a new sentence.", 51 | "The story continues... But not for long.", 52 | "He paused for a moment... Then he continued walking.", 53 | "She thought about it for a while... Then she made her decision.", 54 | ], 55 | "other_punctuation": [ 56 | "Let me give you an example, e.g. this one. Did you understand it?", 57 | "The company (formerly known as Tech Solutions, Inc.) was acquired last year. The new owners rebranded it.", 58 | "The meeting is at 3 p.m. Don't be late!", 59 | 'He said, "I\'ll be there at 5 p.m." Then he hung up the phone.', 60 | ], 61 | "challenging": [ 62 | "The patient presented with abd. pain. CT scan was ordered.", 63 | "The table shows results for Jan. Feb. and Mar. Each month shows improvement.", 64 | "Visit the website at www.example.com. There you'll find more information.", 65 | "She scored 92 vs. 85 in the previous match. Her performance has improved.", 66 | "The temperature was 32 deg. C. It was quite hot that day.", 67 | ], 68 | } 69 | 70 | 71 | def load_test_data(file_path: Path) -> List[str]: 72 | """Load text examples from a test file.""" 73 | print(f"Loading test data from {file_path}...") 74 | if file_path.suffix == ".gz": 75 | with gzip.open(file_path, "rt", encoding="utf-8") as f: 76 | if file_path.suffix == ".jsonl.gz": 77 | # Handle JSONL format 78 | texts = [] 79 | for line in f: 80 | try: 81 | data = json.loads(line) 82 | if "text" in data: 83 | texts.append(data["text"]) 84 | except json.JSONDecodeError: 85 | print(f"Warning: Could not parse line in {file_path}") 86 | return texts 87 | else: 88 | # Plain text in gzip 89 | return f.read().split("\n\n") 90 | else: 91 | # Regular text file 92 | with open(file_path, encoding="utf-8") as f: 93 | return f.read().split("\n\n") 94 | 95 | 96 | def test_model_accuracy(tokenizer: PunktSentenceTokenizer, texts: List[str]) -> Dict[str, Any]: 97 | """Test the tokenizer's performance on a list of texts.""" 98 | start_time = time.time() 99 | total_chars = sum(len(text) for text in texts) 100 | total_sentences = 0 101 | 102 | for text in texts: 103 | sentences = tokenizer.tokenize(text) 104 | total_sentences += len(sentences) 105 | 106 | end_time = time.time() 107 | processing_time = end_time - start_time 108 | chars_per_second = total_chars / processing_time if processing_time > 0 else 0 109 | 110 | return { 111 | "total_texts": len(texts), 112 | "total_chars": total_chars, 113 | "total_sentences": total_sentences, 114 | "processing_time_seconds": processing_time, 115 | "chars_per_second": chars_per_second, 116 | } 117 | 118 | 119 | def run_cprofile(tokenizer: PunktSentenceTokenizer, texts: List[str], output_path: Path) -> None: 120 | """Run cProfile on the tokenizer with the given texts.""" 121 | print("\n=== Running cProfile Analysis ===") 122 | 123 | # Create profile output directory if it doesn't exist 124 | output_path.parent.mkdir(exist_ok=True, parents=True) 125 | 126 | # Run with cProfile 127 | profiler = cProfile.Profile() 128 | profiler.enable() 129 | 130 | # Run the actual test 131 | for text in texts: 132 | tokenizer.tokenize(text) 133 | 134 | profiler.disable() 135 | 136 | # Save stats to file 137 | stats_path = output_path.with_suffix(".prof") 138 | profiler.dump_stats(str(stats_path)) 139 | print(f"cProfile data saved to: {stats_path}") 140 | 141 | # Generate readable report 142 | stats = pstats.Stats(str(stats_path)) 143 | txt_path = output_path.with_suffix(".txt") 144 | with open(txt_path, "w") as f: 145 | sys.stdout = f # Redirect stdout to file 146 | stats.sort_stats("cumulative").print_stats(30) 147 | sys.stdout = sys.__stdout__ # Reset stdout 148 | 149 | print(f"cProfile report saved to: {txt_path}") 150 | 151 | # Print summary to console 152 | print("\nTop 10 functions by cumulative time:") 153 | stats.sort_stats("cumulative").print_stats(10) 154 | 155 | 156 | def run_line_profiler( 157 | tokenizer: PunktSentenceTokenizer, texts: List[str], output_path: Path 158 | ) -> None: 159 | """Run line_profiler on the tokenizer with the given texts.""" 160 | try: 161 | from line_profiler import LineProfiler 162 | except ImportError: 163 | print("line_profiler not installed. Skipping line-by-line profiling.") 164 | print("Install with: pip install line_profiler") 165 | return 166 | 167 | print("\n=== Running Line Profiler Analysis ===") 168 | 169 | # Create a line profiler and add functions to profile 170 | lp = LineProfiler() 171 | lp.add_function(tokenizer.tokenize) 172 | 173 | # Try to profile important internal methods if they exist 174 | for method_name in [ 175 | "_slices_from_text", 176 | "_annotate_tokens", 177 | "_tokenize_words", 178 | "_handle_abbrev", 179 | "_handle_potential_sentence_break", 180 | ]: 181 | if hasattr(tokenizer, method_name): 182 | method = getattr(tokenizer, method_name) 183 | lp.add_function(method) 184 | 185 | # Profile the tokenization process 186 | lp_wrapper = lp(lambda: [tokenizer.tokenize(text) for text in texts]) 187 | lp_wrapper() 188 | 189 | # Save the line profiler stats to text file 190 | txt_path = output_path.with_suffix(".line.txt") 191 | 192 | # Generate text report 193 | with open(txt_path, "w") as f: 194 | sys.stdout = f # Redirect stdout to file 195 | lp.print_stats() 196 | sys.stdout = sys.__stdout__ # Reset stdout 197 | 198 | print(f"Line profiler report saved to: {txt_path}") 199 | 200 | # Print summary to console 201 | print("\nLine-by-line profiling results:") 202 | lp.print_stats() 203 | 204 | 205 | def parse_args(): 206 | parser = argparse.ArgumentParser(description="Profile the nupunkt default model") 207 | parser.add_argument( 208 | "--examples-only", action="store_true", help="Only use predefined examples (faster)" 209 | ) 210 | parser.add_argument( 211 | "--cprofile", 212 | action="store_true", 213 | default=True, 214 | help="Run cProfile profiling (default: True)", 215 | ) 216 | parser.add_argument( 217 | "--line-profiler", 218 | action="store_true", 219 | help="Run line_profiler profiling (requires line_profiler package)", 220 | ) 221 | parser.add_argument( 222 | "--output", 223 | type=str, 224 | default="profile_results", 225 | help="Base name for output files (without extension)", 226 | ) 227 | return parser.parse_args() 228 | 229 | 230 | def main(): 231 | """Profile the default model and report results.""" 232 | args = parse_args() 233 | 234 | # Set paths 235 | models_dir = root_dir / "nupunkt" / "models" 236 | model_path = models_dir / "default_model.bin" 237 | test_path = root_dir / "data" / "test.jsonl.gz" 238 | output_dir = root_dir / "profiles" 239 | output_dir.mkdir(exist_ok=True) 240 | output_path = output_dir / args.output 241 | 242 | # Check if model exists 243 | if not model_path.exists(): 244 | raise FileNotFoundError(f"Model file not found: {model_path}") 245 | 246 | # Load the model 247 | print(f"Loading model from {model_path}...") 248 | tokenizer = PunktSentenceTokenizer.load(model_path) 249 | print("Model loaded successfully") 250 | 251 | # Determine which texts to use for profiling 252 | if args.examples_only: 253 | # Use only the predefined examples 254 | all_examples = [] 255 | for examples in TEST_CASES.values(): 256 | all_examples.extend(examples) 257 | texts = all_examples 258 | print(f"Using {len(texts)} predefined examples for profiling") 259 | else: 260 | # Use test data if available 261 | if test_path.exists(): 262 | texts = load_test_data(test_path) 263 | print(f"Using {len(texts)} test documents for profiling") 264 | else: 265 | print(f"Test data file not found: {test_path}") 266 | print("Falling back to predefined examples") 267 | all_examples = [] 268 | for examples in TEST_CASES.values(): 269 | all_examples.extend(examples) 270 | texts = all_examples 271 | print(f"Using {len(texts)} predefined examples for profiling") 272 | 273 | # Run requested profiling 274 | if args.cprofile: 275 | run_cprofile(tokenizer, texts, output_path) 276 | 277 | if args.line_profiler: 278 | run_line_profiler(tokenizer, texts, output_path) 279 | 280 | print("\nProfiling completed successfully.") 281 | 282 | 283 | if __name__ == "__main__": 284 | main() 285 | -------------------------------------------------------------------------------- /scripts/utils/README.md: -------------------------------------------------------------------------------- 1 | # Nupunkt Utility Scripts 2 | 3 | Scripts to manage, optimize, and analyze nupunkt models. 4 | 5 | ## Utility Scripts 6 | 7 | ### Model Management 8 | 9 | - `benchmark_load_times.py`: Compare load times of different model formats 10 | - `convert_model.py`: Convert between different model storage formats 11 | - `model_info.py`: Display information about a model file 12 | - `optimize_model.py`: Optimize a model file for size and loading performance 13 | 14 | ### Tokenization and Analysis 15 | 16 | - `check_abbreviation.py`: Check if a token is in the model's abbreviation list 17 | - `test_tokenizer.py`: Test the tokenizer on sample text 18 | - `explain_tokenization.py`: Show detailed explanation for tokenization decisions 19 | 20 | ## Usage Examples 21 | 22 | ### Check Abbreviation Tool 23 | 24 | Check if a specific token is recognized as an abbreviation in the model: 25 | 26 | ```bash 27 | # Check a specific token 28 | python check_abbreviation.py "Dr." 29 | 30 | # List all abbreviations in the model 31 | python check_abbreviation.py --list 32 | 33 | # Count total abbreviations 34 | python check_abbreviation.py --count 35 | 36 | # Find abbreviations starting with a specific prefix 37 | python check_abbreviation.py --startswith "u.s" 38 | 39 | # Check in a custom model 40 | python check_abbreviation.py "Dr." --model /path/to/custom_model.bin 41 | ``` 42 | -------------------------------------------------------------------------------- /scripts/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utility scripts for managing nupunkt models. 3 | 4 | This package contains various utility scripts for: 5 | - Converting models between formats 6 | - Optimizing model storage 7 | - Displaying model information 8 | - Testing models with custom text 9 | - Benchmarking model performance 10 | """ 11 | -------------------------------------------------------------------------------- /scripts/utils/check_abbreviation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Script to check if a token is in the model's abbreviation list. 4 | 5 | This utility allows users to verify if a specific token is recognized as an abbreviation 6 | in the default model or a custom model. 7 | """ 8 | 9 | import argparse 10 | import os 11 | import sys 12 | from pathlib import Path 13 | from typing import Optional, Set, Tuple 14 | 15 | # Add the parent directory to the path so we can import nupunkt 16 | script_dir = Path(__file__).parent 17 | root_dir = script_dir.parent.parent 18 | sys.path.append(str(root_dir)) 19 | 20 | # Import nupunkt 21 | from nupunkt.core.parameters import PunktParameters 22 | from nupunkt.models import load_default_model 23 | 24 | 25 | def check_abbreviation(token: str, model_path: Optional[str] = None) -> Tuple[bool, Set[str]]: 26 | """ 27 | Check if a token is in the model's abbreviation list. 28 | 29 | Args: 30 | token: The token to check 31 | model_path: Optional path to a custom model file 32 | 33 | Returns: 34 | A tuple containing: 35 | - True if the token is in the abbreviation list, False otherwise 36 | - The set of abbreviation types from the model 37 | """ 38 | # Load the model 39 | if model_path: 40 | model_path_obj = Path(model_path) 41 | # Load parameters directly from file 42 | params = PunktParameters.load(model_path_obj) 43 | abbrev_types = params.abbrev_types 44 | else: 45 | # Load the default model 46 | tokenizer = load_default_model() 47 | # Access to protected member _params is necessary as there's no public API 48 | # to get the abbreviation types directly from the tokenizer 49 | abbrev_types = tokenizer._params.abbrev_types 50 | 51 | # Clean the token for checking (remove trailing period if present) 52 | clean_token = token.lower() 53 | if clean_token.endswith("."): 54 | clean_token = clean_token[:-1] 55 | 56 | # Check if the token is in the abbreviation list 57 | is_abbrev = clean_token in abbrev_types 58 | 59 | return is_abbrev, abbrev_types 60 | 61 | 62 | def main() -> None: 63 | """Check if a token is in the model's abbreviation list.""" 64 | parser = argparse.ArgumentParser( 65 | description="Check if a token is recognized as an abbreviation in the nupunkt model" 66 | ) 67 | parser.add_argument("token", type=str, nargs="?", help="The token to check") 68 | parser.add_argument("--model", "-m", type=str, help="Path to a custom model file") 69 | parser.add_argument( 70 | "--list", "-l", action="store_true", help="List all abbreviations in the model" 71 | ) 72 | parser.add_argument( 73 | "--startswith", "-s", type=str, help="List abbreviations starting with the given prefix" 74 | ) 75 | parser.add_argument( 76 | "--count", 77 | "-c", 78 | action="store_true", 79 | help="Show the total count of abbreviations in the model", 80 | ) 81 | 82 | args = parser.parse_args() 83 | 84 | # For operations that don't require a specific token 85 | if args.list or args.startswith is not None or args.count: 86 | dummy_token = "a" # Just use a dummy token to load the model 87 | _, abbrev_types = check_abbreviation(dummy_token, args.model) 88 | 89 | if args.list: 90 | # Sort and list all abbreviations 91 | sorted_abbrevs = sorted(abbrev_types) 92 | print(f"\nAll abbreviations in the model ({len(sorted_abbrevs)}):") 93 | for abbrev in sorted_abbrevs: 94 | print(f" {abbrev}") 95 | print() 96 | 97 | if args.startswith is not None: 98 | # List abbreviations starting with the given prefix 99 | prefix = args.startswith.lower() 100 | matching_abbrevs = [abbr for abbr in abbrev_types if abbr.startswith(prefix)] 101 | sorted_matches = sorted(matching_abbrevs) 102 | print(f"\nAbbreviations starting with '{prefix}' ({len(sorted_matches)}):") 103 | for abbrev in sorted_matches: 104 | print(f" {abbrev}") 105 | print() 106 | 107 | if args.count: 108 | # Show total count 109 | print(f"\nTotal abbreviations in the model: {len(abbrev_types)}\n") 110 | 111 | return 112 | 113 | # For checking a specific token 114 | if args.token is None: 115 | parser.print_help() 116 | print("\nError: Please provide a token to check or use --list, --startswith, or --count.") 117 | sys.exit(1) 118 | 119 | # Load the model and get abbreviations 120 | is_abbrev, abbrev_types = check_abbreviation(args.token, args.model) 121 | 122 | # Check the specific token 123 | token = args.token 124 | clean_token = token.lower() 125 | if clean_token.endswith("."): 126 | clean_token = clean_token[:-1] 127 | 128 | if is_abbrev: 129 | print(f"\nYes, '{clean_token}' is recognized as an abbreviation in the model.\n") 130 | else: 131 | print(f"\nNo, '{clean_token}' is NOT recognized as an abbreviation in the model.\n") 132 | 133 | 134 | if __name__ == "__main__": 135 | try: 136 | main() 137 | except BrokenPipeError: 138 | # This prevents the "Broken pipe" error message when piping output to tools like 'head' 139 | # Python flushes standard streams on exit; redirect remaining output 140 | # to /dev/null to avoid another BrokenPipeError at shutdown 141 | devnull = os.open(os.devnull, os.O_WRONLY) 142 | os.dup2(devnull, sys.stdout.fileno()) 143 | sys.exit(0) 144 | -------------------------------------------------------------------------------- /scripts/utils/convert_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Script to convert nupunkt models between different storage formats. 4 | 5 | This script: 6 | 1. Loads a model from a specified file or the default model 7 | 2. Converts it to a different format (binary, json_xz, json) 8 | 3. Saves it to a specified location 9 | 10 | Usage: 11 | python convert_model.py --input model.json.xz --output model.bin --format binary --compression zlib 12 | """ 13 | 14 | import argparse 15 | import sys 16 | import time 17 | from pathlib import Path 18 | 19 | # Add the parent directory to the path so we can import nupunkt 20 | script_dir = Path(__file__).parent 21 | root_dir = script_dir.parent 22 | sys.path.append(str(root_dir)) 23 | 24 | # Import nupunkt 25 | from nupunkt.utils.compression import ( 26 | load_binary_model, 27 | load_compressed_json, 28 | save_binary_model, 29 | save_compressed_json, 30 | ) 31 | 32 | 33 | def convert_model( 34 | input_path: Path, 35 | output_path: Path, 36 | format_type: str, 37 | compression_method: str, 38 | compression_level: int, 39 | ) -> None: 40 | """ 41 | Convert a model from one format to another. 42 | 43 | Args: 44 | input_path: Path to the input model file 45 | output_path: Path to save the converted model to 46 | format_type: Output format type ('binary', 'json_xz', 'json') 47 | compression_method: Compression method for binary format 48 | compression_level: Compression level (0-9) 49 | """ 50 | print(f"Loading model from {input_path}...") 51 | start_time = time.time() 52 | 53 | # Load the model data 54 | if str(input_path).endswith(".bin"): 55 | data = load_binary_model(input_path) 56 | else: 57 | data = load_compressed_json(input_path) 58 | 59 | load_time = time.time() - start_time 60 | print(f"Model loaded in {load_time:.3f} seconds") 61 | 62 | # Get input file size 63 | input_size = Path(input_path).stat().st_size 64 | print(f"Input file size: {input_size / 1024:.2f} KB") 65 | 66 | # Convert the model 67 | print( 68 | f"Converting to {format_type} format with {compression_method} compression (level {compression_level})..." 69 | ) 70 | start_time = time.time() 71 | 72 | if format_type == "binary": 73 | save_binary_model( 74 | data, output_path, compression_method=compression_method, level=compression_level 75 | ) 76 | else: 77 | save_compressed_json( 78 | data, output_path, level=compression_level, use_compression=(format_type == "json_xz") 79 | ) 80 | 81 | convert_time = time.time() - start_time 82 | print(f"Conversion completed in {convert_time:.3f} seconds") 83 | 84 | # Get output file size 85 | output_size = Path(output_path).stat().st_size 86 | print(f"Output file size: {output_size / 1024:.2f} KB") 87 | 88 | # Show compression ratio 89 | ratio = output_size / input_size 90 | print(f"Compression ratio: {ratio:.3f} (smaller is better)") 91 | 92 | # Verify the model can be loaded 93 | print("Verifying model can be loaded...") 94 | start_time = time.time() 95 | 96 | if format_type == "binary": 97 | _ = load_binary_model(output_path) 98 | else: 99 | _ = load_compressed_json(output_path) 100 | 101 | verify_time = time.time() - start_time 102 | print(f"Model verified in {verify_time:.3f} seconds") 103 | 104 | print(f"Model successfully converted and saved to {output_path}") 105 | 106 | 107 | def main(): 108 | """Run the model conversion.""" 109 | parser = argparse.ArgumentParser(description="Convert nupunkt models between different formats") 110 | parser.add_argument( 111 | "--input", 112 | type=str, 113 | default=None, 114 | help="Path to the input model file (default: use the default model)", 115 | ) 116 | parser.add_argument( 117 | "--output", type=str, required=True, help="Path to save the converted model to" 118 | ) 119 | parser.add_argument( 120 | "--format", 121 | type=str, 122 | default="binary", 123 | choices=["json", "json_xz", "binary"], 124 | help="Format to convert the model to", 125 | ) 126 | parser.add_argument( 127 | "--compression", 128 | type=str, 129 | default="zlib", 130 | choices=["none", "zlib", "lzma", "gzip"], 131 | help="Compression method for binary format", 132 | ) 133 | parser.add_argument("--level", type=int, default=6, help="Compression level (0-9)") 134 | 135 | args = parser.parse_args() 136 | 137 | # Use default model if input not specified 138 | if args.input: 139 | input_path = Path(args.input) 140 | else: 141 | from nupunkt.models import get_default_model_path 142 | 143 | input_path = get_default_model_path() 144 | 145 | output_path = Path(args.output) 146 | 147 | # Create output directory if it doesn't exist 148 | output_path.parent.mkdir(parents=True, exist_ok=True) 149 | 150 | # Convert the model 151 | convert_model( 152 | input_path=input_path, 153 | output_path=output_path, 154 | format_type=args.format, 155 | compression_method=args.compression, 156 | compression_level=args.level, 157 | ) 158 | 159 | 160 | if __name__ == "__main__": 161 | main() 162 | -------------------------------------------------------------------------------- /scripts/utils/model_info.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Script to display information about nupunkt models. 4 | 5 | This script: 6 | 1. Loads a model from a specified file or the default model 7 | 2. Displays key information about the model 8 | 3. Provides options to convert between formats or display details 9 | 10 | Usage: 11 | python model_info.py --input model.bin 12 | python model_info.py --input model.bin --stats 13 | python model_info.py --input model.bin --convert output.json --format json 14 | """ 15 | 16 | import argparse 17 | import sys 18 | import time 19 | from collections import Counter 20 | from pathlib import Path 21 | from typing import Any, Dict 22 | 23 | # Add the parent directory to the path so we can import nupunkt 24 | script_dir = Path(__file__).parent 25 | root_dir = script_dir.parent 26 | sys.path.append(str(root_dir)) 27 | 28 | # Import nupunkt 29 | from nupunkt.utils.compression import ( 30 | load_binary_model, 31 | load_compressed_json, 32 | save_binary_model, 33 | save_compressed_json, 34 | ) 35 | 36 | 37 | def format_size(size_bytes: int) -> str: 38 | """Format file size in human-readable format.""" 39 | kb = size_bytes / 1024 40 | if kb < 1000: 41 | return f"{kb:.2f} KB" 42 | mb = kb / 1024 43 | return f"{mb:.2f} MB" 44 | 45 | 46 | def load_model(input_path: Path) -> Dict[str, Any]: 47 | """ 48 | Load a model from the specified path. 49 | 50 | Args: 51 | input_path: Path to the model file 52 | 53 | Returns: 54 | The model data as a dictionary 55 | """ 56 | start_time = time.time() 57 | 58 | # Load the model data 59 | if str(input_path).endswith(".bin"): 60 | data = load_binary_model(input_path) 61 | model_format = "binary" 62 | else: 63 | data = load_compressed_json(input_path) 64 | model_format = "json_xz" if str(input_path).endswith(".json.xz") else "json" 65 | 66 | load_time = time.time() - start_time 67 | 68 | # Handle the trainer format if present 69 | if "parameters" in data: 70 | params = data["parameters"] 71 | trainer_params = {k: v for k, v in data.items() if k != "parameters"} 72 | else: 73 | params = data 74 | trainer_params = {} 75 | 76 | return { 77 | "path": input_path, 78 | "format": model_format, 79 | "size": Path(input_path).stat().st_size, 80 | "load_time": load_time, 81 | "params": params, 82 | "trainer_params": trainer_params, 83 | } 84 | 85 | 86 | def display_model_info(model_data: Dict[str, Any], show_stats: bool = False) -> None: 87 | """ 88 | Display information about the model. 89 | 90 | Args: 91 | model_data: Dictionary containing model information 92 | show_stats: Whether to display detailed statistics 93 | """ 94 | print("\n=== Model Information ===") 95 | print(f"File: {model_data['path']}") 96 | print(f"Format: {model_data['format']}") 97 | print(f"Size: {format_size(model_data['size'])}") 98 | print(f"Load time: {model_data['load_time']:.3f} seconds") 99 | 100 | params = model_data["params"] 101 | print("\nParameters:") 102 | print(f" Abbreviation types: {len(params.get('abbrev_types', []))}") 103 | print(f" Collocations: {len(params.get('collocations', []))}") 104 | print(f" Sentence starters: {len(params.get('sent_starters', []))}") 105 | print(f" Orthographic context: {len(params.get('ortho_context', {}))}") 106 | 107 | trainer_params = model_data["trainer_params"] 108 | if trainer_params: 109 | print("\nTrainer parameters:") 110 | for key, value in trainer_params.items(): 111 | # Skip large collections like common_abbrevs 112 | if isinstance(value, list | dict | set) and len(value) > 10: 113 | print(f" {key}: {len(value)} items") 114 | else: 115 | print(f" {key}: {value}") 116 | 117 | if show_stats: 118 | display_model_stats(model_data) 119 | 120 | 121 | def display_model_stats(model_data: Dict[str, Any]) -> None: 122 | """ 123 | Display detailed statistics about the model. 124 | 125 | Args: 126 | model_data: Dictionary containing model information 127 | """ 128 | params = model_data["params"] 129 | 130 | print("\n=== Model Statistics ===") 131 | 132 | # Abbreviation statistics 133 | abbrev_types = params.get("abbrev_types", []) 134 | if abbrev_types: 135 | print(f"\nAbbreviation types ({len(abbrev_types)}):") 136 | print(" Most common types by first letter:") 137 | first_letter_counts = Counter(abbr[0] if abbr else "" for abbr in abbrev_types) 138 | for letter, count in first_letter_counts.most_common(10): 139 | print(f" {letter}: {count}") 140 | 141 | print("\n Abbreviation examples:") 142 | for abbr in sorted(abbrev_types)[:10]: 143 | print(f" {abbr}") 144 | 145 | # Collocation statistics 146 | collocations = params.get("collocations", []) 147 | if collocations: 148 | print(f"\nCollocations ({len(collocations)}):") 149 | print(" Examples:") 150 | for w1, w2 in sorted(collocations)[:10]: 151 | print(f" {w1} {w2}") 152 | 153 | # Sentence starter statistics 154 | sent_starters = params.get("sent_starters", []) 155 | if sent_starters: 156 | print(f"\nSentence starters ({len(sent_starters)}):") 157 | print(" Examples:") 158 | for starter in sorted(sent_starters)[:10]: 159 | print(f" {starter}") 160 | 161 | # Orthographic context statistics 162 | ortho_context = params.get("ortho_context", {}) 163 | if ortho_context: 164 | print(f"\nOrthographic context ({len(ortho_context)}):") 165 | print(" Most common flag values:") 166 | flags = Counter(ortho_context.values()) 167 | for flag, count in flags.most_common(5): 168 | print(f" Flag {flag}: {count} types") 169 | 170 | 171 | def convert_model( 172 | model_data: Dict[str, Any], 173 | output_path: Path, 174 | format_type: str, 175 | compression_method: str, 176 | compression_level: int, 177 | ) -> None: 178 | """ 179 | Convert a model to a different format. 180 | 181 | Args: 182 | model_data: Dictionary containing model information 183 | output_path: Path to save the converted model to 184 | format_type: Output format type ('binary', 'json_xz', 'json') 185 | compression_method: Compression method for binary format 186 | compression_level: Compression level (0-9) 187 | """ 188 | print("\n=== Converting Model ===") 189 | print(f"Source: {model_data['path']}") 190 | print(f"Target: {output_path}") 191 | print(f"Format: {format_type}") 192 | if format_type == "binary": 193 | print(f"Compression: {compression_method} (level {compression_level})") 194 | elif format_type == "json_xz": 195 | print(f"Compression level: {compression_level}") 196 | 197 | start_time = time.time() 198 | 199 | # Reconstruct original data format 200 | if model_data["trainer_params"]: 201 | # This was a trainer format 202 | data = model_data["trainer_params"].copy() 203 | data["parameters"] = model_data["params"] 204 | else: 205 | # This was a direct format 206 | data = model_data["params"] 207 | 208 | # Convert the model 209 | if format_type == "binary": 210 | save_binary_model( 211 | data, output_path, compression_method=compression_method, level=compression_level 212 | ) 213 | else: 214 | save_compressed_json( 215 | data, output_path, level=compression_level, use_compression=(format_type == "json_xz") 216 | ) 217 | 218 | convert_time = time.time() - start_time 219 | print(f"Conversion completed in {convert_time:.3f} seconds") 220 | 221 | # Get output file size 222 | output_size = Path(output_path).stat().st_size 223 | print(f"Output file size: {format_size(output_size)}") 224 | 225 | # Show compression ratio 226 | ratio = output_size / model_data["size"] 227 | print(f"Size ratio: {ratio:.3f} (< 1 means smaller)") 228 | 229 | print(f"Model successfully converted and saved to {output_path}") 230 | 231 | 232 | def main(): 233 | """Process command-line arguments and run the requested operations.""" 234 | parser = argparse.ArgumentParser(description="Display information about nupunkt models") 235 | parser.add_argument( 236 | "--input", 237 | type=str, 238 | default=None, 239 | help="Path to the input model file (default: use the default model)", 240 | ) 241 | parser.add_argument( 242 | "--stats", action="store_true", help="Display detailed statistics about the model" 243 | ) 244 | parser.add_argument( 245 | "--convert", 246 | type=str, 247 | default=None, 248 | help="Convert the model to a different format and save to the specified path", 249 | ) 250 | parser.add_argument( 251 | "--format", 252 | type=str, 253 | default="binary", 254 | choices=["json", "json_xz", "binary"], 255 | help="Format to convert the model to", 256 | ) 257 | parser.add_argument( 258 | "--compression", 259 | type=str, 260 | default="lzma", 261 | choices=["none", "zlib", "lzma", "gzip"], 262 | help="Compression method for binary format", 263 | ) 264 | parser.add_argument("--level", type=int, default=6, help="Compression level (0-9)") 265 | 266 | args = parser.parse_args() 267 | 268 | # Use default model if input not specified 269 | if args.input: 270 | input_path = Path(args.input) 271 | else: 272 | from nupunkt.models import get_default_model_path 273 | 274 | input_path = get_default_model_path() 275 | 276 | # Load the model 277 | try: 278 | model_data = load_model(input_path) 279 | except Exception as e: 280 | print(f"Error loading model: {e}") 281 | sys.exit(1) 282 | 283 | # Display model information 284 | display_model_info(model_data, args.stats) 285 | 286 | # Convert the model if requested 287 | if args.convert: 288 | output_path = Path(args.convert) 289 | 290 | # Create output directory if it doesn't exist 291 | output_path.parent.mkdir(parents=True, exist_ok=True) 292 | 293 | try: 294 | convert_model( 295 | model_data=model_data, 296 | output_path=output_path, 297 | format_type=args.format, 298 | compression_method=args.compression, 299 | compression_level=args.level, 300 | ) 301 | except Exception as e: 302 | print(f"Error converting model: {e}") 303 | sys.exit(1) 304 | 305 | 306 | if __name__ == "__main__": 307 | main() 308 | -------------------------------------------------------------------------------- /scripts/utils/optimize_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Script to optimize the model storage format for nupunkt. 4 | 5 | This script: 6 | 1. Loads the current default model 7 | 2. Converts it to binary format with compression 8 | 3. Shows the size difference 9 | """ 10 | 11 | import argparse 12 | import sys 13 | import time 14 | from pathlib import Path 15 | 16 | # Add the parent directory to the path so we can import nupunkt 17 | script_dir = Path(__file__).parent 18 | root_dir = script_dir.parent 19 | sys.path.append(str(root_dir)) 20 | 21 | # Import nupunkt 22 | from nupunkt.models import get_default_model_path 23 | from nupunkt.utils.compression import load_compressed_json, save_binary_model 24 | 25 | 26 | def optimize_model(format_type, compression_method, compression_level, output_path=None): 27 | """Optimize the model storage format.""" 28 | # Get the current model path 29 | current_model_path = get_default_model_path() 30 | print(f"Current model: {current_model_path}") 31 | 32 | # Load the model data 33 | print("Loading model data...") 34 | data = load_compressed_json(current_model_path) 35 | 36 | # Get original file size 37 | original_size = Path(current_model_path).stat().st_size 38 | print(f"Original size: {original_size / 1024:.2f} KB") 39 | 40 | # Determine output path 41 | if output_path is None: 42 | if format_type == "binary": 43 | output_path = root_dir / "nupunkt" / "models" / "default_model.bin" 44 | else: 45 | output_path = current_model_path 46 | else: 47 | output_path = Path(output_path) 48 | 49 | # Save in binary format 50 | print( 51 | f"Saving in {format_type} format with {compression_method} compression (level {compression_level})..." 52 | ) 53 | 54 | if format_type == "binary": 55 | save_binary_model( 56 | data, output_path, compression_method=compression_method, level=compression_level 57 | ) 58 | else: 59 | from nupunkt.utils.compression import save_compressed_json 60 | 61 | save_compressed_json( 62 | data, output_path, level=compression_level, use_compression=(format_type == "json_xz") 63 | ) 64 | 65 | # Get new file size 66 | new_size = Path(output_path).stat().st_size 67 | print(f"New size: {new_size / 1024:.2f} KB") 68 | 69 | # Calculate compression ratio 70 | ratio = new_size / original_size 71 | print(f"Compression ratio: {ratio:.3f} (smaller is better)") 72 | 73 | # Verify the model can be loaded 74 | print("Verifying model can be loaded...") 75 | start_time = time.time() 76 | if format_type == "binary": 77 | from nupunkt.utils.compression import load_binary_model 78 | 79 | _ = load_binary_model(output_path) 80 | else: 81 | _ = load_compressed_json(output_path) 82 | load_time = time.time() - start_time 83 | print(f"Model loaded successfully in {load_time:.5f} seconds") 84 | 85 | return output_path 86 | 87 | 88 | def main(): 89 | """Run the model optimization.""" 90 | parser = argparse.ArgumentParser(description="Optimize the model storage format for nupunkt") 91 | parser.add_argument( 92 | "--format", 93 | type=str, 94 | default="binary", 95 | choices=["json", "json_xz", "binary"], 96 | help="Format to save the model in", 97 | ) 98 | parser.add_argument( 99 | "--compression", 100 | type=str, 101 | default="zlib", 102 | choices=["none", "zlib", "lzma", "gzip"], 103 | help="Compression method for binary format", 104 | ) 105 | parser.add_argument("--level", type=int, default=6, help="Compression level (0-9)") 106 | parser.add_argument( 107 | "--output", type=str, default=None, help="Custom output path for the optimized model" 108 | ) 109 | 110 | args = parser.parse_args() 111 | 112 | # Optimize the model 113 | output_path = optimize_model( 114 | format_type=args.format, 115 | compression_method=args.compression, 116 | compression_level=args.level, 117 | output_path=args.output, 118 | ) 119 | 120 | print(f"Model optimized successfully and saved to: {output_path}") 121 | 122 | 123 | if __name__ == "__main__": 124 | main() 125 | -------------------------------------------------------------------------------- /scripts/utils/profile_tokenizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Profiling script for nupunkt sentence tokenizer. 4 | 5 | This script profiles the performance of the sentence tokenizer 6 | to identify bottlenecks and optimization opportunities. 7 | """ 8 | 9 | import cProfile 10 | import io 11 | import pstats 12 | import sys 13 | from pathlib import Path 14 | 15 | # Add the parent directory to the path so we can import nupunkt 16 | script_dir = Path(__file__).parent.parent 17 | root_dir = script_dir.parent 18 | sys.path.append(str(root_dir)) 19 | 20 | # Import nupunkt 21 | from nupunkt.models import load_default_model 22 | 23 | # Import test data loading function 24 | sys.path.append(str(script_dir)) 25 | from test_default_model import load_test_data 26 | 27 | 28 | def profile_tokenization(): 29 | """Profile the sentence tokenization process.""" 30 | print("Loading default model...") 31 | tokenizer = load_default_model() 32 | 33 | # Load test data 34 | test_path = root_dir / "data" / "test.jsonl.gz" 35 | if not test_path.exists(): 36 | print(f"Error: Test data file not found: {test_path}") 37 | return 38 | 39 | test_texts = load_test_data(test_path) 40 | print(f"Loaded {len(test_texts)} test documents.") 41 | 42 | # Create a sample with first N characters to avoid extremely long profiles 43 | sample_size = min(100000, len(test_texts[0])) 44 | sample_text = test_texts[0][:sample_size] 45 | 46 | print(f"Profiling tokenization of a {sample_size} character sample...") 47 | 48 | # Function to profile 49 | def tokenize_sample(): 50 | # Let's tokenize the sample multiple times to get more data 51 | for _ in range(5): 52 | sentences = tokenizer.tokenize(sample_text) 53 | # Force evaluation of generator 54 | _ = len(sentences) # Force evaluation 55 | 56 | # Function to profile specific methods with caching 57 | def profile_cached_methods(): 58 | print("\nProfiling LRU-cached methods specifically...") 59 | 60 | # Create text samples that will trigger the cached methods 61 | sample_tokens = list(tokenizer._tokenize_words(sample_text))[:100] 62 | 63 | # For orthographic heuristic profiling 64 | token_samples = sample_tokens[:20] # Get a few sample tokens 65 | 66 | # Profile orthographic heuristic 67 | profile_ortho = cProfile.Profile() 68 | profile_ortho.enable() 69 | for token in token_samples: 70 | for _ in range(100): # Repeat multiple times 71 | tokenizer._ortho_heuristic(token) 72 | profile_ortho.disable() 73 | 74 | # Print orthographic heuristic stats 75 | s = io.StringIO() 76 | ps = pstats.Stats(profile_ortho, stream=s).sort_stats("cumulative") 77 | ps.print_stats(20) 78 | print("\nOrthographic Heuristic Profile:") 79 | print(s.getvalue()) 80 | 81 | # Get some token types for sentence starter lookup 82 | token_types = [token.type_no_sentperiod for token in token_samples] 83 | 84 | # Profile sentence starter lookup 85 | profile_sent_starter = cProfile.Profile() 86 | profile_sent_starter.enable() 87 | for typ in token_types: 88 | for _ in range(100): # Repeat multiple times 89 | tokenizer._is_sent_starter(typ) 90 | profile_sent_starter.disable() 91 | 92 | # Print sentence starter lookup stats 93 | s = io.StringIO() 94 | ps = pstats.Stats(profile_sent_starter, stream=s).sort_stats("cumulative") 95 | ps.print_stats(20) 96 | print("\nSentence Starter Lookup Profile:") 97 | print(s.getvalue()) 98 | 99 | # Get some candidate abbreviations 100 | abbrev_candidates = ["mr", "dr", "inc", "ltd", "co", "corp", "prof", "jan", "feb", "mar"] 101 | 102 | # Profile abbreviation lookup 103 | profile_abbrev = cProfile.Profile() 104 | profile_abbrev.enable() 105 | for abbr in abbrev_candidates: 106 | for _ in range(50): # Repeat multiple times 107 | # We need to access the base method directly from the tokenizer instance 108 | tokenizer._is_abbreviation(abbr) 109 | profile_abbrev.disable() 110 | 111 | # Print abbreviation lookup stats 112 | s = io.StringIO() 113 | ps = pstats.Stats(profile_abbrev, stream=s).sort_stats("cumulative") 114 | ps.print_stats(20) 115 | print("\nAbbreviation Lookup Profile:") 116 | print(s.getvalue()) 117 | 118 | # Run the main profiler 119 | profile = cProfile.Profile() 120 | profile.enable() 121 | tokenize_sample() 122 | profile.disable() 123 | 124 | # Print sorted stats 125 | s = io.StringIO() 126 | ps = pstats.Stats(profile, stream=s).sort_stats("cumulative") 127 | ps.print_stats(30) # Print top 30 items 128 | print(s.getvalue()) 129 | 130 | # Profile specific cached methods 131 | profile_cached_methods() 132 | 133 | # Save profile results to a file for later analysis 134 | output_dir = script_dir / "utils" / "profiles" 135 | output_dir.mkdir(parents=True, exist_ok=True) 136 | profile_path = output_dir / "tokenizer_profile.prof" 137 | ps.dump_stats(profile_path) 138 | print(f"Saved profile results to {profile_path}") 139 | 140 | # Also print stats sorted by time 141 | s = io.StringIO() 142 | ps = pstats.Stats(profile, stream=s).sort_stats("time") 143 | ps.print_stats(30) # Print top 30 items 144 | print("\nStats sorted by time:") 145 | print(s.getvalue()) 146 | 147 | return profile_path 148 | 149 | 150 | def main(): 151 | """Run the profiling.""" 152 | profile_path = profile_tokenization() 153 | 154 | # Print instructions for viewing the profile with snakeviz 155 | if profile_path: 156 | print("\nTo visualize the profile with snakeviz, run:") 157 | print(" pip install snakeviz") 158 | print(f" snakeviz {profile_path}") 159 | 160 | 161 | if __name__ == "__main__": 162 | main() 163 | -------------------------------------------------------------------------------- /scripts/utils/test_tokenizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Script to test the nupunkt tokenizer on custom text. 4 | 5 | This script: 6 | 1. Loads the default model 7 | 2. Tokenizes text provided by the user 8 | 3. Shows the resulting sentences 9 | """ 10 | 11 | import argparse 12 | import sys 13 | from pathlib import Path 14 | from typing import Optional 15 | 16 | # Add the parent directory to the path so we can import nupunkt 17 | script_dir = Path(__file__).parent 18 | root_dir = script_dir.parent 19 | sys.path.append(str(root_dir)) 20 | 21 | # Import nupunkt 22 | from nupunkt.models import load_default_model 23 | 24 | 25 | def get_test_text() -> str: 26 | """Return sample test text if none is provided.""" 27 | return """ 28 | Dr. Smith went to Washington, D.C. He was very excited about the trip. 29 | The company (Ltd.) was founded in 1997. It has grown significantly since then. 30 | This text contains an ellipsis... And this is a new sentence. 31 | Let me give you an example, e.g. this one. Did you understand it? 32 | The meeting is at 3 p.m. Don't be late! 33 | Under 18 U.S.C. 12, this is a legal citation. The next sentence begins here. 34 | The patient presented with abd. pain. CT scan was ordered. 35 | The table shows results for Jan. Feb. and Mar. Each month shows improvement. 36 | Visit the website at www.example.com. There you'll find more information. 37 | She scored 92 vs. 85 in the previous match. Her performance has improved. 38 | The temperature was 32 deg. C. It was quite hot that day. 39 | """ 40 | 41 | 42 | def tokenize_text(text: str, model_path: Optional[Path] = None) -> None: 43 | """ 44 | Tokenize the given text and print the results. 45 | 46 | Args: 47 | text: The text to tokenize 48 | model_path: Optional path to a custom model 49 | """ 50 | # Load the tokenizer 51 | print("Loading default model...") 52 | tokenizer = load_default_model() 53 | print("Model loaded successfully.") 54 | 55 | # Tokenize the text 56 | print("\n=== Tokenizing Text ===") 57 | print(f"Input text:\n{text}") 58 | 59 | sentences = tokenizer.tokenize(text) 60 | 61 | print("\n=== Tokenization Results ===") 62 | for i, sentence in enumerate(sentences, 1): 63 | print(f"Sentence {i}: {sentence.strip()}") 64 | 65 | print(f"\nFound {len(sentences)} sentences.") 66 | 67 | 68 | def main() -> None: 69 | """Process command-line arguments and tokenize text.""" 70 | parser = argparse.ArgumentParser(description="Test the nupunkt tokenizer on custom text") 71 | parser.add_argument( 72 | "--text", type=str, default=None, help="Text to tokenize (default: use sample text)" 73 | ) 74 | parser.add_argument( 75 | "--file", type=str, default=None, help="Path to a file containing text to tokenize" 76 | ) 77 | parser.add_argument("--model", type=str, default=None, help="Path to a custom model file") 78 | 79 | args = parser.parse_args() 80 | 81 | # Get the text to tokenize 82 | if args.file: 83 | with open(args.file, encoding="utf-8") as f: 84 | text = f.read() 85 | elif args.text: 86 | text = args.text 87 | else: 88 | text = get_test_text() 89 | 90 | # Tokenize the text 91 | tokenize_text(text, args.model) 92 | 93 | 94 | if __name__ == "__main__": 95 | main() 96 | -------------------------------------------------------------------------------- /sentences.py: -------------------------------------------------------------------------------- 1 | """ 2 | Sentence segmentation functionality. 3 | """ 4 | 5 | from typing import List, Tuple, Optional, TYPE_CHECKING 6 | 7 | if TYPE_CHECKING: 8 | from charboundary.segmenters.base import TextSegmenter 9 | 10 | 11 | class SentenceSegmenter: 12 | """ 13 | Handles segmenting text into sentences. 14 | """ 15 | 16 | @staticmethod 17 | def segment_to_sentences( 18 | segmenter: "TextSegmenter", 19 | text: str, 20 | streaming: bool = False, 21 | threshold: Optional[float] = None, 22 | ) -> List[str]: 23 | """ 24 | Segment text into a list of sentences. 25 | 26 | Args: 27 | segmenter: The TextSegmenter to use 28 | text (str): Text to segment 29 | streaming (bool, optional): Whether to use streaming mode for memory efficiency. 30 | Defaults to False. 31 | threshold (float, optional): Probability threshold for classification (0.0-1.0). 32 | Values below 0.5 favor recall (fewer false negatives), 33 | values above 0.5 favor precision (fewer false positives). 34 | If None, use the model's default threshold. 35 | Defaults to None. 36 | 37 | Returns: 38 | List[str]: List of sentences 39 | """ 40 | # Quick return for empty text 41 | if not text: 42 | return [] 43 | 44 | # Use optimized segmentation based on text size 45 | if streaming and len(text) > 10000: 46 | # For large texts, use streaming segmentation 47 | # Note: streaming mode doesn't currently support custom threshold 48 | segmented_parts = list(segmenter.segment_text_streaming(text)) 49 | segmented_text = "".join(segmented_parts) 50 | else: 51 | # For smaller texts, use regular segmentation 52 | segmented_text = segmenter.segment_text(text, threshold=threshold) 53 | 54 | # Fast path: if no sentence tags were added, return the whole text as one sentence 55 | if segmenter.SENTENCE_TAG not in segmented_text: 56 | return [text] if text else [] 57 | 58 | # More efficient string splitting and processing 59 | # Pre-compute tag lengths for performance 60 | para_tag_len = len(segmenter.PARAGRAPH_TAG) 61 | 62 | # Split by sentence tag, but handle paragraph tags properly 63 | sentences = [] 64 | segments = segmented_text.split(segmenter.SENTENCE_TAG) 65 | 66 | # First segment is always before any sentence tag 67 | if segments[0]: 68 | sentences.append(segments[0]) 69 | 70 | # Process remaining segments (each starts after a sentence tag) 71 | for segment in segments[1:]: 72 | # Remove any paragraph tags at the beginning of the segment 73 | if segment.startswith(segmenter.PARAGRAPH_TAG): 74 | segment = segment[para_tag_len:] 75 | 76 | # Remove any paragraph tags in the segment 77 | segment = segment.replace(segmenter.PARAGRAPH_TAG, "") 78 | 79 | if segment: 80 | sentences.append(segment) 81 | 82 | # Post-processing to fix incorrectly segmented quotation marks 83 | # This handles edge cases where the model fails to correctly process quotes 84 | i = 0 85 | while i < len(sentences) - 1: 86 | # Handle case where a sentence ends with a quote and next "sentence" is just a quote 87 | if (sentences[i].endswith('"') or sentences[i].endswith('"')) and sentences[ 88 | i + 1 89 | ].strip() == '"': 90 | # Merge the quote with the following sentence 91 | if i + 2 < len(sentences): 92 | sentences[i + 2] = '" ' + sentences[i + 2] 93 | sentences.pop(i + 1) # Remove the standalone quote 94 | continue 95 | # Handle case where a "sentence" is just a quote that should connect to the next sentence 96 | if sentences[i].strip() == '"' and i + 1 < len(sentences): 97 | # Join with the next sentence 98 | sentences[i + 1] = '" ' + sentences[i + 1] 99 | sentences.pop(i) # Remove the standalone quote 100 | continue 101 | i += 1 102 | 103 | return sentences 104 | 105 | @classmethod 106 | def get_sentence_spans_with_text( 107 | cls, 108 | segmenter: "TextSegmenter", 109 | text: str, 110 | streaming: bool = False, 111 | threshold: Optional[float] = None, 112 | ) -> List[Tuple[str, Tuple[int, int]]]: 113 | """ 114 | Segment text into a list of sentences with their character spans. 115 | 116 | Each span is a tuple of (start_idx, end_idx) where start_idx is inclusive 117 | and end_idx is exclusive (following Python's slicing convention). 118 | The spans are guaranteed to cover the entire input text without gaps. 119 | 120 | Args: 121 | segmenter: The TextSegmenter to use 122 | text (str): Text to segment 123 | streaming (bool, optional): Whether to use streaming mode for memory efficiency. 124 | Defaults to False. 125 | threshold (float, optional): Probability threshold for classification (0.0-1.0). 126 | Values below 0.5 favor recall (fewer false negatives), 127 | values above 0.5 favor precision (fewer false positives). 128 | If None, use the model's default threshold. 129 | Defaults to None. 130 | 131 | Returns: 132 | List[tuple[str, tuple[int, int]]]: List of tuples containing (sentence, (start_index, end_index)) 133 | """ 134 | from charboundary.segmenters.spans import SpanHandler 135 | 136 | # Quick return for empty text 137 | if not text: 138 | return [] 139 | 140 | # Find boundary positions 141 | boundary_positions = SpanHandler.find_boundary_positions( 142 | segmenter, text, threshold=threshold 143 | ) 144 | 145 | # If no boundaries found, return the whole text as one sentence 146 | if not boundary_positions: 147 | return [(text, (0, len(text)))] 148 | 149 | # Create spans from boundary positions 150 | result = [] 151 | start_idx = 0 152 | 153 | # Build spans from boundaries 154 | for end_idx in boundary_positions: 155 | sentence = text[start_idx:end_idx] 156 | # Include all spans, even if only whitespace 157 | result.append((sentence, (start_idx, end_idx))) 158 | start_idx = end_idx 159 | 160 | # Add final segment if needed (for text after the last boundary) 161 | if start_idx < len(text): 162 | sentence = text[start_idx:] 163 | result.append((sentence, (start_idx, len(text)))) 164 | 165 | return result 166 | 167 | @classmethod 168 | def get_sentence_spans( 169 | cls, 170 | segmenter: "TextSegmenter", 171 | text: str, 172 | streaming: bool = False, 173 | threshold: Optional[float] = None, 174 | ) -> List[Tuple[int, int]]: 175 | """ 176 | Get the character spans for each sentence in the text. 177 | 178 | Each span is a tuple of (start_idx, end_idx) where start_idx is inclusive 179 | and end_idx is exclusive (following Python's slicing convention). 180 | The spans are guaranteed to cover the entire input text without gaps. 181 | 182 | Args: 183 | segmenter: The TextSegmenter to use 184 | text (str): Text to segment 185 | streaming (bool, optional): Whether to use streaming mode for memory efficiency. 186 | Defaults to False. 187 | threshold (float, optional): Probability threshold for classification (0.0-1.0). 188 | Values below 0.5 favor recall (fewer false negatives), 189 | values above 0.5 favor precision (fewer false positives). 190 | If None, use the model's default threshold. 191 | Defaults to None. 192 | 193 | Returns: 194 | List[tuple[int, int]]: List of character spans (start_index, end_index) 195 | """ 196 | segments_with_spans = cls.get_sentence_spans_with_text( 197 | segmenter, text, streaming=streaming, threshold=threshold 198 | ) 199 | return [span for _, span in segments_with_spans] -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alea-institute/nupunkt/29d056aba0f6c9e0f43ee1e36d2638260027af0c/tests/__init__.py -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | """Pytest configuration for nupunkt tests.""" 2 | 3 | from pathlib import Path 4 | from typing import List 5 | 6 | import pytest 7 | 8 | from nupunkt import ( 9 | PunktParameters, 10 | PunktSentenceTokenizer, 11 | ) 12 | 13 | 14 | @pytest.fixture 15 | def sample_text() -> str: 16 | """Return a sample text for testing.""" 17 | return """ 18 | This is a sample text. It contains multiple sentences, including abbreviations like Dr. Johnson and Mr. Smith. 19 | The U.S.A. is a country in North America. This example has numbers like 3.14, which aren't abbreviations! 20 | Prof. Jones works at the university. She has a Ph.D. in Computer Science. 21 | """ 22 | 23 | 24 | @pytest.fixture 25 | def legal_text() -> str: 26 | """Return a sample legal text for testing.""" 27 | return """ 28 | The Court finds as follows. Pursuant to 28 U.S.C. § 1332, diversity jurisdiction exists. 29 | The plaintiff, Mr. Smith, filed suit against Corp. Inc. on Jan. 15, 2023. Judge Davis presided. 30 | The case was dismissed with prejudice. See Smith v. Corp., 123 F.Supp.2d 456 (N.D. Cal. 2023). 31 | """ 32 | 33 | 34 | @pytest.fixture 35 | def scientific_text() -> str: 36 | """Return a sample scientific text for testing.""" 37 | return """ 38 | Recent studies show promising results. Fig. 3 demonstrates the correlation between variables A and B. 39 | Dr. Williams et al. (2023) found that approx. 75% of samples exhibited the effect. The p-value was 0.01. 40 | The solution was diluted to 0.5 mg/ml. This conc. was found to be optimal for cell growth. 41 | """ 42 | 43 | 44 | @pytest.fixture 45 | def common_abbreviations() -> List[str]: 46 | """Return a list of common abbreviations for testing.""" 47 | return ["dr", "mr", "mrs", "ms", "prof", "etc", "e.g", "i.e", "u.s.a", "ph.d"] 48 | 49 | 50 | @pytest.fixture 51 | def punkt_params() -> PunktParameters: 52 | """Return a basic PunktParameters object for testing.""" 53 | params = PunktParameters() 54 | params.abbrev_types.update(["dr", "mr", "prof", "etc", "e.g", "i.e"]) 55 | return params 56 | 57 | 58 | @pytest.fixture 59 | def test_tokenizer(punkt_params) -> PunktSentenceTokenizer: 60 | """Return a pre-configured tokenizer for testing.""" 61 | return PunktSentenceTokenizer(punkt_params) 62 | 63 | 64 | @pytest.fixture 65 | def data_dir() -> Path: 66 | """Return the path to the test data directory.""" 67 | return Path(__file__).parent / "data" 68 | 69 | 70 | @pytest.fixture 71 | def create_test_data(data_dir) -> None: 72 | """Create some test data files.""" 73 | # Create a small JSONL file for testing 74 | data_dir.mkdir(parents=True, exist_ok=True) 75 | jsonl_path = data_dir / "test_small.jsonl" 76 | 77 | with open(jsonl_path, "w", encoding="utf-8") as f: 78 | f.write('{"text": "This is sentence one. This is sentence two."}\n') 79 | f.write('{"text": "Dr. Smith visited Mr. Jones. They discussed the U.S.A."}\n') 80 | f.write('{"text": "This contains a number 3.14 which is not an abbreviation."}\n') 81 | 82 | # Create a mixed test file 83 | text_path = data_dir / "mixed_text.txt" 84 | with open(text_path, "w", encoding="utf-8") as f: 85 | f.write(""" 86 | This is a paragraph with multiple sentences. It has abbreviations like Dr. and Mr. Smith. 87 | 88 | This is another paragraph. The U.S.A. is mentioned here. Also Prof. Jones with his Ph.D. 89 | 90 | Here are some numbers: 3.14, 2.71, and 1.62 which should not be treated as abbreviations. 91 | """) 92 | 93 | return None 94 | -------------------------------------------------------------------------------- /tests/test_language_vars.py: -------------------------------------------------------------------------------- 1 | """Unit tests for nupunkt language variables module.""" 2 | 3 | import pytest 4 | 5 | from nupunkt.core.language_vars import PunktLanguageVars 6 | 7 | 8 | def test_punkt_language_vars_basic(): 9 | """Test basic properties of PunktLanguageVars.""" 10 | lang_vars = PunktLanguageVars() 11 | 12 | # Test sentence end characters 13 | assert "." in lang_vars.sent_end_chars 14 | assert "?" in lang_vars.sent_end_chars 15 | assert "!" in lang_vars.sent_end_chars 16 | 17 | # Test internal punctuation 18 | assert "," in lang_vars.internal_punctuation 19 | assert ":" in lang_vars.internal_punctuation 20 | assert ";" in lang_vars.internal_punctuation 21 | 22 | 23 | def test_punkt_language_vars_word_tokenize(): 24 | """Test word tokenization in PunktLanguageVars.""" 25 | lang_vars = PunktLanguageVars() 26 | 27 | # Test simple word tokenization 28 | words = lang_vars.word_tokenize("Hello world") 29 | assert words == ["Hello", "world"] 30 | 31 | # Test with punctuation 32 | words = lang_vars.word_tokenize("Hello, world!") 33 | assert words == ["Hello", ",", "world", "!"] 34 | 35 | # Test with period - accepting current behavior which doesn't split period 36 | words = lang_vars.word_tokenize("End.") 37 | assert words == ["End."] 38 | 39 | 40 | def test_punkt_language_vars_properties(): 41 | """Test properties of PunktLanguageVars.""" 42 | lang_vars = PunktLanguageVars() 43 | 44 | # Test regex properties 45 | assert hasattr(lang_vars, "word_tokenize_pattern") 46 | assert hasattr(lang_vars, "period_context_pattern") 47 | 48 | # Test case detection via first_upper/first_lower 49 | from nupunkt.core.tokens import PunktToken 50 | 51 | upper_token = PunktToken("Word") 52 | assert upper_token.first_upper 53 | assert not upper_token.first_lower 54 | 55 | lower_token = PunktToken("word") 56 | assert not lower_token.first_upper 57 | assert lower_token.first_lower 58 | 59 | 60 | def test_punkt_language_vars_pattern_matching(): 61 | """Test pattern matching in PunktLanguageVars.""" 62 | lang_vars = PunktLanguageVars() 63 | 64 | # Test period context pattern (need to access the property first to initialize it) 65 | pattern = lang_vars.period_context_pattern 66 | assert pattern.search("Mr. Smith") 67 | assert pattern.search("U.S.A. ") 68 | 69 | # Test word tokenize pattern 70 | assert lang_vars.word_tokenize_pattern is not None 71 | assert "Word" in lang_vars.word_tokenize("Word") 72 | assert "123" in lang_vars.word_tokenize("123") 73 | 74 | # In current implementation, "Word." is a single token 75 | tokens = lang_vars.word_tokenize("Word.") 76 | assert "Word." in tokens # Period is attached to the word 77 | 78 | # Test pattern with special chars 79 | tokens = lang_vars.word_tokenize("A. Smith") 80 | assert "A." in tokens 81 | assert "Smith" in tokens 82 | 83 | 84 | def test_punkt_language_vars_custom(): 85 | """Test customizing PunktLanguageVars.""" 86 | 87 | class CustomLanguageVars(PunktLanguageVars): 88 | """A custom language vars class with different sentence endings.""" 89 | 90 | sent_end_chars = (".", "?", "!", ";") # Add semicolon as sentence end 91 | 92 | custom_vars = CustomLanguageVars() 93 | 94 | # Check that our custom class has semicolon as sentence end 95 | assert ";" in custom_vars.sent_end_chars 96 | assert ";" not in PunktLanguageVars().sent_end_chars 97 | 98 | 99 | @pytest.mark.benchmark(group="language_vars") 100 | def test_word_tokenize_benchmark(benchmark): 101 | """Benchmark the word_tokenize method.""" 102 | lang_vars = PunktLanguageVars() 103 | 104 | text = """ 105 | This is a benchmark test for word tokenization functionality in PunktLanguageVars. 106 | It contains multiple sentences, with various punctuation marks like commas, periods, 107 | question marks, and exclamation points! Does it handle all of these correctly? 108 | Numbers like 3.14 and abbreviations like Dr. Smith should be handled properly. 109 | U.S.A. is a country. This sentence ends the paragraph. 110 | 111 | This starts a new paragraph. It should be tokenized correctly as well. 112 | """ 113 | 114 | # Run the benchmark 115 | tokens = benchmark(lambda: lang_vars.word_tokenize(text)) 116 | 117 | # Verify we got reasonable results 118 | assert len(tokens) > 0 119 | assert "This" in tokens 120 | assert "benchmark" in tokens 121 | 122 | # Verify specific tokens that include periods 123 | assert any(token.endswith(".") for token in tokens) 124 | assert "U.S.A." in tokens or "U.S.A" in tokens 125 | -------------------------------------------------------------------------------- /tests/test_parameters.py: -------------------------------------------------------------------------------- 1 | """Unit tests for nupunkt parameters module.""" 2 | 3 | import tempfile 4 | from pathlib import Path 5 | 6 | import pytest 7 | 8 | from nupunkt.core.parameters import PunktParameters 9 | 10 | 11 | def test_punkt_parameters_basic(): 12 | """Test basic functionality of PunktParameters.""" 13 | params = PunktParameters() 14 | 15 | # Should start empty 16 | assert len(params.abbrev_types) == 0 17 | assert len(params.collocations) == 0 18 | assert len(params.sent_starters) == 0 19 | 20 | # Test adding abbreviations 21 | params.abbrev_types.add("dr") 22 | assert "dr" in params.abbrev_types 23 | 24 | # Test adding collocations 25 | params.collocations.add(("new", "york")) 26 | assert ("new", "york") in params.collocations 27 | 28 | # Test adding sentence starters 29 | params.sent_starters.add("however") 30 | assert "however" in params.sent_starters 31 | 32 | # Test the new helper methods 33 | params.add_abbreviation("mr") 34 | assert "mr" in params.abbrev_types 35 | 36 | params.add_sent_starter("furthermore") 37 | assert "furthermore" in params.sent_starters 38 | 39 | # Test updating sets 40 | params.update_abbrev_types({"prof", "dr", "ms"}) 41 | assert "prof" in params.abbrev_types 42 | assert "ms" in params.abbrev_types 43 | 44 | params.update_sent_starters({"additionally", "moreover"}) 45 | assert "additionally" in params.sent_starters 46 | assert "moreover" in params.sent_starters 47 | 48 | 49 | def test_punkt_parameters_ortho_context(): 50 | """Test orthographic context in PunktParameters.""" 51 | params = PunktParameters() 52 | 53 | # Test adding orthographic context 54 | params.add_ortho_context("word", 1) 55 | assert params.ortho_context["word"] == 1 56 | 57 | # Test adding to existing context 58 | params.add_ortho_context("word", 2) 59 | assert params.ortho_context["word"] == 3 # 1 | 2 = 3 60 | 61 | 62 | def test_punkt_parameters_save_load(): 63 | """Test saving and loading parameters.""" 64 | with tempfile.TemporaryDirectory() as tmpdir: 65 | # Create parameters 66 | params = PunktParameters() 67 | params.abbrev_types.update(["dr", "mr", "prof"]) 68 | params.collocations.add(("new", "york")) 69 | params.sent_starters.add("however") 70 | params.add_ortho_context("word", 1) 71 | 72 | # Save uncompressed JSON 73 | uncompressed_path = Path(tmpdir) / "params.json" 74 | params.save(uncompressed_path, format_type="json") 75 | 76 | # Verify it's saved as expected 77 | assert uncompressed_path.exists() 78 | 79 | # Load back 80 | loaded_params = PunktParameters.load(uncompressed_path) 81 | 82 | # Verify it's the same 83 | assert "dr" in loaded_params.abbrev_types 84 | assert "mr" in loaded_params.abbrev_types 85 | assert "prof" in loaded_params.abbrev_types 86 | assert ("new", "york") in loaded_params.collocations 87 | assert "however" in loaded_params.sent_starters 88 | assert loaded_params.ortho_context["word"] == 1 89 | 90 | # Save compressed JSON 91 | compressed_path = Path(tmpdir) / "params.json.xz" 92 | params.save(compressed_path, format_type="json_xz") 93 | 94 | # Verify it's saved as expected 95 | assert compressed_path.exists() 96 | 97 | # Load back 98 | loaded_params = PunktParameters.load(compressed_path) 99 | 100 | # Verify it's the same 101 | assert "dr" in loaded_params.abbrev_types 102 | assert "mr" in loaded_params.abbrev_types 103 | assert "prof" in loaded_params.abbrev_types 104 | assert ("new", "york") in loaded_params.collocations 105 | assert "however" in loaded_params.sent_starters 106 | assert loaded_params.ortho_context["word"] == 1 107 | 108 | # Save binary format 109 | binary_path = Path(tmpdir) / "params.bin" 110 | params.save(binary_path, format_type="binary") 111 | 112 | # Verify it's saved as expected 113 | assert binary_path.exists() 114 | 115 | # Load back 116 | loaded_params = PunktParameters.load(binary_path) 117 | 118 | # Verify it's the same 119 | assert "dr" in loaded_params.abbrev_types 120 | assert "mr" in loaded_params.abbrev_types 121 | assert "prof" in loaded_params.abbrev_types 122 | assert ("new", "york") in loaded_params.collocations 123 | assert "however" in loaded_params.sent_starters 124 | assert loaded_params.ortho_context["word"] == 1 125 | 126 | 127 | def test_punkt_parameters_json_methods(): 128 | """Test to_json and from_json methods.""" 129 | # Create parameters 130 | params = PunktParameters() 131 | params.abbrev_types.update(["dr", "mr", "prof"]) 132 | params.collocations.add(("new", "york")) 133 | params.sent_starters.add("however") 134 | params.add_ortho_context("word", 1) 135 | 136 | # Convert to JSON 137 | json_data = params.to_json() 138 | 139 | # Verify it's a dict 140 | assert isinstance(json_data, dict) 141 | assert "abbrev_types" in json_data 142 | assert "collocations" in json_data 143 | assert "sent_starters" in json_data 144 | assert "ortho_context" in json_data 145 | 146 | # Convert back from JSON 147 | new_params = PunktParameters.from_json(json_data) 148 | 149 | # Verify it's the same 150 | assert set(new_params.abbrev_types) == set(params.abbrev_types) 151 | assert set(new_params.collocations) == set(params.collocations) 152 | assert set(new_params.sent_starters) == set(params.sent_starters) 153 | assert new_params.ortho_context == params.ortho_context 154 | 155 | 156 | def test_regex_pattern_compilation(): 157 | """Test regex pattern compilation for abbreviations and sentence starters.""" 158 | params = PunktParameters() 159 | 160 | # Add a substantial number of abbreviations 161 | for i in range(60): 162 | params.add_abbreviation(f"abbr{i}") 163 | 164 | # Add a substantial number of sentence starters 165 | for i in range(60): 166 | params.add_sent_starter(f"start{i}") 167 | 168 | # Get the patterns 169 | abbrev_pattern = params.get_abbrev_pattern() 170 | sent_starter_pattern = params.get_sent_starter_pattern() 171 | 172 | # Verify the patterns work as expected 173 | assert abbrev_pattern.match("abbr0") 174 | assert abbrev_pattern.match("abbr59") 175 | assert not abbrev_pattern.match("nonexistent") 176 | 177 | assert sent_starter_pattern.match("start0") 178 | assert sent_starter_pattern.match("start59") 179 | assert not sent_starter_pattern.match("nonexistent") 180 | 181 | # Test case insensitivity 182 | assert abbrev_pattern.match("ABBR10") 183 | assert sent_starter_pattern.match("START10") 184 | 185 | # Test pattern caching 186 | abbrev_pattern_second = params.get_abbrev_pattern() 187 | assert abbrev_pattern is abbrev_pattern_second # Same object, not recompiled 188 | 189 | # Test invalidation on update 190 | params.add_abbreviation("newabbr") 191 | abbrev_pattern_third = params.get_abbrev_pattern() 192 | assert abbrev_pattern is not abbrev_pattern_third # Should be recompiled 193 | assert abbrev_pattern_third.match("newabbr") # Should include the new abbreviation 194 | 195 | 196 | @pytest.mark.benchmark(group="parameters") 197 | def test_parameters_save_benchmark(benchmark): 198 | """Benchmark parameter saving with/without compression.""" 199 | # Create parameters with substantial data 200 | params = PunktParameters() 201 | 202 | # Add a significant number of abbreviations 203 | params.abbrev_types.update([f"abbrev{i}" for i in range(500)]) 204 | 205 | # Add collocations 206 | for i in range(200): 207 | params.collocations.add((f"word{i}", f"word{i + 1}")) 208 | 209 | # Add sentence starters 210 | params.sent_starters.update([f"starter{i}" for i in range(100)]) 211 | 212 | # Add orthographic contexts 213 | for i in range(1000): 214 | params.add_ortho_context(f"word{i}", i % 4) 215 | 216 | def save_compressed(): 217 | with tempfile.NamedTemporaryFile(suffix=".json.xz", delete=True) as tmp: 218 | params.save(tmp.name, format_type="json_xz", compression_level=1) 219 | # Get file size 220 | size = Path(tmp.name).stat().st_size 221 | # Return size to see in benchmark results 222 | return size 223 | 224 | # Run the benchmark 225 | file_size = benchmark(save_compressed) 226 | 227 | # Simple verification that something was saved 228 | assert file_size > 0 229 | 230 | 231 | @pytest.mark.benchmark(group="parameters") 232 | def test_parameters_load_benchmark(benchmark): 233 | """Benchmark parameter loading with/without compression.""" 234 | # Create parameters with substantial data 235 | params = PunktParameters() 236 | 237 | # Add a significant number of abbreviations 238 | params.abbrev_types.update([f"abbrev{i}" for i in range(500)]) 239 | 240 | # Add collocations 241 | for i in range(200): 242 | params.collocations.add((f"word{i}", f"word{i + 1}")) 243 | 244 | # Add sentence starters 245 | params.sent_starters.update([f"starter{i}" for i in range(100)]) 246 | 247 | # Add orthographic contexts 248 | for i in range(1000): 249 | params.add_ortho_context(f"word{i}", i % 4) 250 | 251 | # Save to a temporary file first 252 | with tempfile.NamedTemporaryFile(suffix=".json.xz", delete=False) as tmp: 253 | params.save(tmp.name, format_type="json_xz", compression_level=1) 254 | temp_path = tmp.name 255 | 256 | def load_compressed(): 257 | loaded_params = PunktParameters.load(temp_path) 258 | return loaded_params 259 | 260 | # Run the benchmark 261 | loaded_params = benchmark(load_compressed) 262 | 263 | # Cleanup 264 | Path(temp_path).unlink() 265 | 266 | # Simple verification that it loaded correctly 267 | assert len(loaded_params.abbrev_types) == 500 268 | assert len(loaded_params.collocations) == 200 269 | assert len(loaded_params.sent_starters) == 100 270 | assert len(loaded_params.ortho_context) == 1000 271 | -------------------------------------------------------------------------------- /tests/test_tokens.py: -------------------------------------------------------------------------------- 1 | """Unit tests for nupunkt token module.""" 2 | 3 | import pytest 4 | 5 | from nupunkt.core.tokens import PunktToken, create_punkt_token 6 | 7 | 8 | def test_punkt_token_basic(): 9 | """Test basic properties of PunktToken.""" 10 | # Simple token 11 | token = PunktToken("word") 12 | assert token.tok == "word" 13 | assert token.type == "word" 14 | assert not token.period_final 15 | 16 | # Token with period 17 | token = PunktToken("word.") 18 | assert token.tok == "word." 19 | assert token.type == "word." 20 | assert token.period_final 21 | 22 | 23 | def test_punkt_token_attributes(): 24 | """Test PunktToken attributes.""" 25 | # Token with parameters 26 | token = PunktToken("word", parastart=True, linestart=True) 27 | assert token.parastart 28 | assert token.linestart 29 | assert not token.sentbreak 30 | assert not token.abbr 31 | assert not token.ellipsis 32 | 33 | 34 | def test_punkt_token_type_methods(): 35 | """Test PunktToken type methods.""" 36 | # Test type_no_period 37 | token = PunktToken("word.") 38 | assert token.type_no_period == "word" 39 | 40 | # Test type_no_sentperiod 41 | token = PunktToken("word.") 42 | token.sentbreak = True 43 | assert token.type_no_sentperiod == "word" 44 | 45 | token = PunktToken("word.") 46 | token.sentbreak = False 47 | assert token.type_no_sentperiod == "word." 48 | 49 | 50 | def test_punkt_token_case_properties(): 51 | """Test PunktToken case detection properties.""" 52 | # Test first_upper 53 | token = PunktToken("Word") 54 | assert token.first_upper 55 | assert not token.first_lower 56 | 57 | # Test first_lower 58 | token = PunktToken("word") 59 | assert not token.first_upper 60 | assert token.first_lower 61 | 62 | # Test first_case 63 | token = PunktToken("Word") 64 | assert token.first_case == "upper" 65 | 66 | token = PunktToken("word") 67 | assert token.first_case == "lower" 68 | 69 | token = PunktToken("123") 70 | assert token.first_case == "none" 71 | 72 | 73 | def test_punkt_token_special_types(): 74 | """Test special type detection in PunktToken.""" 75 | # Test is_ellipsis with standard ASCII ellipsis 76 | token = PunktToken("...") 77 | assert token.is_ellipsis 78 | 79 | # Test is_ellipsis with Unicode ellipsis character 80 | token = PunktToken("\u2026") 81 | assert token.is_ellipsis 82 | 83 | # Test is_ellipsis with Unicode ellipsis at end of word 84 | token = PunktToken("word\u2026") 85 | assert token.is_ellipsis 86 | 87 | # Test is_number 88 | token = PunktToken("123") 89 | assert token.is_number 90 | 91 | token = PunktToken("3.14") 92 | assert token.is_number 93 | 94 | # Test is_initial 95 | token = PunktToken("A.") 96 | assert token.is_initial 97 | 98 | # Test is_alpha 99 | token = PunktToken("word") 100 | assert token.is_alpha 101 | 102 | # Test is_non_punct 103 | token = PunktToken("word") 104 | assert token.is_non_punct 105 | 106 | token = PunktToken(".") 107 | assert not token.is_non_punct 108 | 109 | 110 | @pytest.mark.benchmark(group="tokens") 111 | def test_token_creation_benchmark(benchmark): 112 | """Benchmark token creation.""" 113 | 114 | def create_tokens(): 115 | tokens = [] 116 | for i in range(1000): 117 | if i % 3 == 0: 118 | # Regular word 119 | token = PunktToken(f"word{i}") 120 | elif i % 3 == 1: 121 | # Word with period 122 | token = PunktToken(f"abbrev{i}.") 123 | else: 124 | # Mixed case with punctuation 125 | token = PunktToken(f"Mixed{i}!") 126 | tokens.append(token) 127 | return tokens 128 | 129 | # Run the benchmark 130 | tokens = benchmark(create_tokens) 131 | 132 | # Simple verification 133 | assert len(tokens) == 1000 134 | assert tokens[0].tok == "word0" 135 | assert tokens[1].tok == "abbrev1." 136 | assert tokens[2].tok == "Mixed2!" 137 | 138 | 139 | @pytest.mark.benchmark(group="tokens") 140 | def test_token_factory_benchmark(benchmark): 141 | """Benchmark token creation through the factory function.""" 142 | 143 | def create_tokens_with_factory(): 144 | tokens = [] 145 | for i in range(1000): 146 | if i % 3 == 0: 147 | # Regular word 148 | token = create_punkt_token(f"word{i}") 149 | elif i % 3 == 1: 150 | # Word with period 151 | token = create_punkt_token(f"abbrev{i}.") 152 | else: 153 | # Mixed case with punctuation 154 | token = create_punkt_token(f"Mixed{i}!") 155 | tokens.append(token) 156 | return tokens 157 | 158 | # Run the benchmark 159 | tokens = benchmark(create_tokens_with_factory) 160 | 161 | # Simple verification 162 | assert len(tokens) == 1000 163 | assert tokens[0].tok == "word0" 164 | assert tokens[1].tok == "abbrev1." 165 | assert tokens[2].tok == "Mixed2!" 166 | 167 | 168 | @pytest.mark.benchmark(group="tokens") 169 | def test_token_factory_with_cache_benchmark(benchmark): 170 | """Benchmark token creation through the factory function with cache hits.""" 171 | 172 | # Pre-populate the token cache with some common tokens 173 | common_tokens = [ 174 | "the", 175 | "of", 176 | "and", 177 | "a", 178 | "to", 179 | "in", 180 | "is", 181 | "that", 182 | "it", 183 | "was", 184 | "for", 185 | "on", 186 | "are", 187 | "as", 188 | "with", 189 | "his", 190 | "they", 191 | "at", 192 | "be", 193 | "this", 194 | "from", 195 | "have", 196 | "or", 197 | "by", 198 | "one", 199 | "had", 200 | "not", 201 | "but", 202 | "what", 203 | "all", 204 | ] 205 | for tok in common_tokens: 206 | create_punkt_token(tok) 207 | 208 | def create_tokens_with_cache(): 209 | tokens = [] 210 | # First add common tokens that should hit the cache 211 | for tok in common_tokens: 212 | tokens.append(create_punkt_token(tok)) 213 | 214 | # Then add some new tokens 215 | for i in range(500): 216 | if i % 3 == 0: 217 | # Regular word 218 | token = create_punkt_token(f"word{i}") 219 | elif i % 3 == 1: 220 | # Word with period 221 | token = create_punkt_token(f"abbrev{i}.") 222 | else: 223 | # Mixed case with punctuation 224 | token = create_punkt_token(f"Mixed{i}!") 225 | tokens.append(token) 226 | return tokens 227 | 228 | # Run the benchmark 229 | tokens = benchmark(create_tokens_with_cache) 230 | 231 | # Simple verification 232 | assert len(tokens) == 30 + 500 233 | assert tokens[0].tok == "the" 234 | 235 | 236 | @pytest.mark.benchmark(group="tokens") 237 | def test_token_property_access_benchmark(benchmark): 238 | """Benchmark access to token properties.""" 239 | # Create a variety of tokens first 240 | tokens = [] 241 | for i in range(1000): 242 | if i % 4 == 0: 243 | tokens.append(PunktToken(f"Word{i}")) 244 | elif i % 4 == 1: 245 | tokens.append(PunktToken(f"abbrev{i}.")) 246 | elif i % 4 == 2: 247 | tokens.append(PunktToken(f"{i}.{i}")) 248 | else: 249 | tokens.append(PunktToken(f"...{i}")) 250 | 251 | def access_properties(): 252 | results = [] 253 | for token in tokens: 254 | # Access various properties 255 | props = ( 256 | token.type, 257 | token.period_final, 258 | token.is_ellipsis, 259 | token.is_number, 260 | token.is_initial, 261 | token.is_alpha, 262 | token.is_non_punct, 263 | token.first_case, 264 | ) 265 | results.append(props) 266 | return results 267 | 268 | # Run the benchmark 269 | results = benchmark(access_properties) 270 | 271 | # Simple verification 272 | assert len(results) == 1000 273 | 274 | 275 | @pytest.mark.benchmark(group="tokens") 276 | def test_lazy_property_access_benchmark(benchmark): 277 | """Benchmark access to lazily evaluated properties.""" 278 | # Create tokens using the factory function 279 | tokens = [] 280 | for i in range(1000): 281 | if i % 4 == 0: 282 | tokens.append(create_punkt_token(f"Word{i}")) 283 | elif i % 4 == 1: 284 | tokens.append(create_punkt_token(f"abbrev{i}.")) 285 | elif i % 4 == 2: 286 | tokens.append(create_punkt_token(f"{i}.{i}")) 287 | else: 288 | tokens.append(create_punkt_token(f"...{i}")) 289 | 290 | def access_lazy_properties(): 291 | results = [] 292 | for token in tokens: 293 | # Access lazily evaluated properties 294 | props = ( 295 | token.type_no_period, 296 | token.type_no_sentperiod, 297 | token.is_ellipsis, 298 | token.is_number, 299 | token.is_initial, 300 | token.is_alpha, 301 | token.is_non_punct, 302 | ) 303 | results.append(props) 304 | return results 305 | 306 | # Run the benchmark 307 | results = benchmark(access_lazy_properties) 308 | 309 | # Simple verification 310 | assert len(results) == 1000 311 | 312 | 313 | def test_slots_vs_dict(): 314 | """Test memory efficiency of slots vs dict.""" 315 | # Create a token using the slots-based class 316 | token = PunktToken("test") 317 | 318 | # Verify it uses slots 319 | assert not hasattr(token, "__dict__") 320 | 321 | # Verify all attributes are accessible 322 | assert token.tok == "test" 323 | assert token.parastart is False 324 | assert token.linestart is False 325 | assert token.sentbreak is False 326 | assert token.abbr is False 327 | assert token.ellipsis is False 328 | assert token.period_final is False 329 | assert token.type == "test" 330 | 331 | # Test lazy property initialization 332 | assert token.is_ellipsis is False 333 | # After access, the property should be calculated and cached 334 | assert token._is_ellipsis is False 335 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | """Basic tests for nupunkt utility functions.""" 2 | 3 | import tempfile 4 | from pathlib import Path 5 | 6 | import pytest 7 | 8 | from nupunkt.utils.compression import load_compressed_json, save_compressed_json 9 | from nupunkt.utils.iteration import pair_iter 10 | from nupunkt.utils.statistics import collocation_log_likelihood, dunning_log_likelihood 11 | 12 | 13 | def test_pair_iter(): 14 | """Test the pair_iter utility function.""" 15 | # Empty list should yield nothing 16 | assert list(pair_iter([])) == [] 17 | 18 | # Single item list should yield a single pair with None as second element 19 | result = list(pair_iter([1])) 20 | assert len(result) == 1 21 | assert result[0] == (1, None) 22 | 23 | # Two item list 24 | result = list(pair_iter([1, 2])) 25 | assert len(result) == 2 26 | assert result[0] == (1, 2) 27 | assert result[1] == (2, None) 28 | 29 | # Multiple item list 30 | result = list(pair_iter([1, 2, 3, 4])) 31 | assert len(result) == 4 32 | assert result[0] == (1, 2) 33 | assert result[1] == (2, 3) 34 | assert result[2] == (3, 4) 35 | assert result[3] == (4, None) 36 | 37 | 38 | def test_dunning_log_likelihood(): 39 | """Test the dunning_log_likelihood function.""" 40 | # Test with simple values 41 | ll = dunning_log_likelihood(100, 1000, 50, 10000) 42 | assert isinstance(ll, float) 43 | # The function returns negative values by design for Punkt algorithm 44 | 45 | # Higher count_ab should generally result in smaller negative values 46 | ll1 = dunning_log_likelihood(100, 1000, 10, 10000) 47 | ll2 = dunning_log_likelihood(100, 1000, 50, 10000) 48 | assert ll1 < ll2 49 | 50 | # Test with edge cases 51 | ll = dunning_log_likelihood(0, 0, 0, 1) 52 | assert isinstance(ll, float) 53 | assert ll == 0.0 # Edge case handling 54 | 55 | # Test with (1,1,1,1) corner case 56 | ll = dunning_log_likelihood(1, 1, 1, 1) 57 | assert isinstance(ll, float) 58 | # This case can return negative values by design 59 | 60 | 61 | def test_collocation_log_likelihood(): 62 | """Test the collocation_log_likelihood function.""" 63 | # Test with simple values 64 | ll = collocation_log_likelihood(100, 200, 50, 10000) 65 | assert isinstance(ll, float) 66 | assert ll > 0 67 | 68 | # Higher count_ab should generally result in higher likelihood 69 | ll1 = collocation_log_likelihood(100, 200, 10, 10000) 70 | ll2 = collocation_log_likelihood(100, 200, 50, 10000) 71 | assert ll2 > ll1 72 | 73 | # Perfect correlation should have high likelihood 74 | ll = collocation_log_likelihood(100, 100, 100, 10000) 75 | assert ll > 0 76 | 77 | # Test edge cases 78 | ll = collocation_log_likelihood(0, 0, 0, 1) 79 | assert isinstance(ll, float) 80 | 81 | ll = collocation_log_likelihood(1, 1, 1, 1) 82 | assert isinstance(ll, float) 83 | 84 | 85 | def test_compression_functions_basic(): 86 | """Test basic functionality of compression utility functions.""" 87 | with tempfile.TemporaryDirectory() as tmpdir: 88 | test_data = {"key1": "value1", "key2": [1, 2, 3], "key3": {"nested": True}} 89 | 90 | # Test without compression 91 | uncompressed_path = Path(tmpdir) / "test_uncompressed.json" 92 | save_compressed_json(test_data, uncompressed_path, use_compression=False) 93 | 94 | # Verify it's a standard JSON file 95 | assert uncompressed_path.exists() 96 | assert uncompressed_path.suffix == ".json" 97 | 98 | # Load it back and verify content 99 | loaded_data = load_compressed_json(uncompressed_path) 100 | assert loaded_data == test_data 101 | 102 | # Test with compression 103 | compressed_path = Path(tmpdir) / "test_compressed.json.xz" 104 | save_compressed_json(test_data, compressed_path, use_compression=True) 105 | 106 | # Verify it's compressed 107 | assert compressed_path.exists() 108 | assert compressed_path.suffix == ".xz" 109 | 110 | # Load it back and verify content 111 | loaded_data = load_compressed_json(compressed_path) 112 | assert loaded_data == test_data 113 | 114 | 115 | def test_compression_automatic_extension(): 116 | """Test automatic extension handling in compression functions.""" 117 | with tempfile.TemporaryDirectory() as tmpdir: 118 | test_data = {"key1": "value1", "key2": [1, 2, 3]} 119 | 120 | # Test with base path 121 | base_path = Path(tmpdir) / "test_file" 122 | 123 | # Save with compression 124 | save_compressed_json(test_data, base_path, use_compression=True) 125 | expected_path = Path(f"{base_path}.json.xz") 126 | assert expected_path.exists() 127 | 128 | # Load back with automatic detection - use the expected path directly 129 | loaded_data = load_compressed_json(expected_path) 130 | assert loaded_data == test_data 131 | 132 | # Test with .json extension but requesting compression 133 | json_path = Path(tmpdir) / "test_file.json" 134 | save_compressed_json(test_data, json_path, use_compression=True) 135 | expected_path = Path(f"{json_path}.xz") 136 | assert expected_path.exists() 137 | 138 | # Load back with automatic detection - use the expected path directly 139 | loaded_data = load_compressed_json(expected_path) 140 | assert loaded_data == test_data 141 | 142 | 143 | def test_compression_level_verification(): 144 | """Test that different compression levels work correctly.""" 145 | with tempfile.TemporaryDirectory() as tmpdir: 146 | # Create a larger test dataset 147 | test_data = {f"key_{i}": f"value_{i}" * 100 for i in range(100)} 148 | 149 | # Compress with level 1 (fast) 150 | fast_path = Path(tmpdir) / "fast_compressed.json.xz" 151 | save_compressed_json(test_data, fast_path, level=1) 152 | 153 | # Compress with level 9 (best) 154 | best_path = Path(tmpdir) / "best_compressed.json.xz" 155 | save_compressed_json(test_data, best_path, level=9) 156 | 157 | # Get file sizes 158 | fast_size = fast_path.stat().st_size 159 | best_size = best_path.stat().st_size 160 | 161 | # For very small test data, compression level might not make a big difference 162 | # or might even have reverse effect due to compression metadata overhead 163 | # Just verify that both files are created and have reasonable size 164 | assert fast_size > 0 165 | assert best_size > 0 166 | 167 | # Both should load correctly 168 | assert load_compressed_json(fast_path) == test_data 169 | assert load_compressed_json(best_path) == test_data 170 | 171 | 172 | @pytest.mark.benchmark(group="compression") 173 | def test_compression_benchmark(benchmark): 174 | """Benchmark compression functions.""" 175 | # Create test data - mix of strings, numbers, and nested structures 176 | test_data = { 177 | "strings": [f"value_{i}" * 20 for i in range(50)], 178 | "numbers": [i * 3.14159 for i in range(100)], 179 | "nested": [{"id": i, "name": f"item_{i}" * 5, "active": i % 2 == 0} for i in range(50)], 180 | } 181 | 182 | # Prepare data for compression benchmarking 183 | 184 | def compress_func(): 185 | with tempfile.NamedTemporaryFile(suffix=".json.xz", delete=True) as tmp: 186 | # Use fastest compression level for benchmarking 187 | save_compressed_json(test_data, tmp.name, level=1) 188 | # Read back to ensure full cycle 189 | return load_compressed_json(tmp.name) 190 | 191 | # Run the benchmark 192 | result = benchmark(compress_func) 193 | 194 | # Simple assertions to make sure it worked 195 | assert result == test_data 196 | 197 | 198 | @pytest.mark.benchmark(group="compression") 199 | def test_no_compression_benchmark(benchmark): 200 | """Benchmark without compression for comparison.""" 201 | # Use the same test data as in the compression benchmark 202 | test_data = { 203 | "strings": [f"value_{i}" * 20 for i in range(50)], 204 | "numbers": [i * 3.14159 for i in range(100)], 205 | "nested": [{"id": i, "name": f"item_{i}" * 5, "active": i % 2 == 0} for i in range(50)], 206 | } 207 | 208 | def no_compress_func(): 209 | with tempfile.NamedTemporaryFile(suffix=".json", delete=True) as tmp: 210 | save_compressed_json(test_data, tmp.name, use_compression=False) 211 | return load_compressed_json(tmp.name) 212 | 213 | # Run the benchmark 214 | result = benchmark(no_compress_func) 215 | 216 | # Simple assertions to make sure it worked 217 | assert result == test_data 218 | --------------------------------------------------------------------------------