├── version.txt ├── cttpunctuator ├── src │ ├── onnx │ │ ├── punc.bin │ │ └── configuration.json │ ├── utils │ │ ├── text_post_process.py │ │ └── OrtInferSession.py │ └── punctuator.py └── __init__.py ├── MANIFEST.in ├── .flake8 ├── test └── test.py ├── LICENSE ├── .github └── workflows │ └── python-package.yml ├── cttPunctuator.py ├── setup.py ├── .gitignore └── README.md /version.txt: -------------------------------------------------------------------------------- 1 | 0.0.2 -------------------------------------------------------------------------------- /cttpunctuator/src/onnx/punc.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lovemefan/CT-Transformer-punctuation/HEAD/cttpunctuator/src/onnx/punc.bin -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include cttpunctuator/src/onnx/configuration.json 2 | include cttpunctuator/src/onnx/punc.onnx 3 | include cttpunctuator/src/onnx/punc.yaml -------------------------------------------------------------------------------- /cttpunctuator/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # @FileName :__init__.py.py 3 | # @Time :2023/4/13 14:58 4 | # @Author :lovemefan 5 | # @Email :lovemefan@outlook.com 6 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | show-source=true 3 | statistics=true 4 | max-line-length = 108 5 | per-file-ignores = 6 | # line too long 7 | runtime/src/utils/kaldifeat/feature.py : E501 8 | runtime/src/utils/kaldifeat/ivector.py : E501 9 | 10 | exclude = 11 | .git, 12 | **/data/** 13 | **/onnx/** 14 | 15 | ignore = 16 | # E203 white space before ":" 17 | E203, 18 | # W503 line break before binary operator 19 | W503, 20 | # E226 missing whitespace around arithmetic operator 21 | E226, 22 | -------------------------------------------------------------------------------- /cttpunctuator/src/onnx/configuration.json: -------------------------------------------------------------------------------- 1 | { 2 | "framework": "onnx", 3 | "task" : "punctuation", 4 | "model" : { 5 | "type" : "generic-punc", 6 | "punc_model_name" : "punc.pb", 7 | "punc_model_config" : { 8 | "type": "pytorch", 9 | "code_base": "funasr", 10 | "mode": "punc", 11 | "lang": "zh-cn", 12 | "batch_size": 1, 13 | "punc_config": "punc.yaml", 14 | "model": "damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch" 15 | } 16 | }, 17 | "pipeline": { 18 | "type":"punc-inference" 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /test/test.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # @FileName :test.py.py 3 | # @Time :2023/4/19 13:39 4 | # @Author :lovemefan 5 | # @Email :lovemefan@outlook.com 6 | 7 | import logging 8 | 9 | from cttPunctuator import CttPunctuator 10 | 11 | logging.basicConfig( 12 | level=logging.INFO, 13 | format="[%(asctime)s %(levelname)s] [%(filename)s:%(lineno)d %(module)s.%(funcName)s] %(message)s", 14 | ) 15 | # offline mode 16 | punc = CttPunctuator() 17 | text = "据报道纽约时报使用ChatGPT创建了一个情人节消息生成器用户只需输入几个提示就可以得到一封自动生成的情书" 18 | logging.info(punc.punctuate(text)[0]) 19 | 20 | # online mode 21 | punc = CttPunctuator(online=True) 22 | text_in = ( 23 | "跨境河流是养育沿岸|人民的生命之源长期以来为帮助下游地区防灾减灾中方技术人员|" 24 | "在上游地区极为恶劣的自然条件下克服巨大困难甚至冒着生命危险|" 25 | "向印方提供汛期水文资料处理紧急事件中方重视印方在跨境河流>问题上的关切|" 26 | "愿意进一步完善双方联合工作机制|凡是|中方能做的我们|" 27 | "都会去做而且会做得更好我请印度朋友们放心中国在上游的|" 28 | "任何开发利用都会经过科学|规划和论证兼顾上下游的利益" 29 | ) 30 | 31 | vads = text_in.split("|") 32 | rec_result_all = "" 33 | for vad in vads: 34 | result = punc.punctuate(vad) 35 | rec_result_all += result[0] 36 | logging.info(f"Part: {rec_result_all}") 37 | 38 | logging.info(f"Final: {rec_result_all}") 39 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2014-2017 Alexey Popravka 4 | Copyright (c) 2021 Sean Stewart 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. -------------------------------------------------------------------------------- /.github/workflows/python-package.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python 3 | 4 | name: Python package 5 | 6 | on: 7 | push: 8 | branches: [ "main" ] 9 | pull_request: 10 | branches: [ "main" ] 11 | 12 | jobs: 13 | build: 14 | 15 | runs-on: ${{matrix.os}} 16 | strategy: 17 | fail-fast: false 18 | matrix: 19 | os: [ubuntu-20.04, windows-2019, macos-11] 20 | python-version: ["3.7.2", "3.8.0", "3.9.0", "3.10.0", "3.11.4"] 21 | 22 | steps: 23 | - uses: actions/checkout@v3 24 | - name: Set up Python ${{ matrix.python-version }} 25 | uses: actions/setup-python@v3 26 | with: 27 | python-version: ${{ matrix.python-version }} 28 | - name: Install dependencies 29 | run: | 30 | python3 -m pip install --upgrade pip 31 | pip install -U setuptools wheel 32 | - name: Test ctt punctuator 33 | run: | 34 | curl -L "https://huggingface.co/lovemefan/ctt_punctuator/resolve/main/cttpunctuator/src/onnx/punc.onnx" -o cttpunctuator/src/onnx/punc.onnx 35 | pip3 install -e . 36 | python3 test/test.py 37 | -------------------------------------------------------------------------------- /cttPunctuator.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # @FileName :ctt-punctuator.py 3 | # @Time :2023/4/13 15:03 4 | # @Author :lovemefan 5 | # @Email :lovemefan@outlook.com 6 | 7 | 8 | __author__ = "lovemefan" 9 | __copyright__ = "Copyright (C) 2023 lovemefan" 10 | __license__ = "MIT" 11 | __version__ = "v0.0.1" 12 | 13 | import logging 14 | import threading 15 | 16 | from cttpunctuator.src.punctuator import CT_Transformer, CT_Transformer_VadRealtime 17 | 18 | logging.basicConfig( 19 | level=logging.INFO, 20 | format="[%(asctime)s %(levelname)s] [%(filename)s:%(lineno)d %(module)s.%(funcName)s] %(message)s", 21 | ) 22 | 23 | lock = threading.RLock() 24 | 25 | 26 | class CttPunctuator: 27 | _offline_model = None 28 | _online_model = None 29 | 30 | def __init__(self, online: bool = False): 31 | """ 32 | punctuator with singleton pattern 33 | :param online: 34 | """ 35 | self.online = online 36 | 37 | if online: 38 | if CttPunctuator._online_model is None: 39 | with lock: 40 | if CttPunctuator._online_model is None: 41 | logging.info("Initializing punctuator model with online mode.") 42 | CttPunctuator._online_model = CT_Transformer_VadRealtime() 43 | self.param_dict = {"cache": []} 44 | logging.info("Online model initialized.") 45 | self.model = CttPunctuator._online_model 46 | 47 | else: 48 | if CttPunctuator._offline_model is None: 49 | with lock: 50 | if CttPunctuator._offline_model is None: 51 | logging.info("Initializing punctuator model with offline mode.") 52 | CttPunctuator._offline_model = CT_Transformer() 53 | logging.info("Offline model initialized.") 54 | self.model = CttPunctuator._offline_model 55 | 56 | logging.info("Model initialized.") 57 | 58 | def punctuate(self, text: str, param_dict=None): 59 | if self.online: 60 | param_dict = param_dict or self.param_dict 61 | return self.model(text, self.param_dict) 62 | else: 63 | return self.model(text) 64 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # @FileName :setup.py 3 | # @Time :2023/4/4 11:22 4 | # @Author :lovemefan 5 | # @Email :lovemefan@outlook.com 6 | import os 7 | from pathlib import Path 8 | 9 | from setuptools import find_namespace_packages, setup 10 | 11 | dirname = Path(os.path.dirname(__file__)) 12 | version_file = dirname / "version.txt" 13 | with open(version_file, "r") as f: 14 | version = f.read().strip() 15 | 16 | requirements = { 17 | "install": [ 18 | "setuptools<=65.0", 19 | "PyYAML", 20 | "typeguard==2.13.3", 21 | "onnxruntime>=1.14.1", 22 | ], 23 | "setup": [ 24 | "numpy==1.24.2", 25 | ], 26 | "all": [], 27 | } 28 | requirements["all"].extend(requirements["install"]) 29 | 30 | install_requires = requirements["install"] 31 | setup_requires = requirements["setup"] 32 | 33 | 34 | setup( 35 | name="cttpunctuator", 36 | version=version, 37 | url="https://github.com/lovemefan/CT-Transformer-punctuation", 38 | author="Lovemefan, Yunnan Key Laboratory of Artificial Intelligence, " 39 | "Kunming University of Science and Technology, Kunming, Yunnan ", 40 | author_email="lovemefan@outlook.com", 41 | description="ctt-punctuator: A enterprise-grade punctuator after chinese asr based " 42 | "on ct-transformer from funasr opensource", 43 | long_description=open(os.path.join(dirname, "README.md"), encoding="utf-8").read(), 44 | long_description_content_type="text/markdown", 45 | license="The MIT License", 46 | packages=find_namespace_packages(), 47 | include_package_data=True, 48 | install_requires=install_requires, 49 | python_requires=">=3.7.0", 50 | classifiers=[ 51 | "Programming Language :: Python", 52 | "Programming Language :: Python :: 3.7", 53 | "Programming Language :: Python :: 3.8", 54 | "Programming Language :: Python :: 3.9", 55 | "Programming Language :: Python :: 3.10", 56 | "Development Status :: 5 - Production/Stable", 57 | "Intended Audience :: Science/Research", 58 | "Operating System :: POSIX :: Linux", 59 | "License :: OSI Approved :: Apache Software License", 60 | "Topic :: Multimedia :: Sound/Audio :: Speech", 61 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 62 | "Topic :: Software Development :: Libraries :: Python Modules", 63 | ], 64 | ) 65 | -------------------------------------------------------------------------------- /cttpunctuator/src/utils/text_post_process.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # @FileName :text_post_process.py 3 | # @Time :2023/4/13 15:09 4 | # @Author :lovemefan 5 | # @Email :lovemefan@outlook.com 6 | from pathlib import Path 7 | from typing import Dict, Iterable, List, Union 8 | 9 | import numpy as np 10 | import yaml 11 | from typeguard import check_argument_types 12 | 13 | 14 | class TokenIDConverterError(Exception): 15 | pass 16 | 17 | 18 | class TokenIDConverter: 19 | def __init__( 20 | self, 21 | token_list: Union[List, str], 22 | ): 23 | check_argument_types() 24 | 25 | self.token_list = token_list 26 | self.unk_symbol = token_list[-1] 27 | self.token2id = {v: i for i, v in enumerate(self.token_list)} 28 | self.unk_id = self.token2id[self.unk_symbol] 29 | 30 | def get_num_vocabulary_size(self) -> int: 31 | return len(self.token_list) 32 | 33 | def ids2tokens(self, integers: Union[np.ndarray, Iterable[int]]) -> List[str]: 34 | if isinstance(integers, np.ndarray) and integers.ndim != 1: 35 | raise TokenIDConverterError( 36 | f"Must be 1 dim ndarray, but got {integers.ndim}" 37 | ) 38 | return [self.token_list[i] for i in integers] 39 | 40 | def tokens2ids(self, tokens: Iterable[str]) -> List[int]: 41 | return [self.token2id.get(i, self.unk_id) for i in tokens] 42 | 43 | 44 | def split_to_mini_sentence(words: list, word_limit: int = 20): 45 | assert word_limit > 1 46 | if len(words) <= word_limit: 47 | return [words] 48 | sentences = [] 49 | length = len(words) 50 | sentence_len = length // word_limit 51 | for i in range(sentence_len): 52 | sentences.append(words[i * word_limit : (i + 1) * word_limit]) 53 | if length % word_limit > 0: 54 | sentences.append(words[sentence_len * word_limit :]) 55 | return sentences 56 | 57 | 58 | def code_mix_split_words(text: str): 59 | words = [] 60 | segs = text.split() 61 | for seg in segs: 62 | # There is no space in seg. 63 | current_word = "" 64 | for c in seg: 65 | if len(c.encode()) == 1: 66 | # This is an ASCII char. 67 | current_word += c 68 | else: 69 | # This is a Chinese char. 70 | if len(current_word) > 0: 71 | words.append(current_word) 72 | current_word = "" 73 | words.append(c) 74 | if len(current_word) > 0: 75 | words.append(current_word) 76 | return words 77 | 78 | 79 | def read_yaml(yaml_path: Union[str, Path]) -> Dict: 80 | if not Path(yaml_path).exists(): 81 | raise FileExistsError(f"The {yaml_path} does not exist.") 82 | 83 | with open(str(yaml_path), "rb") as f: 84 | data = yaml.load(f, Loader=yaml.Loader) 85 | return data 86 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ -------------------------------------------------------------------------------- /cttpunctuator/src/utils/OrtInferSession.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # @FileName :OrtInferSession.py 3 | # @Time :2023/4/13 15:13 4 | # @Author :lovemefan 5 | # @Email :lovemefan@outlook.com 6 | import logging 7 | from pathlib import Path 8 | from typing import List, Union 9 | 10 | import numpy as np 11 | from onnxruntime import ( 12 | GraphOptimizationLevel, 13 | InferenceSession, 14 | SessionOptions, 15 | get_available_providers, 16 | get_device, 17 | ) 18 | 19 | 20 | class ONNXRuntimeError(Exception): 21 | pass 22 | 23 | 24 | class OrtInferSession: 25 | def __init__(self, model_file, device_id=-1, intra_op_num_threads=4): 26 | device_id = str(device_id) 27 | sess_opt = SessionOptions() 28 | sess_opt.intra_op_num_threads = intra_op_num_threads 29 | sess_opt.log_severity_level = 4 30 | sess_opt.enable_cpu_mem_arena = False 31 | sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL 32 | 33 | cuda_ep = "CUDAExecutionProvider" 34 | cuda_provider_options = { 35 | "device_id": device_id, 36 | "arena_extend_strategy": "kNextPowerOfTwo", 37 | "cudnn_conv_algo_search": "EXHAUSTIVE", 38 | "do_copy_in_default_stream": "true", 39 | } 40 | cpu_ep = "CPUExecutionProvider" 41 | cpu_provider_options = { 42 | "arena_extend_strategy": "kSameAsRequested", 43 | } 44 | 45 | EP_list = [] 46 | if ( 47 | device_id != "-1" 48 | and get_device() == "GPU" 49 | and cuda_ep in get_available_providers() 50 | ): 51 | EP_list = [(cuda_ep, cuda_provider_options)] 52 | EP_list.append((cpu_ep, cpu_provider_options)) 53 | 54 | self._verify_model(model_file) 55 | self.session = InferenceSession( 56 | model_file, sess_options=sess_opt, providers=EP_list 57 | ) 58 | 59 | if device_id != "-1" and cuda_ep not in self.session.get_providers(): 60 | logging.warnings.warn( 61 | f"{cuda_ep} is not avaiable for current env, " 62 | f"the inference part is automatically shifted to be executed under {cpu_ep}.\n" 63 | "Please ensure the installed onnxruntime-gpu version matches your cuda and cudnn version, " 64 | "you can check their relations from the offical web site: " 65 | "https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html", 66 | RuntimeWarning, 67 | ) 68 | 69 | def __call__( 70 | self, input_content: List[Union[np.ndarray, np.ndarray]] 71 | ) -> np.ndarray: 72 | input_dict = dict(zip(self.get_input_names(), input_content)) 73 | try: 74 | return self.session.run(self.get_output_names(), input_dict) 75 | except Exception as e: 76 | raise ONNXRuntimeError("ONNXRuntime inferece failed.") from e 77 | 78 | def get_input_names( 79 | self, 80 | ): 81 | return [v.name for v in self.session.get_inputs()] 82 | 83 | def get_output_names( 84 | self, 85 | ): 86 | return [v.name for v in self.session.get_outputs()] 87 | 88 | def get_character_list(self, key: str = "character"): 89 | return self.meta_dict[key].splitlines() 90 | 91 | def have_key(self, key: str = "character") -> bool: 92 | self.meta_dict = self.session.get_modelmeta().custom_metadata_map 93 | if key in self.meta_dict.keys(): 94 | return True 95 | return False 96 | 97 | @staticmethod 98 | def _verify_model(model_path): 99 | model_path = Path(model_path) 100 | if not model_path.exists(): 101 | raise FileNotFoundError(f"{model_path} does not exists.") 102 | if not model_path.is_file(): 103 | raise FileExistsError(f"{model_path} is not a file.") 104 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 |
4 |

Ctt punctuator

5 |
6 | 7 | 8 | ![python3.7](https://img.shields.io/badge/python-3.7-green.svg) 9 | ![python3.8](https://img.shields.io/badge/python-3.8-green.svg) 10 | ![python3.9](https://img.shields.io/badge/python-3.9-green.svg) 11 | ![python3.10](https://img.shields.io/badge/python-3.10-green.svg) 12 | 13 | 14 | 15 | A enterprise-grade Chinese-English code switch punctuator [funasr](https://github.com/alibaba-damo-academy/FunASR/). 16 | 17 | 18 | 19 |
20 |

Key Features

21 |
22 | 23 | - **General** 24 | 25 | ctt punctuator was trained on chinese-english code switch corpora. 26 | - [x] offline punctuator 27 | - [x] online punctuator 28 | - [x] punctuator for chinese-english code switch 29 | 30 | the onnx model file is 279M, you can download it from [here](https://github.com/lovemefan/CT-Transformer-punctuation/raw/main/cttpunctuator/src/onnx/punc.onnx) 31 | 32 | - **Highly Portable** 33 | 34 | ctt-punctuator reaps benefits from the rich ecosystems built around **ONNX** running everywhere where these runtimes are available. 35 | 36 | 37 | 38 | ## Installation 39 | 40 | ```bash 41 | 42 | 43 | git clone https://github.com/lovemefan/CT-Transformer-punctuation.git 44 | cd CT-Transformer-punctuation 45 | # download onnx model from huggingface 46 | wget "https://huggingface.co/lovemefan/ctt_punctuator/resolve/main/cttpunctuator/src/onnx/punc.onnx" -O cttpunctuator/src/onnx/punc.onnx 47 | # you can also download with huggingface mirror 48 | # wget "https://hf-mirror.com/lovemefan/ctt_punctuator/resolve/main/cttpunctuator/src/onnx/punc.onnx" -O cttpunctuator/src/onnx/punc.onnx 49 | pip install -e . 50 | ``` 51 | 52 | ## Usage 53 | 54 | ```python 55 | from cttPunctuator import CttPunctuator 56 | import logging 57 | logging.basicConfig( 58 | level=logging.INFO, 59 | format="[%(asctime)s %(levelname)s] [%(filename)s:%(lineno)d %(module)s.%(funcName)s] %(message)s", 60 | ) 61 | # offline mode 62 | punc = CttPunctuator() 63 | text = "据报道纽约时报使用ChatGPT创建了一个情人节消息生成器用户只需输入几个提示就可以得到一封自动生成的情书" 64 | logging.info(punc.punctuate(text)[0]) 65 | 66 | # online mode 67 | punc = CttPunctuator(online=True) 68 | text_in = "跨境河流是养育沿岸|人民的生命之源长期以来为帮助下游地区防灾减灾中方技术人员|在上游地区极为恶劣的自然条件下克服巨大困难甚至冒着生命危险|向印方提供汛期水文资料处理紧急事件中方重视印方在跨境河流>问题上的关切|愿意进一步完善双方联合工作机制|凡是|中方能做的我们|都会去做而且会做得更好我请印度朋友们放心中国在上游的|任何开发利用都会经过科学|规划和论证兼顾上下游的利益" 69 | 70 | vads = text_in.split("|") 71 | rec_result_all = "" 72 | param_dict = {"cache": []} 73 | for vad in vads: 74 | result = punc.punctuate(vad, param_dict=param_dict) 75 | rec_result_all += result[0] 76 | logging.info(f"Part: {rec_result_all}") 77 | 78 | logging.info(f"Final: {rec_result_all}") 79 | ``` 80 | ## Result 81 | ```bash 82 | [2023-04-19 01:12:39,308 INFO] [ctt-punctuator.py:50 ctt-punctuator.__init__] Initializing punctuator model with offline mode. 83 | [2023-04-19 01:12:55,854 INFO] [ctt-punctuator.py:52 ctt-punctuator.__init__] Offline model initialized. 84 | [2023-04-19 01:12:55,854 INFO] [ctt-punctuator.py:55 ctt-punctuator.__init__] Model initialized. 85 | [2023-04-19 01:12:55,868 INFO] [ctt-punctuator.py:67 ctt-punctuator.] 据报道,纽约时报使用ChatGPT创建了一个情人节消息生成器,用户只需输入几个提示,就可以得到一封自动生成的情书。 86 | [2023-04-19 01:12:55,868 INFO] [ctt-punctuator.py:40 ctt-punctuator.__init__] Initializing punctuator model with online mode. 87 | [2023-04-19 01:13:12,499 INFO] [ctt-punctuator.py:43 ctt-punctuator.__init__] Online model initialized. 88 | [2023-04-19 01:13:12,499 INFO] [ctt-punctuator.py:55 ctt-punctuator.__init__] Model initialized. 89 | [2023-04-19 01:13:12,502 INFO] [ctt-punctuator.py:77 ctt-punctuator.] Partial: 跨境河流是养育沿岸 90 | [2023-04-19 01:13:12,508 INFO] [ctt-punctuator.py:77 ctt-punctuator.] Partial: 跨境河流是养育沿岸人民的生命之源。长期以来,为帮助下游地区防灾减灾中方技术人员 91 | [2023-04-19 01:13:12,521 INFO] [ctt-punctuator.py:77 ctt-punctuator.] Partial: 跨境河流是养育沿岸人民的生命之源。长期以来,为帮助下游地区防灾减灾中方技术人员在上游地区极为恶劣的自然条件下克服巨大困难,甚至冒着生命危险 92 | [2023-04-19 01:13:12,547 INFO] [ctt-punctuator.py:77 ctt-punctuator.] Partial: 跨境河流是养育沿岸人民的生命之源。长期以来,为帮助下游地区防灾减灾中方技术人员在上游地区极为恶劣的自然条件下克服巨大困难,甚至冒着生命危险,向印方提供汛期水文资料处理紧急事件。中方重视印方在跨境河流>问题上的关切 93 | [2023-04-19 01:13:12,553 INFO] [ctt-punctuator.py:77 ctt-punctuator.] Partial: 跨境河流是养育沿岸人民的生命之源。长期以来,为帮助下游地区防灾减灾中方技术人员在上游地区极为恶劣的自然条件下克服巨大困难,甚至冒着生命危险,向印方提供汛期水文资料处理紧急事件。中方重视印方在跨境河流>问题上的关切,愿意进一步完善双方联合工作机制 94 | [2023-04-19 01:13:12,559 INFO] [ctt-punctuator.py:77 ctt-punctuator.] Partial: 跨境河流是养育沿岸人民的生命之源。长期以来,为帮助下游地区防灾减灾中方技术人员在上游地区极为恶劣的自然条件下克服巨大困难,甚至冒着生命危险,向印方提供汛期水文资料处理紧急事件。中方重视印方在跨境河流>问题上的关切,愿意进一步完善双方联合工作机制。凡是 95 | [2023-04-19 01:13:12,560 INFO] [ctt-punctuator.py:77 ctt-punctuator.] Partial: 跨境河流是养育沿岸人民的生命之源。长期以来,为帮助下游地区防灾减灾中方技术人员在上游地区极为恶劣的自然条件下克服巨大困难,甚至冒着生命危险,向印方提供汛期水文资料处理紧急事件。中方重视印方在跨境河流>问题上的关切,愿意进一步完善双方联合工作机制。凡是中方能做的,我们 96 | [2023-04-19 01:13:12,567 INFO] [ctt-punctuator.py:77 ctt-punctuator.] Partial: 跨境河流是养育沿岸人民的生命之源。长期以来,为帮助下游地区防灾减灾中方技术人员在上游地区极为恶劣的自然条件下克服巨大困难,甚至冒着生命危险,向印方提供汛期水文资料处理紧急事件。中方重视印方在跨境河流>问题上的关切,愿意进一步完善双方联合工作机制。凡是中方能做的,我们都会去做,而且会做得更好。我请印度朋友们放心,中国在上游的 97 | [2023-04-19 01:13:12,572 INFO] [ctt-punctuator.py:77 ctt-punctuator.] Partial: 跨境河流是养育沿岸人民的生命之源。长期以来,为帮助下游地区防灾减灾中方技术人员在上游地区极为恶劣的自然条件下克服巨大困难,甚至冒着生命危险,向印方提供汛期水文资料处理紧急事件。中方重视印方在跨境河流>问题上的关切,愿意进一步完善双方联合工作机制。凡是中方能做的,我们都会去做,而且会做得更好。我请印度朋友们放心,中国在上游的任何开发利用,都会经过科学 98 | [2023-04-19 01:13:12,578 INFO] [ctt-punctuator.py:77 ctt-punctuator.] Partial: 跨境河流是养育沿岸人民的生命之源。长期以来,为帮助下游地区防灾减灾中方技术人员在上游地区极为恶劣的自然条件下克服巨大困难,甚至冒着生命危险,向印方提供汛期水文资料处理紧急事件。中方重视印方在跨境河流>问题上的关切,愿意进一步完善双方联合工作机制。凡是中方能做的,我们都会去做,而且会做得更好。我请印度朋友们放心,中国在上游的任何开发利用,都会经过科学规划和论证,兼顾上下游的利益 99 | [2023-04-19 01:13:12,578 INFO] [ctt-punctuator.py:79 ctt-punctuator.] Final: 跨境河流是养育沿岸人民的生命之源。长期以来,为帮助下游地区防灾减灾中方技术人员在上游地区极为恶劣的自然条件下克服巨大困难,甚至冒着生命危险,向印方提供汛期水文资料处理紧急事件。中方重视印方在跨境河流>问题上的关切,愿意进一步完善双方联合工作机制。凡是中方能做的,我们都会去做,而且会做得更好。我请印度朋友们放心,中国在上游的任何开发利用,都会经过科学规划和论证,兼顾上下游的利益 100 | ``` 101 | 102 | ## Citation 103 | ``` 104 | @inproceedings{chen2020controllable, 105 | title={Controllable Time-Delay Transformer for Real-Time Punctuation Prediction and Disfluency Detection}, 106 | author={Chen, Qian and Chen, Mengzhe and Li, Bo and Wang, Wen}, 107 | booktitle={ICASSP 2020-2020 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)}, 108 | pages={8069--8073}, 109 | year={2020}, 110 | organization={IEEE} 111 | } 112 | ``` 113 | ``` 114 | @misc{FunASR, 115 | author = {Speech Lab, Alibaba Group, China}, 116 | title = {FunASR: A Fundamental End-to-End Speech Recognition Toolkit}, 117 | year = {2023}, 118 | publisher = {GitHub}, 119 | journal = {GitHub repository}, 120 | howpublished = {\url{https://github.com/alibaba-damo-academy/FunASR/}}, 121 | } 122 | 123 | ``` 124 | -------------------------------------------------------------------------------- /cttpunctuator/src/punctuator.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os.path 3 | import pickle 4 | from pathlib import Path 5 | from typing import Tuple, Union 6 | 7 | import numpy as np 8 | 9 | from cttpunctuator.src.utils.OrtInferSession import ONNXRuntimeError, OrtInferSession 10 | from cttpunctuator.src.utils.text_post_process import ( 11 | TokenIDConverter, 12 | code_mix_split_words, 13 | split_to_mini_sentence, 14 | ) 15 | 16 | 17 | class CT_Transformer: 18 | """ 19 | Author: Speech Lab, Alibaba Group, China 20 | CT-Transformer: Controllable time-delay transformer 21 | for real-time punctuation prediction and disfluency detection 22 | https://arxiv.org/pdf/2003.01309.pdf 23 | """ 24 | 25 | def __init__( 26 | self, 27 | model_dir: Union[str, Path] = None, 28 | batch_size: int = 1, 29 | device_id: Union[str, int] = "-1", 30 | quantize: bool = False, 31 | intra_op_num_threads: int = 4, 32 | ): 33 | model_dir = model_dir or os.path.join(os.path.dirname(__file__), "onnx") 34 | if model_dir is None or not Path(model_dir).exists(): 35 | raise FileNotFoundError(f"{model_dir} does not exist.") 36 | 37 | model_file = os.path.join(model_dir, "punc.onnx") 38 | if quantize: 39 | model_file = os.path.join(model_dir, "model_quant.onnx") 40 | config_file = os.path.join(model_dir, "punc.bin") 41 | with open(config_file, "rb") as file: 42 | config = pickle.load(file) 43 | 44 | self.converter = TokenIDConverter(config["token_list"]) 45 | self.ort_infer = OrtInferSession( 46 | model_file, device_id, intra_op_num_threads=intra_op_num_threads 47 | ) 48 | self.batch_size = 1 49 | self.punc_list = config["punc_list"] 50 | self.period = 0 51 | for i in range(len(self.punc_list)): 52 | if self.punc_list[i] == ",": 53 | self.punc_list[i] = "," 54 | elif self.punc_list[i] == "?": 55 | self.punc_list[i] = "?" 56 | elif self.punc_list[i] == "。": 57 | self.period = i 58 | 59 | def __call__(self, text: Union[list, str], split_size=20): 60 | split_text = code_mix_split_words(text) 61 | split_text_id = self.converter.tokens2ids(split_text) 62 | mini_sentences = split_to_mini_sentence(split_text, split_size) 63 | mini_sentences_id = split_to_mini_sentence(split_text_id, split_size) 64 | assert len(mini_sentences) == len(mini_sentences_id) 65 | cache_sent = [] 66 | cache_sent_id = [] 67 | new_mini_sentence = "" 68 | new_mini_sentence_punc = [] 69 | cache_pop_trigger_limit = 200 70 | for mini_sentence_i in range(len(mini_sentences)): 71 | mini_sentence = mini_sentences[mini_sentence_i] 72 | mini_sentence_id = mini_sentences_id[mini_sentence_i] 73 | mini_sentence = cache_sent + mini_sentence 74 | 75 | mini_sentence_id = np.array(cache_sent_id + mini_sentence_id, dtype="int64") 76 | text_lengths = np.array([len(mini_sentence)], dtype="int32") 77 | 78 | data = { 79 | "text": mini_sentence_id[None, :], 80 | "text_lengths": text_lengths, 81 | } 82 | try: 83 | outputs = self.infer(data["text"], data["text_lengths"]) 84 | y = outputs[0] 85 | punctuations = np.argmax(y, axis=-1)[0] 86 | assert punctuations.size == len(mini_sentence) 87 | except ONNXRuntimeError as e: 88 | logging.exception(e) 89 | 90 | # Search for the last Period/QuestionMark as cache 91 | if mini_sentence_i < len(mini_sentences) - 1: 92 | sentenceEnd = -1 93 | last_comma_index = -1 94 | for i in range(len(punctuations) - 2, 1, -1): 95 | if ( 96 | self.punc_list[punctuations[i]] == "。" 97 | or self.punc_list[punctuations[i]] == "?" 98 | ): 99 | sentenceEnd = i 100 | break 101 | if last_comma_index < 0 and self.punc_list[punctuations[i]] == ",": 102 | last_comma_index = i 103 | 104 | if ( 105 | sentenceEnd < 0 106 | and len(mini_sentence) > cache_pop_trigger_limit 107 | and last_comma_index >= 0 108 | ): 109 | # The sentence it too long, cut off at a comma. 110 | sentenceEnd = last_comma_index 111 | punctuations[sentenceEnd] = self.period 112 | cache_sent = mini_sentence[sentenceEnd + 1 :] 113 | cache_sent_id = mini_sentence_id[sentenceEnd + 1 :].tolist() 114 | mini_sentence = mini_sentence[0 : sentenceEnd + 1] 115 | punctuations = punctuations[0 : sentenceEnd + 1] 116 | 117 | new_mini_sentence_punc += [int(x) for x in punctuations] 118 | words_with_punc = [] 119 | for i in range(len(mini_sentence)): 120 | if i > 0: 121 | if ( 122 | len(mini_sentence[i][0].encode()) == 1 123 | and len(mini_sentence[i - 1][0].encode()) == 1 124 | ): 125 | mini_sentence[i] = " " + mini_sentence[i] 126 | words_with_punc.append(mini_sentence[i]) 127 | if self.punc_list[punctuations[i]] != "_": 128 | words_with_punc.append(self.punc_list[punctuations[i]]) 129 | new_mini_sentence += "".join(words_with_punc) 130 | # Add Period for the end of the sentence 131 | new_mini_sentence_out = new_mini_sentence 132 | new_mini_sentence_punc_out = new_mini_sentence_punc 133 | if mini_sentence_i == len(mini_sentences) - 1: 134 | if new_mini_sentence[-1] == "," or new_mini_sentence[-1] == "、": 135 | new_mini_sentence_out = new_mini_sentence[:-1] + "。" 136 | new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [ 137 | self.period 138 | ] 139 | elif new_mini_sentence[-1] != "。" and new_mini_sentence[-1] != "?": 140 | new_mini_sentence_out = new_mini_sentence + "。" 141 | new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [ 142 | self.period 143 | ] 144 | return new_mini_sentence_out, new_mini_sentence_punc_out 145 | 146 | def infer( 147 | self, feats: np.ndarray, feats_len: np.ndarray 148 | ) -> Tuple[np.ndarray, np.ndarray]: 149 | outputs = self.ort_infer([feats, feats_len]) 150 | return outputs 151 | 152 | 153 | class CT_Transformer_VadRealtime(CT_Transformer): 154 | """ 155 | Author: Speech Lab, Alibaba Group, China 156 | CT-Transformer: Controllable time-delay transformer for 157 | real-time punctuation prediction and disfluency detection 158 | https://arxiv.org/pdf/2003.01309.pdf 159 | """ 160 | 161 | def __init__( 162 | self, 163 | model_dir: Union[str, Path] = None, 164 | batch_size: int = 1, 165 | device_id: Union[str, int] = "-1", 166 | quantize: bool = False, 167 | intra_op_num_threads: int = 4, 168 | ): 169 | super(CT_Transformer_VadRealtime, self).__init__( 170 | model_dir, batch_size, device_id, quantize, intra_op_num_threads 171 | ) 172 | 173 | def __call__(self, text: str, param_dict: map, split_size=20): 174 | cache_key = "cache" 175 | assert cache_key in param_dict 176 | cache = param_dict[cache_key] 177 | if cache is not None and len(cache) > 0: 178 | precache = "".join(cache) 179 | else: 180 | precache = "" 181 | cache = [] 182 | full_text = precache + text 183 | split_text = code_mix_split_words(full_text) 184 | split_text_id = self.converter.tokens2ids(split_text) 185 | mini_sentences = split_to_mini_sentence(split_text, split_size) 186 | mini_sentences_id = split_to_mini_sentence(split_text_id, split_size) 187 | new_mini_sentence_punc = [] 188 | assert len(mini_sentences) == len(mini_sentences_id) 189 | 190 | cache_sent = [] 191 | cache_sent_id = np.array([], dtype="int32") 192 | sentence_punc_list = [] 193 | sentence_words_list = [] 194 | cache_pop_trigger_limit = 200 195 | skip_num = 0 196 | for mini_sentence_i in range(len(mini_sentences)): 197 | mini_sentence = mini_sentences[mini_sentence_i] 198 | mini_sentence_id = mini_sentences_id[mini_sentence_i] 199 | mini_sentence = cache_sent + mini_sentence 200 | mini_sentence_id = np.concatenate((cache_sent_id, mini_sentence_id), axis=0) 201 | text_length = len(mini_sentence_id) 202 | data = { 203 | "input": np.array(mini_sentence_id[None, :], dtype="int64"), 204 | "text_lengths": np.array([text_length], dtype="int32"), 205 | "vad_mask": self.vad_mask(text_length, len(cache))[ 206 | None, None, :, : 207 | ].astype(np.float32), 208 | "sub_masks": np.tril( 209 | np.ones((text_length, text_length), dtype=np.float32) 210 | )[None, None, :, :].astype(np.float32), 211 | } 212 | try: 213 | outputs = self.infer( 214 | data["input"], 215 | data["text_lengths"], 216 | data["vad_mask"], 217 | data["sub_masks"], 218 | ) 219 | y = outputs[0] 220 | punctuations = np.argmax(y, axis=-1)[0] 221 | assert punctuations.size == len(mini_sentence) 222 | except ONNXRuntimeError as e: 223 | logging.exception(e) 224 | 225 | # Search for the last Period/QuestionMark as cache 226 | if mini_sentence_i < len(mini_sentences) - 1: 227 | sentenceEnd = -1 228 | last_comma_index = -1 229 | for i in range(len(punctuations) - 2, 1, -1): 230 | if ( 231 | self.punc_list[punctuations[i]] == "。" 232 | or self.punc_list[punctuations[i]] == "?" 233 | ): 234 | sentenceEnd = i 235 | break 236 | if last_comma_index < 0 and self.punc_list[punctuations[i]] == ",": 237 | last_comma_index = i 238 | 239 | if ( 240 | sentenceEnd < 0 241 | and len(mini_sentence) > cache_pop_trigger_limit 242 | and last_comma_index >= 0 243 | ): 244 | # The sentence it too long, cut off at a comma. 245 | sentenceEnd = last_comma_index 246 | punctuations[sentenceEnd] = self.period 247 | cache_sent = mini_sentence[sentenceEnd + 1 :] 248 | cache_sent_id = mini_sentence_id[sentenceEnd + 1 :] 249 | mini_sentence = mini_sentence[0 : sentenceEnd + 1] 250 | punctuations = punctuations[0 : sentenceEnd + 1] 251 | 252 | punctuations_np = [int(x) for x in punctuations] 253 | new_mini_sentence_punc += punctuations_np 254 | sentence_punc_list += [self.punc_list[int(x)] for x in punctuations_np] 255 | sentence_words_list += mini_sentence 256 | 257 | assert len(sentence_punc_list) == len(sentence_words_list) 258 | words_with_punc = [] 259 | sentence_punc_list_out = [] 260 | for i in range(0, len(sentence_words_list)): 261 | if i > 0: 262 | if ( 263 | len(sentence_words_list[i][0].encode()) == 1 264 | and len(sentence_words_list[i - 1][-1].encode()) == 1 265 | ): 266 | sentence_words_list[i] = " " + sentence_words_list[i] 267 | if skip_num < len(cache): 268 | skip_num += 1 269 | else: 270 | words_with_punc.append(sentence_words_list[i]) 271 | if skip_num >= len(cache): 272 | sentence_punc_list_out.append(sentence_punc_list[i]) 273 | if sentence_punc_list[i] != "_": 274 | words_with_punc.append(sentence_punc_list[i]) 275 | sentence_out = "".join(words_with_punc) 276 | 277 | sentenceEnd = -1 278 | for i in range(len(sentence_punc_list) - 2, 1, -1): 279 | if sentence_punc_list[i] == "。" or sentence_punc_list[i] == "?": 280 | sentenceEnd = i 281 | break 282 | cache_out = sentence_words_list[sentenceEnd + 1 :] 283 | if sentence_out[-1] in self.punc_list: 284 | sentence_out = sentence_out[:-1] 285 | sentence_punc_list_out[-1] = "_" 286 | param_dict[cache_key] = cache_out 287 | return sentence_out, sentence_punc_list_out, cache_out 288 | 289 | def vad_mask(self, size, vad_pos, dtype=np.bool_): 290 | """Create mask for decoder self-attention. 291 | 292 | :param int size: size of mask 293 | :param int vad_pos: index of vad index 294 | :param torch.dtype dtype: result dtype 295 | :rtype: torch.Tensor (B, Lmax, Lmax) 296 | """ 297 | ret = np.ones((size, size), dtype=dtype) 298 | if vad_pos <= 0 or vad_pos >= size: 299 | return ret 300 | sub_corner = np.zeros((vad_pos - 1, size - vad_pos), dtype=dtype) 301 | ret[0 : vad_pos - 1, vad_pos:] = sub_corner 302 | return ret 303 | 304 | def infer( 305 | self, 306 | feats: np.ndarray, 307 | feats_len: np.ndarray, 308 | vad_mask: np.ndarray, 309 | sub_masks: np.ndarray, 310 | ) -> Tuple[np.ndarray, np.ndarray]: 311 | outputs = self.ort_infer([feats, feats_len, vad_mask, sub_masks]) 312 | return outputs 313 | --------------------------------------------------------------------------------