├── .github └── workflows │ ├── pylint.yml │ └── run_tests.yml ├── .gitignore ├── LICENCE.txt ├── MANIFEST.in ├── Makefile ├── README.md ├── linting_config └── pylint-configuration.pylintrc ├── punctfix ├── __init__.py ├── inference.py ├── models.py └── streaming.py ├── requirements.txt ├── scripts └── test_timing.py ├── setup.py └── tests ├── __init__.py └── test_punctuation.py /.github/workflows/pylint.yml: -------------------------------------------------------------------------------- 1 | name: Pylint 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | paths: 8 | - "punctfix/**" 9 | pull_request: 10 | branches: 11 | - main 12 | paths: 13 | - "punctfix/**" 14 | 15 | jobs: 16 | build: 17 | runs-on: ubuntu-latest 18 | strategy: 19 | matrix: 20 | python-version: ["3.8"] 21 | steps: 22 | - uses: actions/checkout@v2 23 | - name: Set up Python ${{ matrix.python-version }} 24 | uses: actions/setup-python@v2 25 | with: 26 | python-version: ${{ matrix.python-version }} 27 | - name: Install dependencies 28 | run: | 29 | python -m pip install --upgrade pip 30 | pip install pylint 31 | - name: Analysing the code with pylint 32 | run: | 33 | pylint --rcfile ./linting_config/pylint-configuration.pylintrc $(find ./punctfix/ -name "*.py" | xargs) 34 | -------------------------------------------------------------------------------- /.github/workflows/run_tests.yml: -------------------------------------------------------------------------------- 1 | name: InferenceTests 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | paths: 8 | - "punctfix/**" 9 | pull_request: 10 | branches: 11 | - main 12 | paths: 13 | - "punctfix/**" 14 | 15 | jobs: 16 | test: 17 | runs-on: ubuntu-latest 18 | strategy: 19 | matrix: 20 | python-version: ["3.8"] 21 | steps: 22 | - uses: actions/checkout@v2 23 | - name: Set up Python ${{ matrix.python-version }} 24 | uses: actions/setup-python@v2 25 | with: 26 | python-version: ${{ matrix.python-version }} 27 | - name: Install dependencies 28 | run: | 29 | pip install wheel 30 | pip install -r requirements.txt 31 | - name: Running all tests 32 | run: | 33 | python -m unittest -v tests/test_*.py 34 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | test_model/ 2 | .idea/ 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | 7 | .eggs 8 | 9 | # Distribution / packaging 10 | bin/ 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | eggs/ 15 | lib/ 16 | lib64/ 17 | parts/ 18 | sdist/ 19 | var/ 20 | *.egg-info/ 21 | .installed.cfg 22 | *.egg 23 | 24 | **.pyo 25 | **.pyc 26 | -------------------------------------------------------------------------------- /LICENCE.txt: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.md 2 | include LICENCE.txt 3 | include requirements.txt 4 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: test 2 | 3 | test: 4 | python -m unittest -v tests/test_*.py 5 | 6 | pylint: 7 | pylint --rcfile=./linting_config/pylint-configuration.pylintrc $(shell find ./punctfix/ -name "*.py" | xargs) 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Punctuation restoration 2 | Adds punctuation and capitalization for a given text without punctuation. 3 | 4 | Works on Danish, German and English. 5 | 6 | Models hosted on huggingface! ❤️ 🤗 7 | 8 | ## Status with python 3.8 9 | ![example workflow](https://github.com/danspeech/punctfix/actions/workflows/run_tests.yml/badge.svg) 10 | ![example workflow](https://github.com/danspeech/punctfix/actions/workflows/pylint.yml/badge.svg) 11 | 12 | ## Installation 13 | ``` 14 | pip install punctfix 15 | ``` 16 | 17 | ## Usage 18 | Its quite simple to use! 19 | 20 | ```python 21 | >>> from punctfix import PunctFixer 22 | >>> fixer = PunctFixer(language="da") 23 | 24 | >>> example_text = "mit navn det er rasmus og jeg kommer fra firmaet alvenir det er mig som har trænet denne lækre model" 25 | >>> print(fixer.punctuate(example_text)) 26 | 'Mit navn det er Rasmus og jeg kommer fra firmaet Alvenir. Det er mig som har trænet denne lækre model.' 27 | 28 | >>> example_text = "en dag bliver vi sku glade for, at vi nu kan sætte punktummer og kommaer i en sætning det fungerer da meget godt ikke" 29 | >>> print(fixer.punctuate(example_text)) 30 | 'En dag bliver vi sku glade for, at vi nu kan sætte punktummer og kommaer i en sætning. Det fungerer da meget godt, ikke?' 31 | ``` 32 | 33 | Note that, per default, the input text will be normalied. See next section for more details. 34 | 35 | ## Parameters for PunctFixer 36 | * Pass `device="cuda"` or `device="cpu"` to indicate where to run inference. Default is `device="cpu"` 37 | * To handle long sequences, we use a chunk size and an overlap. These can be modified. For higher speed but 38 | lower acuracy use a chunk size of 150-200 and very little overlap i.e. 5-10. These parameters are set with 39 | default values `word_chunk_size=100`, `word_overlap=70` which makes it run a bit slow. The default parameters 40 | will be updated when we have some results on variations. 41 | * Supported languages are "en" for English, "da" for Danish and "de" for German. Default is `language="da"`. 42 | * Note that the fixer has been trained on normalized text (lowercase letters and numbers) and will per default normalize input text. You can instantiate the model with `skip_normalization=True` to disable this but this might yield errors on some input text. 43 | * To raise warnings every time the input is normalied, set `warn_on_normalization=True`. 44 | 45 | ## Contribute 46 | If you encounter issues, feel free to open issues in the repo and then we will fix. Even better, create issue and 47 | then a PR that fixes the issue! ;-) 48 | 49 | Happy punctuating! 50 | -------------------------------------------------------------------------------- /linting_config/pylint-configuration.pylintrc: -------------------------------------------------------------------------------- 1 | [MASTER] 2 | 3 | # A comma-separated list of package or module names from where C extensions may 4 | # be loaded. Extensions are loading into the active Python interpreter and may 5 | # run arbitrary code. 6 | extension-pkg-whitelist= 7 | 8 | # Add files or directories to the blacklist. They should be base names, not 9 | # paths. 10 | ignore=CVS 11 | 12 | # Add files or directories matching the regex patterns to the blacklist. The 13 | # regex matches against base names, not paths. 14 | ignore-patterns= 15 | 16 | # Python code to execute, usually for sys.path manipulation such as 17 | # pygtk.require(). 18 | #init-hook= 19 | 20 | # Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the 21 | # number of processors available to use. 22 | jobs=0 23 | 24 | # Control the amount of potential inferred values when inferring a single 25 | # object. This can help the performance when dealing with large functions or 26 | # complex, nested conditions. 27 | limit-inference-results=120 28 | 29 | # List of plugins (as comma separated values of python modules names) to load, 30 | # usually to register additional checkers. 31 | load-plugins= 32 | 33 | # Pickle collected data for later comparisons. 34 | persistent=yes 35 | 36 | # Specify a configuration file. 37 | #rcfile= 38 | 39 | # When enabled, pylint would attempt to guess common misconfiguration and emit 40 | # user-friendly hints instead of false-positive error messages. 41 | suggestion-mode=yes 42 | 43 | # Allow loading of arbitrary C extensions. Extensions are imported into the 44 | # active Python interpreter and may run arbitrary code. 45 | unsafe-load-any-extension=no 46 | 47 | 48 | [MESSAGES CONTROL] 49 | 50 | # Only show warnings with the listed confidence levels. Leave empty to show 51 | # all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED. 52 | confidence= 53 | 54 | # Disable the message, report, category or checker with the given id(s). You 55 | # can either give multiple identifiers separated by comma (,) or put this 56 | # option multiple times (only on the command line, not in the configuration 57 | # file where it should appear only once). You can also use "--disable=all" to 58 | # disable everything first and then reenable specific checks. For example, if 59 | # you want to run only the similarities checker, you can use "--disable=all 60 | # --enable=similarities". If you want to run only the classes checker, but have 61 | # no Warning level messages displayed, use "--disable=all --enable=classes 62 | # --disable=W". 63 | disable=missing-module-docstring,import-error 64 | 65 | # Enable the message, report, category or checker with the given id(s). You can 66 | # either give multiple identifier separated by comma (,) or put this option 67 | # multiple time (only on the command line, not in the configuration file where 68 | # it should appear only once). See also the "--disable" option for examples. 69 | enable=c-extension-no-member 70 | 71 | 72 | [REPORTS] 73 | 74 | # Python expression which should return a note less than 10 (10 is the highest 75 | # note). You have access to the variables errors warning, statement which 76 | # respectively contain the number of errors / warnings messages and the total 77 | # number of statements analyzed. This is used by the global evaluation report 78 | # (RP0004). 79 | evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) 80 | 81 | # Template used to display messages. This is a python new-style format string 82 | # used to format the message information. See doc for all details. 83 | #msg-template= 84 | 85 | # Set the output format. Available formats are text, parseable, colorized, json 86 | # and msvs (visual studio). You can also give a reporter class, e.g. 87 | # mypackage.mymodule.MyReporterClass. 88 | output-format=text 89 | 90 | # Tells whether to display a full report or only the messages. 91 | reports=no 92 | 93 | # Activate the evaluation score. 94 | score=yes 95 | 96 | 97 | [REFACTORING] 98 | 99 | # Maximum number of nested blocks for function / method body 100 | max-nested-blocks=5 101 | 102 | # Complete name of functions that never returns. When checking for 103 | # inconsistent-return-statements if a never returning function is called then 104 | # it will be considered as an explicit return statement and no message will be 105 | # printed. 106 | never-returning-functions=sys.exit 107 | 108 | 109 | [VARIABLES] 110 | 111 | # List of additional names supposed to be defined in builtins. Remember that 112 | # you should avoid defining new builtins when possible. 113 | additional-builtins= 114 | 115 | # Tells whether unused global variables should be treated as a violation. 116 | allow-global-unused-variables=yes 117 | 118 | # List of strings which can identify a callback function by name. A callback 119 | # name must start or end with one of those strings. 120 | callbacks=cb_, 121 | _cb 122 | 123 | # A regular expression matching the name of dummy variables (i.e. expected to 124 | # not be used). 125 | dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_ 126 | 127 | # Argument names that match this expression will be ignored. Default to name 128 | # with leading underscore. 129 | ignored-argument-names=_.*|^ignored_|^unused_ 130 | 131 | # Tells whether we should check for unused import in __init__ files. 132 | init-import=no 133 | 134 | # List of qualified module names which can have objects that can redefine 135 | # builtins. 136 | redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io 137 | 138 | 139 | [BASIC] 140 | 141 | # Naming style matching correct argument names. 142 | argument-naming-style=snake_case 143 | 144 | # Regular expression matching correct argument names. Overrides argument- 145 | # naming-style. 146 | #argument-rgx= 147 | 148 | # Naming style matching correct attribute names. 149 | attr-naming-style=snake_case 150 | 151 | # Regular expression matching correct attribute names. Overrides attr-naming- 152 | # style. 153 | #attr-rgx= 154 | 155 | # Bad variable names which should always be refused, separated by a comma. 156 | bad-names=foo, 157 | bar, 158 | baz, 159 | toto, 160 | tutu, 161 | tata 162 | 163 | # Naming style matching correct class attribute names. 164 | class-attribute-naming-style=any 165 | 166 | # Regular expression matching correct class attribute names. Overrides class- 167 | # attribute-naming-style. 168 | #class-attribute-rgx= 169 | 170 | # Naming style matching correct class names. 171 | class-naming-style=PascalCase 172 | 173 | # Regular expression matching correct class names. Overrides class-naming- 174 | # style. 175 | #class-rgx= 176 | 177 | # Naming style matching correct constant names. 178 | const-naming-style=UPPER_CASE 179 | 180 | # Regular expression matching correct constant names. Overrides const-naming- 181 | # style. 182 | #const-rgx= 183 | 184 | # Minimum line length for functions/classes that require docstrings, shorter 185 | # ones are exempt. 186 | docstring-min-length=-1 187 | 188 | # Naming style matching correct function names. 189 | function-naming-style=snake_case 190 | 191 | # Regular expression matching correct function names. Overrides function- 192 | # naming-style. 193 | #function-rgx= 194 | 195 | # Good variable names which should always be accepted, separated by a comma. 196 | good-names=i, 197 | j, 198 | k, 199 | ex, 200 | Run, 201 | _ 202 | 203 | # Include a hint for the correct naming format with invalid-name. 204 | include-naming-hint=yes 205 | 206 | # Naming style matching correct inline iteration names. 207 | inlinevar-naming-style=any 208 | 209 | # Regular expression matching correct inline iteration names. Overrides 210 | # inlinevar-naming-style. 211 | #inlinevar-rgx= 212 | 213 | # Naming style matching correct method names. 214 | method-naming-style=snake_case 215 | 216 | # Regular expression matching correct method names. Overrides method-naming- 217 | # style. 218 | # method-rgx=snake_case|(setUp)|(tearDown)|(asyncSetUp)|(setUpClass)|(tearDownClass) 219 | 220 | # Naming style matching correct module names. 221 | module-naming-style=snake_case 222 | 223 | # Regular expression matching correct module names. Overrides module-naming- 224 | # style. 225 | #module-rgx= 226 | 227 | # Colon-delimited sets of names that determine each other's naming style when 228 | # the name regexes allow several styles. 229 | name-group= 230 | 231 | # Regular expression which should only match function or class names that do 232 | # not require a docstring. 233 | no-docstring-rgx=^_ 234 | 235 | # List of decorators that produce properties, such as abc.abstractproperty. Add 236 | # to this list to register other decorators that produce valid properties. 237 | # These decorators are taken in consideration only for invalid-name. 238 | property-classes=abc.abstractproperty 239 | 240 | # Naming style matching correct variable names. 241 | variable-naming-style=snake_case 242 | 243 | # Regular expression matching correct variable names. Overrides variable- 244 | # naming-style. 245 | #variable-rgx= 246 | 247 | 248 | [SPELLING] 249 | 250 | # Limits count of emitted suggestions for spelling mistakes. 251 | max-spelling-suggestions=4 252 | 253 | # Spelling dictionary name. Available dictionaries: none. To make it working 254 | # install python-enchant package.. 255 | spelling-dict= 256 | 257 | # List of comma separated words that should not be checked. 258 | spelling-ignore-words= 259 | 260 | # A path to a file that contains private dictionary; one word per line. 261 | spelling-private-dict-file= 262 | 263 | # Tells whether to store unknown words to indicated private dictionary in 264 | # --spelling-private-dict-file option instead of raising a message. 265 | spelling-store-unknown-words=no 266 | 267 | 268 | [TYPECHECK] 269 | 270 | # List of decorators that produce context managers, such as 271 | # contextlib.contextmanager. Add to this list to register other decorators that 272 | # produce valid context managers. 273 | contextmanager-decorators=contextlib.contextmanager 274 | 275 | # List of members which are set dynamically and missed by pylint inference 276 | # system, and so shouldn't trigger E1101 when accessed. Python regular 277 | # expressions are accepted. 278 | generated-members= 279 | 280 | # Tells whether missing members accessed in mixin class should be ignored. A 281 | # mixin class is detected if its name ends with "mixin" (case insensitive). 282 | ignore-mixin-members=yes 283 | 284 | # Tells whether to warn about missing members when the owner of the attribute 285 | # is inferred to be None. 286 | ignore-none=yes 287 | 288 | # This flag controls whether pylint should warn about no-member and similar 289 | # checks whenever an opaque object is returned when inferring. The inference 290 | # can return multiple potential results while evaluating a Python object, but 291 | # some branches might not be evaluated, which results in partial inference. In 292 | # that case, it might be useful to still emit no-member and other checks for 293 | # the rest of the inferred objects. 294 | ignore-on-opaque-inference=yes 295 | 296 | # List of class names for which member attributes should not be checked (useful 297 | # for classes with dynamically set attributes). This supports the use of 298 | # qualified names. 299 | ignored-classes=optparse.Values,thread._local,_thread._local 300 | 301 | # List of module names for which member attributes should not be checked 302 | # (useful for modules/projects where namespaces are manipulated during runtime 303 | # and thus existing member attributes cannot be deduced by static analysis. It 304 | # supports qualified module names, as well as Unix pattern matching. 305 | ignored-modules= 306 | 307 | # Show a hint with possible names when a member name was not found. The aspect 308 | # of finding the hint is based on edit distance. 309 | missing-member-hint=yes 310 | 311 | # The minimum edit distance a name should have in order to be considered a 312 | # similar match for a missing member name. 313 | missing-member-hint-distance=1 314 | 315 | # The total number of similar names that should be taken in consideration when 316 | # showing a hint for a missing member. 317 | missing-member-max-choices=1 318 | 319 | 320 | [FORMAT] 321 | 322 | # Expected format of line ending, e.g. empty (any line ending), LF or CRLF. 323 | expected-line-ending-format= 324 | 325 | # Regexp for a line that is allowed to be longer than the limit. 326 | ignore-long-lines=^\s*(# )??$ 327 | 328 | # Number of spaces of indent required inside a hanging or continued line. 329 | indent-after-paren=4 330 | 331 | # String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 332 | # tab). 333 | indent-string=' ' 334 | 335 | # Maximum number of characters on a single line. 336 | max-line-length=120 337 | 338 | # Maximum number of lines in a module. 339 | max-module-lines=1000 340 | 341 | # Allow the body of a class to be on the same line as the declaration if body 342 | # contains single statement. 343 | single-line-class-stmt=no 344 | 345 | # Allow the body of an if to be on the same line as the test if there is no 346 | # else. 347 | single-line-if-stmt=no 348 | 349 | 350 | [LOGGING] 351 | 352 | # Format style used to check logging format string. `old` means using % 353 | # formatting, while `new` is for `{}` formatting. 354 | logging-format-style=old 355 | 356 | # Logging modules to check that the string format arguments are in logging 357 | # function parameter format. 358 | logging-modules=logging 359 | 360 | 361 | [MISCELLANEOUS] 362 | 363 | # List of note tags to take in consideration, separated by a comma. 364 | notes=FIXME, 365 | XXX, 366 | TODO 367 | 368 | 369 | [SIMILARITIES] 370 | 371 | # Ignore comments when computing similarities. 372 | ignore-comments=yes 373 | 374 | # Ignore docstrings when computing similarities. 375 | ignore-docstrings=yes 376 | 377 | # Ignore imports when computing similarities. 378 | ignore-imports=no 379 | 380 | # Minimum lines number of a similarity. 381 | min-similarity-lines=4 382 | 383 | 384 | [IMPORTS] 385 | 386 | # Allow wildcard imports from modules that define __all__. 387 | allow-wildcard-with-all=no 388 | 389 | # Analyse import fallback blocks. This can be used to support both Python 2 and 390 | # 3 compatible code, which means that the block might have code that exists 391 | # only in one or another interpreter, leading to false positives when analysed. 392 | analyse-fallback-blocks=no 393 | 394 | # Deprecated modules which should not be used, separated by a comma. 395 | deprecated-modules=optparse,tkinter.tix 396 | 397 | # Create a graph of external dependencies in the given file (report RP0402 must 398 | # not be disabled). 399 | ext-import-graph= 400 | 401 | # Create a graph of every (i.e. internal and external) dependencies in the 402 | # given file (report RP0402 must not be disabled). 403 | import-graph= 404 | 405 | # Create a graph of internal dependencies in the given file (report RP0402 must 406 | # not be disabled). 407 | int-import-graph= 408 | 409 | # Force import order to recognize a module as part of the standard 410 | # compatibility libraries. 411 | known-standard-library= 412 | 413 | # Force import order to recognize a module as part of a third party library. 414 | known-third-party=enchant 415 | 416 | 417 | [DESIGN] 418 | 419 | # Maximum number of arguments for function / method. 420 | max-args=10 421 | 422 | # Maximum number of attributes for a class (see R0902). 423 | max-attributes=10 424 | 425 | # Maximum number of boolean expressions in an if statement. 426 | max-bool-expr=5 427 | 428 | # Maximum number of branch for function / method body. 429 | max-branches=12 430 | 431 | # Maximum number of locals for function / method body. 432 | max-locals=15 433 | 434 | # Maximum number of parents for a class (see R0901). 435 | max-parents=7 436 | 437 | # Maximum number of public methods for a class (see R0904). 438 | max-public-methods=20 439 | 440 | # Maximum number of return / yield for function / method body. 441 | max-returns=6 442 | 443 | # Maximum number of statements in function / method body. 444 | max-statements=50 445 | 446 | # Minimum number of public methods for a class (see R0903). 447 | min-public-methods=2 448 | 449 | 450 | [CLASSES] 451 | 452 | # List of method names used to declare (i.e. assign) instance attributes. 453 | defining-attr-methods=__init__, 454 | __new__, 455 | setUp 456 | 457 | # List of member names, which should be excluded from the protected access 458 | # warning. 459 | exclude-protected=_asdict, 460 | _fields, 461 | _replace, 462 | _source, 463 | _make 464 | 465 | # List of valid names for the first argument in a class method. 466 | valid-classmethod-first-arg=cls 467 | 468 | # List of valid names for the first argument in a metaclass class method. 469 | valid-metaclass-classmethod-first-arg=cls 470 | 471 | 472 | [EXCEPTIONS] 473 | 474 | # Exceptions that will emit a warning when being caught. Defaults to 475 | # "Exception". 476 | overgeneral-exceptions=builtins.Exception 477 | -------------------------------------------------------------------------------- /punctfix/__init__.py: -------------------------------------------------------------------------------- 1 | from .inference import PunctFixer 2 | -------------------------------------------------------------------------------- /punctfix/inference.py: -------------------------------------------------------------------------------- 1 | from collections import Counter 2 | from dataclasses import dataclass 3 | from typing import Tuple, Dict, List, Union, Optional 4 | import warnings 5 | import re 6 | 7 | import torch 8 | from transformers import TokenClassificationPipeline 9 | 10 | from punctfix.models import get_custom_model_and_tokenizer, get_english_model_and_tokenizer, \ 11 | get_danish_model_and_tokenizer, get_german_model_and_tokenizer 12 | 13 | 14 | WORD_NORMALIZATION_PATTERN = re.compile(r"[\W_]+") 15 | 16 | class NoLanguageOrModelSelect(Exception): 17 | """ 18 | Exception raised if you fail to specify either a language or custom model path. 19 | """ 20 | 21 | class NonNormalizedTextWarning(RuntimeWarning): 22 | """ 23 | Warning given when the input text does not follow the normalization used 24 | during training. 25 | """ 26 | 27 | @dataclass 28 | class WordPrediction: 29 | """ 30 | Dataclass to hold word and labels for inference. 31 | """ 32 | word: str 33 | labels: List[str] 34 | 35 | @property 36 | def label(self): 37 | """ 38 | Label property. When called, at least one label should always be present. 39 | 40 | :return: A single model label as a str 41 | """ 42 | return Counter(self.labels).most_common(1)[0][0] 43 | 44 | 45 | class PunctFixer: 46 | """ 47 | PunctFixer used to punctuate a given text. 48 | """ 49 | 50 | def __init__(self, language: str = "da", 51 | custom_model_path: str = None, 52 | use_auth_token: Optional[Union[bool, str]] = None, 53 | word_overlap: int = 70, 54 | word_chunk_size: int = 100, 55 | device: Union[str, torch.device] = torch.device("cpu"), 56 | skip_normalization=False, 57 | warn_on_normalization=False, 58 | batch_size: int = 1 59 | ): 60 | """ 61 | :param language: Valid options are "da", "de", "en", for Danish, German and English, respectively. 62 | :param custom_model_path: If you have a trained model yourself. If parsed, then language param will be ignored. 63 | :param word_overlap: How many words should overlap in case text is too long. Defaults to 70. 64 | :param word_chunk_size: How many words should a single pass consist of. Defaults to 100. 65 | :param device: A torch.device on which to perform inference. The strings "cpu" or "cuda" can also be given. 66 | :param skip_normalization: Don't check input text and don't normalize it. 67 | :param warn_on_normalization: Warn the user if the input text was normalized. 68 | :param batch_size: Number of text chunks to pass through token classification pipeline. 69 | """ 70 | 71 | self.word_overlap = word_overlap 72 | self.word_chunk_size = word_chunk_size 73 | self.skip_normalization = skip_normalization 74 | self.warn_on_normalization = warn_on_normalization 75 | self.batch_size = batch_size 76 | 77 | self.supported_languages = { 78 | "de": "German", 79 | "da": "Danish", 80 | "en": "English" 81 | } 82 | 83 | if custom_model_path: 84 | self.model, self.tokenizer = get_custom_model_and_tokenizer(custom_model_path, use_auth_token) 85 | elif language == "en": 86 | self.model, self.tokenizer = get_english_model_and_tokenizer() 87 | elif language == "da": 88 | self.model, self.tokenizer = get_danish_model_and_tokenizer() 89 | elif language == "de": 90 | self.model, self.tokenizer = get_german_model_and_tokenizer() 91 | else: 92 | raise NoLanguageOrModelSelect("You need to specify either language or custom_model_path " 93 | "when instantiating a PunctFixer.") 94 | 95 | self.tokenizer.decoder.cleanup = False 96 | self.model = self.model.eval() 97 | if isinstance(device, str): # Backwards compatability 98 | self.device = 0 if device == "cuda" and torch.cuda.is_available() else -1 99 | else: 100 | self.device = device 101 | 102 | 103 | self.pipe = TokenClassificationPipeline(model=self.model, 104 | tokenizer=self.tokenizer, 105 | aggregation_strategy="first", 106 | device=self.device, 107 | ignore_labels=[]) 108 | 109 | def get_supported_languages(self) -> Dict[str, str]: 110 | """ 111 | Get a dict containing supported languages for PunctFixer. 112 | 113 | :return: dict containing support languages for PunctFixer 114 | """ 115 | return self.supported_languages 116 | 117 | @staticmethod 118 | def init_word_prediction_list(words: List[str]) -> List[WordPrediction]: 119 | """ 120 | Initialize a word prediction list i.e. a list containing WordPrediction object for each word. 121 | :param words: List of words 122 | 123 | :return: List of Word predictions 124 | """ 125 | return [WordPrediction(word=word, labels=[]) for word in words] 126 | 127 | def populate_word_prediction_with_labels(self, chunks: List[List[str]], word_prediction_list: List[WordPrediction]): 128 | """ 129 | Performs predictions on all chunks of text, and adds labels to the relevant word predictions. 130 | 131 | :param chunks: List of List of words 132 | :param word_prediction_list: A list containing word predictions i.e. word and labels. 133 | :return: Word predictions list with all label predictions for each word 134 | """ 135 | outputs = self.pipe([" ".join(chunk_text) for chunk_text in chunks], batch_size=self.batch_size) 136 | for i, output in enumerate(outputs): 137 | word_counter = 0 138 | for entity in output: 139 | label = entity["entity_group"] 140 | text = entity["word"] 141 | words_in_text = text.split(" ") 142 | 143 | for word in words_in_text: 144 | current_index = i * self.word_chunk_size + word_counter - (i * self.word_overlap) 145 | 146 | # Sanity check 147 | assert word_prediction_list[current_index].word == word, \ 148 | f"Something went wrong while matching word list ... " \ 149 | f"Tried matching the word: {word} with {word_prediction_list[current_index].word}" 150 | word_prediction_list[current_index].labels.append(label) 151 | word_counter += 1 152 | 153 | return word_prediction_list 154 | 155 | def combine_word_predictions_into_final_text(self, word_prediction_list: List[WordPrediction]): 156 | """ 157 | Combines all predictions for each word into a final string by checking label (majority vote or if equal 158 | predictions, it chooses however Counter from itertools chooses top_n. 159 | 160 | :param word_prediction_list: List of word predictions 161 | :return: A final string with punctuation 162 | """ 163 | final_text = [] 164 | auto_upper_next = False 165 | for word_pred in word_prediction_list: 166 | punctuated_text, auto_upper_next = self._combine_label_and_word(word_pred.label, 167 | word_pred.word, 168 | auto_upper_next) 169 | final_text.append(punctuated_text) 170 | 171 | return " ".join(final_text) 172 | 173 | def split_words_into_chunks(self, words: List[str]) -> List[List[str]]: 174 | """ 175 | Simple method to split a list of words into chunks of words with overlap. 176 | 177 | :param words: List of words to split into chunks 178 | :return: List of List of words consisting of the chunks 179 | """ 180 | return [words[i:i + self.word_chunk_size] 181 | for i in 182 | range(0, len(words), self.word_chunk_size - self.word_overlap)] 183 | 184 | def punctuate(self, text: str) -> str: 185 | """ 186 | Punctuates given text. 187 | 188 | :param text: A lowercase text with no punctuation. 189 | If it has punctuatation, it will be removed. 190 | :return: A punctuated text. 191 | """ 192 | words = self.split_input_text(text) 193 | 194 | # If we have a long sequence of text (measured by words), we split it into chunks 195 | chunks = [] 196 | if len(words) >= self.word_chunk_size: 197 | chunks = self.split_words_into_chunks(words) 198 | else: 199 | chunks.append(words) 200 | 201 | # We create a word prediction list and then combine the predictions to to final text 202 | word_prediction_list = self.init_word_prediction_list(words) 203 | word_prediction_list = self.populate_word_prediction_with_labels(chunks, word_prediction_list) 204 | return self.combine_word_predictions_into_final_text(word_prediction_list) 205 | 206 | def split_input_text(self, text: str) -> List[str]: 207 | """ 208 | Splits given text into words using whitespace tokenization, also performing normalization 209 | :param text: A lowercase text with no punctuation (otherwise normalized) 210 | :return: A list of the words in that text, splitted and normalized. 211 | """ 212 | words = text.split(" ") 213 | if self.skip_normalization: 214 | return words 215 | 216 | normalized_words = [] 217 | to_warn = [] 218 | for word in words: 219 | if not word: 220 | to_warn.append("Additional whitespace was removed.") 221 | norm_word = WORD_NORMALIZATION_PATTERN.sub("", word) 222 | if not word: 223 | continue 224 | if len(norm_word) < len(word): 225 | to_warn.append(r"Non-word (r'\W') characters were removed.") 226 | # We might have removed the entire word 227 | if not norm_word: 228 | continue 229 | if not norm_word.islower(): 230 | norm_word = norm_word.lower() 231 | to_warn.append("Text was lowercased.") 232 | normalized_words.append(norm_word) 233 | 234 | # Warn once for each type of normalization 235 | if self.warn_on_normalization and to_warn: 236 | warnings.warn( 237 | "The input text was modified to follow model normalization: " + 238 | " ".join(sorted(set(to_warn))) + 239 | " To avoid seeing this, set suppress_normalization_warning=True. "\ 240 | "To entirely circumvent normalization, set skip_normalization=True. ", 241 | NonNormalizedTextWarning) 242 | return normalized_words 243 | 244 | @staticmethod 245 | def _combine_label_and_word(label: str, word: str, auto_uppercase: bool = False) -> Tuple[str, bool]: 246 | """ 247 | Combines label and word into a single string by looking at the label to see what should be changed. 248 | 249 | :param label: Punctuation label from the model 250 | :param word: Word to handle 251 | :param auto_uppercase: Whether automatically uppercase independent of label 252 | :return: Tuple with a str consisting of the word with relevant punctuation and 253 | whether to auto capitalize next word 254 | """ 255 | next_auto_uppercase = False 256 | if label[-1] == "U": 257 | word = word.capitalize() 258 | 259 | if label[0] != "O": 260 | word += label[0] 261 | 262 | if auto_uppercase: 263 | word = word.capitalize() 264 | 265 | if label[0] in {".", "!", "?"}: 266 | next_auto_uppercase = True 267 | 268 | return word, next_auto_uppercase 269 | -------------------------------------------------------------------------------- /punctfix/models.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Optional, Union 2 | 3 | from transformers import AutoTokenizer, AutoModelForTokenClassification, BertTokenizerFast 4 | 5 | 6 | def get_english_model_and_tokenizer() -> Tuple[AutoModelForTokenClassification, AutoTokenizer]: 7 | """ 8 | Gets English transformer model and tokenizer 9 | :return: Tuple with (model, tokenizer) 10 | """ 11 | model_id = "Alvenir/bert-punct-restoration-en" 12 | model = AutoModelForTokenClassification.from_pretrained(model_id) 13 | tokenizer = BertTokenizerFast.from_pretrained(model_id) 14 | return model, tokenizer 15 | 16 | 17 | def get_danish_model_and_tokenizer() -> Tuple[AutoModelForTokenClassification, BertTokenizerFast]: 18 | """ 19 | Gets Danish transformer model and tokenizer 20 | :return: Tuple with (model, tokenizer) 21 | """ 22 | model_id = "Alvenir/bert-punct-restoration-da" 23 | tokenizer = BertTokenizerFast.from_pretrained(model_id) 24 | model = AutoModelForTokenClassification.from_pretrained(model_id) 25 | return model, tokenizer 26 | 27 | 28 | def get_german_model_and_tokenizer() -> Tuple[AutoModelForTokenClassification, AutoTokenizer]: 29 | """ 30 | Gets German transformer model and tokenizer 31 | :return: Tuple with (model, tokenizer) 32 | """ 33 | model_id = "Alvenir/bert-punct-restoration-de" 34 | tokenizer = BertTokenizerFast.from_pretrained(model_id) 35 | model = AutoModelForTokenClassification.from_pretrained(model_id) 36 | return model, tokenizer 37 | 38 | 39 | def get_custom_model_and_tokenizer( 40 | model_path: str, 41 | use_auth_token: Optional[Union[bool, str]] = None 42 | ) -> Tuple[AutoModelForTokenClassification, AutoTokenizer]: 43 | """ 44 | Gets local transformer model and tokenizer 45 | :return: Tuple with (model, tokenizer) 46 | """ 47 | model = AutoModelForTokenClassification.from_pretrained(model_path, use_auth_token=use_auth_token) 48 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True, use_auth_token=use_auth_token) 49 | return model, tokenizer 50 | -------------------------------------------------------------------------------- /punctfix/streaming.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | from punctfix.inference import PunctFixer, WordPrediction 4 | 5 | 6 | class PunctFixStreamer: 7 | """ 8 | A stateful streamer that receives text in segments, on-line performing punct-fixing and 9 | returning partial results during streaming. These partial results are guaranteed to be 10 | final. 11 | """ 12 | 13 | chunked_words: List[WordPrediction] 14 | buffer: List[WordPrediction] 15 | 16 | def __init__(self, punct_fixer: PunctFixer): 17 | """ 18 | Takes in an instantiated punct fixer. 19 | """ 20 | self.punct_fixer = punct_fixer 21 | self.clear() 22 | 23 | def __call__(self, new_text_segment: str) -> Optional[str]: 24 | """ 25 | Stream in new text, returning None if this new text did not change anything 26 | and the partial, finalized text if there has been updates to it. 27 | """ 28 | self.buffer.extend( 29 | self.punct_fixer.init_word_prediction_list( 30 | self.punct_fixer.split_input_text(new_text_segment) 31 | ) 32 | ) 33 | if self.process_buffer(): 34 | return self.get_result() 35 | return None 36 | 37 | def finalize(self): 38 | """ 39 | Mark end of stream and return final puncatuated string. 40 | """ 41 | self.process_buffer(is_finalized=True) 42 | punctuated = self.get_result(is_finalized=True) 43 | self.clear() 44 | return punctuated 45 | 46 | def get_result(self, is_finalized=False) -> str: 47 | """ 48 | Returns punctuated string in of all inputs streamed in so far. 49 | If called when not finalized, will only return text that is certain/no longer subject to change 50 | """ 51 | if is_finalized: 52 | finalized_words = self.chunked_words 53 | # These lines perform a tricky calculation in a dumb way: 54 | # When is each word finalized? When it has gotten all the labels that it will get. 55 | # This number of labels is not constant across the sequence and depends on overlap 56 | # size and on chunk size. To avoid trying to be clever, I just calculate the chunks 57 | # and overlaps and sum up how many times each index will be in a chunk. 58 | else: 59 | # The + chunk size makes calculation takes into account that there will be more 60 | # chunks in future and that we should not finalize prematurely 61 | final_num_preds = [0] * ( 62 | len(self.chunked_words) + self.punct_fixer.word_chunk_size 63 | ) 64 | for chunk in self.punct_fixer.split_words_into_chunks( 65 | range(len(self.chunked_words)) 66 | ): 67 | for idx in chunk: 68 | final_num_preds[idx] += 1 69 | finalized_words = [ 70 | word 71 | for i, word in enumerate(self.chunked_words) 72 | if len(word.labels) == final_num_preds[i] 73 | ] 74 | return self.punct_fixer.combine_word_predictions_into_final_text( 75 | finalized_words 76 | ) 77 | 78 | def process_buffer(self, is_finalized=False) -> bool: 79 | """ 80 | Performs actual punctfixing of content in buffer, updating internal state such that a maximal number 81 | of words get predicted labels. Returns true if new chunks were created and processed and false if not. 82 | """ 83 | new_chunks = [] 84 | # Save how many words were chunked before this call 85 | this_processing_started_at = ( 86 | len(self.chunked_words) - self.punct_fixer.word_overlap 87 | if self.chunked_words 88 | else 0 89 | ) 90 | # Whole chunks are appended unless the stream is finalized in which case, the buffer 91 | # is completely emptied 92 | while len(self.buffer) >= self.punct_fixer.word_chunk_size or ( 93 | is_finalized and self.buffer 94 | ): 95 | new_chunks.append( 96 | [word.word for word in self.buffer[: self.punct_fixer.word_chunk_size]] 97 | ) 98 | # Not all words are chunked for the first time, we must (except for first time) 99 | # skip the first `word_overlap` words to avoid duplicates. 100 | already_chunked_idx = ( 101 | self.punct_fixer.word_overlap if self.chunked_words else 0 102 | ) 103 | self.chunked_words.extend( 104 | self.buffer[already_chunked_idx : self.punct_fixer.word_chunk_size] 105 | ) 106 | # We don't remove the entire buffer length from the buffer as we want 107 | # to emulate the overlap feature of the punctfixer; we leave some in there for next chunk. 108 | self.buffer = self.buffer[ 109 | self.punct_fixer.word_chunk_size - self.punct_fixer.word_overlap : 110 | ] 111 | if new_chunks: 112 | # Run the forward pass on all new chunks, matching with the words that are included in them 113 | self.punct_fixer.populate_word_prediction_with_labels( 114 | new_chunks, self.chunked_words[this_processing_started_at:] 115 | ) 116 | return True 117 | return False 118 | 119 | def clear(self): 120 | """ 121 | Reset internal state. 122 | """ 123 | self.buffer = [] 124 | self.chunked_words = [] 125 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tokenizers >= 0.11.6 2 | transformers >= 4.13 3 | torch 4 | -------------------------------------------------------------------------------- /scripts/test_timing.py: -------------------------------------------------------------------------------- 1 | from time import time 2 | import torch 3 | from punctfix import PunctFixer 4 | 5 | MODEL_INPUT = "det der sker over de tre dage fra præsident huden tav ankommer til københavn det er at der " \ 6 | "sådan en bliver spillet sådan et form for tom og jerry kispus mellem københavns politi og " \ 7 | "så de har danske demonstranter for tibet og fåfalungongsom meget gerne vil vise deres " \ 8 | "utilfredshed med det kinesiske regime og det de opfatter som undertrykkelse af de her " \ 9 | "mindretal i kine og lige nu står støttekomiteen for ti bedet bag en demonstration på" \ 10 | " højbro plads i københavn lisbeth davidsen hvor mange er der kommet det er ikke " \ 11 | "de store folkemasser der er mødt op her på" * 10 12 | 13 | def time_fp(device_str: str, batch_size: int): 14 | print(">>> Profiling device %s on batch size %i" % (device_str, batch_size)) 15 | start = time() 16 | model = PunctFixer(language="da", device=device_str, batch_size=batch_size) 17 | print("Initialization time %f" % (time() - start)) 18 | 19 | # Warmup potential CUDA device 20 | model.punctuate(MODEL_INPUT) 21 | 22 | times = [] 23 | for _ in range(5): 24 | start = time() 25 | model.punctuate(MODEL_INPUT) 26 | times.append(time() - start) 27 | print("Average time: %f\nStd. time: %f" % (torch.tensor(times).mean().item(), torch.tensor(times).std().item())) 28 | 29 | 30 | if __name__ == "__main__": 31 | devices = ["cpu"] 32 | batch_sizes = [1, 16, 32, 64] 33 | if torch.cuda.is_available(): 34 | devices.append("cuda") 35 | for device in devices: 36 | for batch_size in batch_sizes: 37 | time_fp(device, batch_size) 38 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | 4 | with open("./README.md", "r", encoding="utf-8") as fh: 5 | long_description = fh.read() 6 | 7 | with open('./requirements.txt', encoding="utf-8") as f: 8 | requirements = f.read().splitlines() 9 | 10 | setuptools.setup( 11 | name="punctfix", 12 | version="0.11.1", 13 | author="Martin Carsten Nielsen", 14 | author_email="martin@alvenir.ai", 15 | description="Punctuation restoration library", 16 | long_description=long_description, 17 | long_description_content_type="text/markdown", 18 | packages=setuptools.find_packages(), 19 | include_package_data=True, 20 | install_requires=requirements, 21 | license_file="LICENCE.txt", 22 | url="https://github.com/danspeech/punctfix", 23 | classifiers=[ 24 | "Programming Language :: Python :: 3", 25 | 'Development Status :: 5 - Production/Stable', 26 | "Operating System :: OS Independent", 27 | ], 28 | ) 29 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alvenirai/punctfix/00bd061d05b4175cd08985f6f021a5a8bb85b108/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_punctuation.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from unittest.mock import patch, MagicMock, ANY 3 | 4 | from punctfix import PunctFixer 5 | from punctfix.inference import NonNormalizedTextWarning 6 | from punctfix.streaming import PunctFixStreamer 7 | 8 | class CleanupDisableTest(unittest.TestCase): 9 | 10 | def setUp(self) -> None: 11 | super().setUp() 12 | self.model = PunctFixer(language="en") 13 | 14 | def tearDown(self) -> None: 15 | super().tearDown() 16 | self.model = None 17 | 18 | def test_donot_does_not_become_dont(self): 19 | model_input = "hello i am brian i do not like koalas" 20 | expected_output = "Hello I am Brian. I do not like Koalas." 21 | 22 | actual_output = self.model.punctuate(model_input) 23 | 24 | self.assertEqual(actual_output, expected_output) 25 | 26 | 27 | class DanishPunctuationRestorationTest(unittest.TestCase): 28 | 29 | def setUp(self) -> None: 30 | super().setUp() 31 | self.model = PunctFixer(language="da") 32 | 33 | def tearDown(self) -> None: 34 | super().tearDown() 35 | self.model = None 36 | 37 | def test_sample01(self): 38 | model_input = "mit navn det er rasmus og jeg kommer fra firmaet alvenir " \ 39 | "det er mig som har trænet denne lækre model" 40 | expected_output = "Mit navn det er Rasmus og jeg kommer fra firmaet Alvenir. " \ 41 | "Det er mig som har trænet denne lækre model." 42 | 43 | actual_output = self.model.punctuate(model_input) 44 | 45 | self.assertEqual(actual_output, expected_output) 46 | 47 | def test_sample02(self): 48 | model_input = "en dag bliver vi sku glade for at vi nu kan sætte punktummer " \ 49 | "og kommaer i en sætning det fungerer da meget godt ikke" 50 | expected_output = "En dag bliver vi sku glade for, at vi nu kan sætte punktummer " \ 51 | "og kommaer i en sætning. Det fungerer da meget godt, ikke?" 52 | 53 | actual_output = self.model.punctuate(model_input) 54 | 55 | self.assertEqual(actual_output, expected_output) 56 | 57 | def test_sample03(self): 58 | model_input = "det der sker over de tre dage fra præsident huden tav ankommer til københavn det er at der " \ 59 | "sådan en bliver spillet sådan et form for tom og jerry kispus mellem københavns politi og " \ 60 | "så de har danske demonstranter for tibet og fåfalungongsom meget gerne vil vise deres " \ 61 | "utilfredshed med det kinesiske regime og det de opfatter som undertrykkelse af de her " \ 62 | "mindretal i kine og lige nu står støttekomiteen for ti bedet bag en demonstration på" \ 63 | " højbro plads i københavn lisbeth davidsen hvor mange er der kommet det er ikke " \ 64 | "de store folkemasser der er mødt op her på" 65 | 66 | actual_output = self.model.punctuate(model_input) 67 | 68 | self.assertIsNotNone(actual_output) 69 | 70 | 71 | class EnglishPunctuationRestorationTest(unittest.TestCase): 72 | 73 | def setUp(self) -> None: 74 | super().setUp() 75 | self.model = PunctFixer(language="en") 76 | 77 | def tearDown(self) -> None: 78 | super().tearDown() 79 | self.model = None 80 | 81 | def test_sample01(self): 82 | model_input = "hello i come from denmark and i am very good at english" 83 | expected_output = "Hello! I come from Denmark and I am very good at English." 84 | 85 | actual_output = self.model.punctuate(model_input) 86 | 87 | self.assertEqual(actual_output, expected_output) 88 | 89 | def test_sample02(self): 90 | model_input = "do you really want to know this is so weird to write text just to see if it works does it work" 91 | expected_output = "Do you really want to know? This is so weird to write text just to see " \ 92 | "if it works? Does it work?" 93 | 94 | actual_output = self.model.punctuate(model_input) 95 | 96 | self.assertEqual(actual_output, expected_output) 97 | 98 | 99 | class GermanPunctuationRestorationTest(unittest.TestCase): 100 | 101 | def setUp(self) -> None: 102 | super().setUp() 103 | self.model = PunctFixer(language="de") 104 | 105 | def tearDown(self) -> None: 106 | super().tearDown() 107 | self.model = None 108 | 109 | def test_sample01(self): 110 | model_input = "oscar geht einkaufen in einen großen supermarkt seine einkaufsliste ist lang er " \ 111 | "kauft für das ganze wochenende ein außerdem kommen gäste für die er kochen wird" 112 | expected_output = "Oscar geht einkaufen in einen großen Supermarkt. Seine Einkaufsliste ist lang. " \ 113 | "Er kauft für das ganze Wochenende ein. Außerdem kommen Gäste, für die er kochen wird." 114 | 115 | actual_output = self.model.punctuate(model_input) 116 | 117 | self.assertEqual(actual_output, expected_output) 118 | 119 | 120 | class HandlingOfLargeTextsTest(unittest.TestCase): 121 | 122 | def setUp(self) -> None: 123 | super().setUp() 124 | self.model = PunctFixer(language="da") 125 | self.long_text = "halløj derude og vel kom til en podcast med jacob og neviton du kan lige starte med at fortælle lidt om dig selv ja ja men det er så mig der jakob den ene halvdel af det her potkarsthold det er jo noget nyt for os det her det er mogens amatører det måske også lige onderestmidtet det var fald den første potkarstvidhverår så vi var i korte introducerer introns vase bare bare fortæller om det er jo næsten lige ja vi faldt næsten jævnaldrende vi bor i samme by ja vi bor i viborg vi bor og stort set helt års liv er det vær du det har jeg larlasesjeg ar altid sådan tilhørt vi går jeg kommer jo fra en lille landsby men jeg tror de kom til at høre mere om osten sener han jeg skal nok komme sådan lidt men gode personlige historier om nu vi jeg tog lidt mærkelig fyr som vi nu og vi har en hel del beatet er det må man sige en grund til vi egentlig startede mig på et kaste fordi at vi altid vi gør det dem på at snakke om alle mulige ting i flere timer faktisk så tog jeg mig selv lige tænk på tænk tanken om at staten podcast det er jo lidt maiestremehvertfald efterhånden her det er så en sjov det er sjovt når at se til bag på så tænkte jeg hvorfor ikke og det er faktisk hende det første ting vi kan huske vi latham jeg var så ikke en kot karst men men jeg skulle hente nyredtanengang efter en bitur og jeg havde ikke sådan jeg havde kendt ham en nu er det måske er og de ting vi snakker om det var ikke bare sådan en af de to ting det var virkelig det var lige for tegnefilmen til livet altså være mening med livet så er der er ikke noget bestemt i komme til at høre de klokken tre om natten efter sådan bytur koge rundt i en bil det er bare gode tider og penge til på den nu ja ser for i det korona nu har jeg så sammen med det lidt anvende ap vi sidst nogen der hører med her i to tusinde og niogtyve så kun leve en palmi der er så lidt til historie børn en ny anne de vi bare snart med denne uge men vi kan jo også komme op med et navn til podcasten ja det er og sådan for titemitersi jamen vi har venner kan læge for kort og godt for at sige i kor et godt ja ser han det års ugerne jamen jeg ved silende med at sige nej det var i hvert fald bare lige sådan en kort introduktion jeg har tænkt på hvad vi skal snakke om næste gang og det kan de vi lave i en ilespoyler ja til folket i gik jeg ned at for altså jakob et meget formelt når jeg trækker med amrepoasnakerom mer så så meget åben odsngetsonogeernår han snakker om at så gar krimeligterdetså men nå hvor sprængt om ja hvordan det ændrer så altså hvordan vi var kun så hvordan børn er nu han når der er jo sket ikke mere og så hvis der sket nogle sjove ting og brøndum som jeg lige kan huske sjove fortællinger jeg vil jo i stå jeg er lige sket som vil det tænker en bar og så kommer vi jo hundrede procent af indre så kom nok rundt omkring hvad der lige sker så det hvert fald det ja jeg har hvert fad ikke med at sige jamen skal vi afslutte den for i dag så skal du så sige os jeg ved ikke hvordan migarmename det bliver i hvert fald bedre bedre betidehkavisde sige vor velkendende ord gør når vi alle det er ikke så kendt fordi det er første gang men det bliver det for en på en tidspunkt færdig" 126 | 127 | def tearDown(self) -> None: 128 | super().tearDown() 129 | self.model = None 130 | self.long_text = None 131 | 132 | def test_multiple_chunk_size_and_padding_configs(self): 133 | configs_to_test = [ 134 | (50, 5), 135 | (50, 10), 136 | (50, 20), 137 | (50, 40), 138 | (100, 5), 139 | (100, 20), 140 | (100, 40), 141 | (100, 80), 142 | (150, 10), 143 | (150, 20), 144 | (150, 50), 145 | (150, 80), 146 | (150, 120), 147 | (200, 50), 148 | (200, 100), 149 | (200, 150), 150 | (200, 180) 151 | ] 152 | 153 | for chunk_size, overlap in configs_to_test: 154 | self.model.word_overlap = overlap 155 | self.model.word_chunk_size = chunk_size 156 | 157 | actual_output = self.model.punctuate(self.long_text) 158 | self.assertIsNotNone(actual_output) 159 | 160 | 161 | class GenerelFunctionalityTest(unittest.TestCase): 162 | 163 | def setUp(self) -> None: 164 | super().setUp() 165 | # Setup 166 | self.torch_cuda_patch = patch( 167 | 'punctfix.inference.torch.cuda.is_available' 168 | ) 169 | self.torch_cuda_mock: MagicMock = self.torch_cuda_patch.start() 170 | self.torch_cuda_mock.return_value = False 171 | 172 | self.token_classification_pipeline_patch = patch( 173 | 'punctfix.inference.TokenClassificationPipeline' 174 | ) 175 | self.token_classification_pipeline_mock: MagicMock = self.token_classification_pipeline_patch.start() 176 | 177 | def test_if_gpu_not_available_default_cpu(self): 178 | # When 179 | self.model = PunctFixer(language="da", device="cuda") 180 | 181 | # Expect 182 | self.token_classification_pipeline_mock.assert_called_once_with(model=ANY, 183 | tokenizer=ANY, 184 | aggregation_strategy="first", 185 | device=-1, 186 | ignore_labels=ANY) 187 | 188 | 189 | def tearDown(self) -> None: 190 | super().tearDown() 191 | self.torch_cuda_patch.stop() 192 | self.token_classification_pipeline_patch.stop() 193 | 194 | class NormalizationTest(unittest.TestCase): 195 | 196 | def setUp(self) -> None: 197 | super().setUp() 198 | self.model = PunctFixer(language="da") 199 | 200 | def tearDown(self) -> None: 201 | super().tearDown() 202 | self.model = None 203 | 204 | def test_do_normalize(self): 205 | self.model.warn_on_normalization = False 206 | expected_output = ["hejsa", "mand"] 207 | for model_input in ("hejsa, mand", " hejsa mand", "hejsa mand", 208 | "Hejsa mand", "hejsa mand", " hejsa mand", " hejsa, Mand", 209 | "hejsa % mand ! % "): 210 | actual_output = self.model.split_input_text(model_input) 211 | self.assertEqual(actual_output, expected_output) 212 | 213 | def test_warnings(self): 214 | self.model.warn_on_normalization = True 215 | with self.assertWarns(NonNormalizedTextWarning): 216 | model_input = "hejsa, mand" 217 | self.model.split_input_text(model_input) 218 | 219 | with self.assertWarns(NonNormalizedTextWarning): 220 | model_input = "hejsa mand" 221 | self.model.split_input_text(model_input) 222 | 223 | with self.assertWarns(NonNormalizedTextWarning): 224 | model_input = "hejsa Mand" 225 | self.model.split_input_text(model_input) 226 | 227 | def test_do_not_normalize(self): 228 | model_input = "det der sker over de tre dage fra præsident huden tav ankommer til københavn det er at der " \ 229 | "sådan en bliver spillet sådan et form for tom og jerry kispus mellem københavns politi og " \ 230 | "så de har danske demonstranter for tibet og fåfalungongsom meget gerne vil vise deres " \ 231 | "utilfredshed med det kinesiske regime og det de opfatter som undertrykkelse af de her " \ 232 | "mindretal i kine og lige nu står støttekomiteen for ti bedet bag en demonstration på" \ 233 | " højbro plads i københavn lisbeth davidsen hvor mange er der kommet det er ikke " \ 234 | "de store folkemasser der er mødt op her på" 235 | expected_output = model_input.split(" ") 236 | actual_output = self.model.split_input_text(model_input) 237 | self.assertEqual(actual_output, expected_output) 238 | 239 | class InputParameterTest(unittest.TestCase): 240 | def test_setting_batch_size(self): 241 | model_input = "mit navn det er rasmus og jeg kommer fra firmaet alvenir " \ 242 | "det er mig som har trænet denne lækre model" 243 | expected_output = "Mit navn det er Rasmus og jeg kommer fra firmaet Alvenir. " \ 244 | "Det er mig som har trænet denne lækre model." 245 | for batch_size in 1, 27, 99: 246 | model = PunctFixer(language="da", batch_size=batch_size) 247 | actual_output = model.punctuate(model_input) 248 | self.assertEqual(actual_output, expected_output) 249 | 250 | class PunctFixStreamerTest(unittest.TestCase): 251 | 252 | def setUp(self) -> None: 253 | super().setUp() 254 | self.streamer = PunctFixStreamer(PunctFixer(language="da")) 255 | 256 | def tearDown(self) -> None: 257 | super().tearDown() 258 | del self.streamer 259 | 260 | def test_sample01(self): 261 | model_inputs = "mit navn det er rasmus", "og jeg kommer", "fra firmaet alvenir",\ 262 | "det er mig", "som har trænet", "denne", "lækre model" 263 | expected_output = "Mit navn det er Rasmus og jeg kommer fra firmaet Alvenir. " \ 264 | "Det er mig som har trænet denne lækre model." 265 | 266 | for input_ in model_inputs: 267 | self.streamer(input_) 268 | actual_output = self.streamer.finalize() 269 | self.assertEqual(actual_output, expected_output) 270 | 271 | def test_sample02(self): 272 | model_inputs = "en dag bliver vi sku glade", "for", "at vi nu kan", "sætte punktummer ",\ 273 | "og kommaer", "i", "en", "sætning det fungerer da meget", "godt ikke" 274 | expected_output = "En dag bliver vi sku glade for, at vi nu kan sætte punktummer " \ 275 | "og kommaer i en sætning. Det fungerer da meget godt, ikke?" 276 | for input_ in model_inputs: 277 | self.streamer(input_) 278 | actual_output = self.streamer.finalize() 279 | self.assertEqual(actual_output, expected_output) 280 | 281 | def test_sample03(self): 282 | # We want it super loooong 283 | model_input = "det der sker over de tre dage fra præsident huden tav ankommer til københavn det er at der " \ 284 | "sådan en bliver spillet sådan et form for tom og jerry kispus mellem københavns politi og " \ 285 | "så de har danske demonstranter for tibet og fåfalungongsom meget gerne vil vise deres " \ 286 | "utilfredshed med det kinesiske regime og det de opfatter som undertrykkelse af de her " \ 287 | "mindretal i kine og lige nu står støttekomiteen for ti bedet bag en demonstration på" \ 288 | " højbro plads i københavn lisbeth davidsen hvor mange er der kommet det er ikke " \ 289 | "de store folkemasser der er mødt op her på byggepladsen her er ekstra ord " * 2 290 | expected_output = self.streamer.punct_fixer.punctuate(model_input) 291 | for w in model_input.split(): 292 | partial_output = self.streamer(w) 293 | if partial_output is not None: 294 | self.assertIn(partial_output, expected_output) 295 | actual_output = self.streamer.finalize() 296 | self.assertEqual(actual_output, expected_output) 297 | 298 | def test_repeated_same_input(self): 299 | self.streamer("test") 300 | self.streamer("test") 301 | output = self.streamer.finalize() 302 | expected_output = "Test test." 303 | self.assertEqual(output, expected_output) 304 | 305 | def test_empty_string_input(self): 306 | self.streamer("") 307 | output = self.streamer.finalize() 308 | self.assertEqual(output, "") 309 | 310 | ### The below tests are isolated method unit tests 311 | def test_get_result_method(self): 312 | self.streamer("test test") 313 | # Call get_result at an intermediate state 314 | output = self.streamer.get_result() 315 | self.assertEqual(output, "") 316 | # Call get_result at a final state but without processing buffer 317 | output = self.streamer.get_result(is_finalized=True) 318 | self.assertEqual(output, "") 319 | # Call get_result at a final state where there is enough data to make the buffer processed 320 | self.streamer(" ".join(["test"]*98)) # gives a total of 100=chunk size 321 | output = self.streamer.get_result(is_finalized= True) 322 | self.assertEqual(output, ("Test "*100)[:-1]) 323 | 324 | def test_finalize_method(self): 325 | self.streamer("finalizing test") 326 | output = self.streamer.finalize() 327 | expected_output = "Finalizing Test." 328 | self.assertEqual(output, expected_output) 329 | 330 | def test_call_method(self): 331 | self.streamer("test input") 332 | self.assertEqual(len(self.streamer.buffer), 2) 333 | self.assertEqual(len(self.streamer.chunked_words), 0) 334 | self.streamer("test " * 100) 335 | self.assertEqual(len(self.streamer.buffer), 72) # Overlap size +2 336 | self.assertEqual(len(self.streamer.chunked_words), 100) # Chunk size 337 | 338 | def test_clear_method(self): 339 | self.streamer("clearing test") 340 | self.streamer.clear() 341 | self.assertEqual(self.streamer.buffer, []) 342 | self.assertEqual(self.streamer.chunked_words, []) 343 | 344 | if __name__ == '__main__': 345 | unittest.main() 346 | --------------------------------------------------------------------------------