├── version.txt
├── cttpunctuator
├── src
│ ├── onnx
│ │ ├── punc.bin
│ │ └── configuration.json
│ ├── utils
│ │ ├── text_post_process.py
│ │ └── OrtInferSession.py
│ └── punctuator.py
└── __init__.py
├── MANIFEST.in
├── .flake8
├── test
└── test.py
├── LICENSE
├── .github
└── workflows
│ └── python-package.yml
├── cttPunctuator.py
├── setup.py
├── .gitignore
└── README.md
/version.txt:
--------------------------------------------------------------------------------
1 | 0.0.2
--------------------------------------------------------------------------------
/cttpunctuator/src/onnx/punc.bin:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lovemefan/CT-Transformer-punctuation/HEAD/cttpunctuator/src/onnx/punc.bin
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include cttpunctuator/src/onnx/configuration.json
2 | include cttpunctuator/src/onnx/punc.onnx
3 | include cttpunctuator/src/onnx/punc.yaml
--------------------------------------------------------------------------------
/cttpunctuator/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | # @FileName :__init__.py.py
3 | # @Time :2023/4/13 14:58
4 | # @Author :lovemefan
5 | # @Email :lovemefan@outlook.com
6 |
--------------------------------------------------------------------------------
/.flake8:
--------------------------------------------------------------------------------
1 | [flake8]
2 | show-source=true
3 | statistics=true
4 | max-line-length = 108
5 | per-file-ignores =
6 | # line too long
7 | runtime/src/utils/kaldifeat/feature.py : E501
8 | runtime/src/utils/kaldifeat/ivector.py : E501
9 |
10 | exclude =
11 | .git,
12 | **/data/**
13 | **/onnx/**
14 |
15 | ignore =
16 | # E203 white space before ":"
17 | E203,
18 | # W503 line break before binary operator
19 | W503,
20 | # E226 missing whitespace around arithmetic operator
21 | E226,
22 |
--------------------------------------------------------------------------------
/cttpunctuator/src/onnx/configuration.json:
--------------------------------------------------------------------------------
1 | {
2 | "framework": "onnx",
3 | "task" : "punctuation",
4 | "model" : {
5 | "type" : "generic-punc",
6 | "punc_model_name" : "punc.pb",
7 | "punc_model_config" : {
8 | "type": "pytorch",
9 | "code_base": "funasr",
10 | "mode": "punc",
11 | "lang": "zh-cn",
12 | "batch_size": 1,
13 | "punc_config": "punc.yaml",
14 | "model": "damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
15 | }
16 | },
17 | "pipeline": {
18 | "type":"punc-inference"
19 | }
20 | }
21 |
--------------------------------------------------------------------------------
/test/test.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | # @FileName :test.py.py
3 | # @Time :2023/4/19 13:39
4 | # @Author :lovemefan
5 | # @Email :lovemefan@outlook.com
6 |
7 | import logging
8 |
9 | from cttPunctuator import CttPunctuator
10 |
11 | logging.basicConfig(
12 | level=logging.INFO,
13 | format="[%(asctime)s %(levelname)s] [%(filename)s:%(lineno)d %(module)s.%(funcName)s] %(message)s",
14 | )
15 | # offline mode
16 | punc = CttPunctuator()
17 | text = "据报道纽约时报使用ChatGPT创建了一个情人节消息生成器用户只需输入几个提示就可以得到一封自动生成的情书"
18 | logging.info(punc.punctuate(text)[0])
19 |
20 | # online mode
21 | punc = CttPunctuator(online=True)
22 | text_in = (
23 | "跨境河流是养育沿岸|人民的生命之源长期以来为帮助下游地区防灾减灾中方技术人员|"
24 | "在上游地区极为恶劣的自然条件下克服巨大困难甚至冒着生命危险|"
25 | "向印方提供汛期水文资料处理紧急事件中方重视印方在跨境河流>问题上的关切|"
26 | "愿意进一步完善双方联合工作机制|凡是|中方能做的我们|"
27 | "都会去做而且会做得更好我请印度朋友们放心中国在上游的|"
28 | "任何开发利用都会经过科学|规划和论证兼顾上下游的利益"
29 | )
30 |
31 | vads = text_in.split("|")
32 | rec_result_all = ""
33 | for vad in vads:
34 | result = punc.punctuate(vad)
35 | rec_result_all += result[0]
36 | logging.info(f"Part: {rec_result_all}")
37 |
38 | logging.info(f"Final: {rec_result_all}")
39 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | The MIT License (MIT)
2 |
3 | Copyright (c) 2014-2017 Alexey Popravka
4 | Copyright (c) 2021 Sean Stewart
5 |
6 | Permission is hereby granted, free of charge, to any person obtaining a copy
7 | of this software and associated documentation files (the "Software"), to deal
8 | in the Software without restriction, including without limitation the rights
9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | copies of the Software, and to permit persons to whom the Software is
11 | furnished to do so, subject to the following conditions:
12 |
13 | The above copyright notice and this permission notice shall be included in all
14 | copies or substantial portions of the Software.
15 |
16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | SOFTWARE.
--------------------------------------------------------------------------------
/.github/workflows/python-package.yml:
--------------------------------------------------------------------------------
1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions
2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python
3 |
4 | name: Python package
5 |
6 | on:
7 | push:
8 | branches: [ "main" ]
9 | pull_request:
10 | branches: [ "main" ]
11 |
12 | jobs:
13 | build:
14 |
15 | runs-on: ${{matrix.os}}
16 | strategy:
17 | fail-fast: false
18 | matrix:
19 | os: [ubuntu-20.04, windows-2019, macos-11]
20 | python-version: ["3.7.2", "3.8.0", "3.9.0", "3.10.0", "3.11.4"]
21 |
22 | steps:
23 | - uses: actions/checkout@v3
24 | - name: Set up Python ${{ matrix.python-version }}
25 | uses: actions/setup-python@v3
26 | with:
27 | python-version: ${{ matrix.python-version }}
28 | - name: Install dependencies
29 | run: |
30 | python3 -m pip install --upgrade pip
31 | pip install -U setuptools wheel
32 | - name: Test ctt punctuator
33 | run: |
34 | curl -L "https://huggingface.co/lovemefan/ctt_punctuator/resolve/main/cttpunctuator/src/onnx/punc.onnx" -o cttpunctuator/src/onnx/punc.onnx
35 | pip3 install -e .
36 | python3 test/test.py
37 |
--------------------------------------------------------------------------------
/cttPunctuator.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | # @FileName :ctt-punctuator.py
3 | # @Time :2023/4/13 15:03
4 | # @Author :lovemefan
5 | # @Email :lovemefan@outlook.com
6 |
7 |
8 | __author__ = "lovemefan"
9 | __copyright__ = "Copyright (C) 2023 lovemefan"
10 | __license__ = "MIT"
11 | __version__ = "v0.0.1"
12 |
13 | import logging
14 | import threading
15 |
16 | from cttpunctuator.src.punctuator import CT_Transformer, CT_Transformer_VadRealtime
17 |
18 | logging.basicConfig(
19 | level=logging.INFO,
20 | format="[%(asctime)s %(levelname)s] [%(filename)s:%(lineno)d %(module)s.%(funcName)s] %(message)s",
21 | )
22 |
23 | lock = threading.RLock()
24 |
25 |
26 | class CttPunctuator:
27 | _offline_model = None
28 | _online_model = None
29 |
30 | def __init__(self, online: bool = False):
31 | """
32 | punctuator with singleton pattern
33 | :param online:
34 | """
35 | self.online = online
36 |
37 | if online:
38 | if CttPunctuator._online_model is None:
39 | with lock:
40 | if CttPunctuator._online_model is None:
41 | logging.info("Initializing punctuator model with online mode.")
42 | CttPunctuator._online_model = CT_Transformer_VadRealtime()
43 | self.param_dict = {"cache": []}
44 | logging.info("Online model initialized.")
45 | self.model = CttPunctuator._online_model
46 |
47 | else:
48 | if CttPunctuator._offline_model is None:
49 | with lock:
50 | if CttPunctuator._offline_model is None:
51 | logging.info("Initializing punctuator model with offline mode.")
52 | CttPunctuator._offline_model = CT_Transformer()
53 | logging.info("Offline model initialized.")
54 | self.model = CttPunctuator._offline_model
55 |
56 | logging.info("Model initialized.")
57 |
58 | def punctuate(self, text: str, param_dict=None):
59 | if self.online:
60 | param_dict = param_dict or self.param_dict
61 | return self.model(text, self.param_dict)
62 | else:
63 | return self.model(text)
64 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | # @FileName :setup.py
3 | # @Time :2023/4/4 11:22
4 | # @Author :lovemefan
5 | # @Email :lovemefan@outlook.com
6 | import os
7 | from pathlib import Path
8 |
9 | from setuptools import find_namespace_packages, setup
10 |
11 | dirname = Path(os.path.dirname(__file__))
12 | version_file = dirname / "version.txt"
13 | with open(version_file, "r") as f:
14 | version = f.read().strip()
15 |
16 | requirements = {
17 | "install": [
18 | "setuptools<=65.0",
19 | "PyYAML",
20 | "typeguard==2.13.3",
21 | "onnxruntime>=1.14.1",
22 | ],
23 | "setup": [
24 | "numpy==1.24.2",
25 | ],
26 | "all": [],
27 | }
28 | requirements["all"].extend(requirements["install"])
29 |
30 | install_requires = requirements["install"]
31 | setup_requires = requirements["setup"]
32 |
33 |
34 | setup(
35 | name="cttpunctuator",
36 | version=version,
37 | url="https://github.com/lovemefan/CT-Transformer-punctuation",
38 | author="Lovemefan, Yunnan Key Laboratory of Artificial Intelligence, "
39 | "Kunming University of Science and Technology, Kunming, Yunnan ",
40 | author_email="lovemefan@outlook.com",
41 | description="ctt-punctuator: A enterprise-grade punctuator after chinese asr based "
42 | "on ct-transformer from funasr opensource",
43 | long_description=open(os.path.join(dirname, "README.md"), encoding="utf-8").read(),
44 | long_description_content_type="text/markdown",
45 | license="The MIT License",
46 | packages=find_namespace_packages(),
47 | include_package_data=True,
48 | install_requires=install_requires,
49 | python_requires=">=3.7.0",
50 | classifiers=[
51 | "Programming Language :: Python",
52 | "Programming Language :: Python :: 3.7",
53 | "Programming Language :: Python :: 3.8",
54 | "Programming Language :: Python :: 3.9",
55 | "Programming Language :: Python :: 3.10",
56 | "Development Status :: 5 - Production/Stable",
57 | "Intended Audience :: Science/Research",
58 | "Operating System :: POSIX :: Linux",
59 | "License :: OSI Approved :: Apache Software License",
60 | "Topic :: Multimedia :: Sound/Audio :: Speech",
61 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
62 | "Topic :: Software Development :: Libraries :: Python Modules",
63 | ],
64 | )
65 |
--------------------------------------------------------------------------------
/cttpunctuator/src/utils/text_post_process.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | # @FileName :text_post_process.py
3 | # @Time :2023/4/13 15:09
4 | # @Author :lovemefan
5 | # @Email :lovemefan@outlook.com
6 | from pathlib import Path
7 | from typing import Dict, Iterable, List, Union
8 |
9 | import numpy as np
10 | import yaml
11 | from typeguard import check_argument_types
12 |
13 |
14 | class TokenIDConverterError(Exception):
15 | pass
16 |
17 |
18 | class TokenIDConverter:
19 | def __init__(
20 | self,
21 | token_list: Union[List, str],
22 | ):
23 | check_argument_types()
24 |
25 | self.token_list = token_list
26 | self.unk_symbol = token_list[-1]
27 | self.token2id = {v: i for i, v in enumerate(self.token_list)}
28 | self.unk_id = self.token2id[self.unk_symbol]
29 |
30 | def get_num_vocabulary_size(self) -> int:
31 | return len(self.token_list)
32 |
33 | def ids2tokens(self, integers: Union[np.ndarray, Iterable[int]]) -> List[str]:
34 | if isinstance(integers, np.ndarray) and integers.ndim != 1:
35 | raise TokenIDConverterError(
36 | f"Must be 1 dim ndarray, but got {integers.ndim}"
37 | )
38 | return [self.token_list[i] for i in integers]
39 |
40 | def tokens2ids(self, tokens: Iterable[str]) -> List[int]:
41 | return [self.token2id.get(i, self.unk_id) for i in tokens]
42 |
43 |
44 | def split_to_mini_sentence(words: list, word_limit: int = 20):
45 | assert word_limit > 1
46 | if len(words) <= word_limit:
47 | return [words]
48 | sentences = []
49 | length = len(words)
50 | sentence_len = length // word_limit
51 | for i in range(sentence_len):
52 | sentences.append(words[i * word_limit : (i + 1) * word_limit])
53 | if length % word_limit > 0:
54 | sentences.append(words[sentence_len * word_limit :])
55 | return sentences
56 |
57 |
58 | def code_mix_split_words(text: str):
59 | words = []
60 | segs = text.split()
61 | for seg in segs:
62 | # There is no space in seg.
63 | current_word = ""
64 | for c in seg:
65 | if len(c.encode()) == 1:
66 | # This is an ASCII char.
67 | current_word += c
68 | else:
69 | # This is a Chinese char.
70 | if len(current_word) > 0:
71 | words.append(current_word)
72 | current_word = ""
73 | words.append(c)
74 | if len(current_word) > 0:
75 | words.append(current_word)
76 | return words
77 |
78 |
79 | def read_yaml(yaml_path: Union[str, Path]) -> Dict:
80 | if not Path(yaml_path).exists():
81 | raise FileExistsError(f"The {yaml_path} does not exist.")
82 |
83 | with open(str(yaml_path), "rb") as f:
84 | data = yaml.load(f, Loader=yaml.Loader)
85 | return data
86 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
--------------------------------------------------------------------------------
/cttpunctuator/src/utils/OrtInferSession.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | # @FileName :OrtInferSession.py
3 | # @Time :2023/4/13 15:13
4 | # @Author :lovemefan
5 | # @Email :lovemefan@outlook.com
6 | import logging
7 | from pathlib import Path
8 | from typing import List, Union
9 |
10 | import numpy as np
11 | from onnxruntime import (
12 | GraphOptimizationLevel,
13 | InferenceSession,
14 | SessionOptions,
15 | get_available_providers,
16 | get_device,
17 | )
18 |
19 |
20 | class ONNXRuntimeError(Exception):
21 | pass
22 |
23 |
24 | class OrtInferSession:
25 | def __init__(self, model_file, device_id=-1, intra_op_num_threads=4):
26 | device_id = str(device_id)
27 | sess_opt = SessionOptions()
28 | sess_opt.intra_op_num_threads = intra_op_num_threads
29 | sess_opt.log_severity_level = 4
30 | sess_opt.enable_cpu_mem_arena = False
31 | sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
32 |
33 | cuda_ep = "CUDAExecutionProvider"
34 | cuda_provider_options = {
35 | "device_id": device_id,
36 | "arena_extend_strategy": "kNextPowerOfTwo",
37 | "cudnn_conv_algo_search": "EXHAUSTIVE",
38 | "do_copy_in_default_stream": "true",
39 | }
40 | cpu_ep = "CPUExecutionProvider"
41 | cpu_provider_options = {
42 | "arena_extend_strategy": "kSameAsRequested",
43 | }
44 |
45 | EP_list = []
46 | if (
47 | device_id != "-1"
48 | and get_device() == "GPU"
49 | and cuda_ep in get_available_providers()
50 | ):
51 | EP_list = [(cuda_ep, cuda_provider_options)]
52 | EP_list.append((cpu_ep, cpu_provider_options))
53 |
54 | self._verify_model(model_file)
55 | self.session = InferenceSession(
56 | model_file, sess_options=sess_opt, providers=EP_list
57 | )
58 |
59 | if device_id != "-1" and cuda_ep not in self.session.get_providers():
60 | logging.warnings.warn(
61 | f"{cuda_ep} is not avaiable for current env, "
62 | f"the inference part is automatically shifted to be executed under {cpu_ep}.\n"
63 | "Please ensure the installed onnxruntime-gpu version matches your cuda and cudnn version, "
64 | "you can check their relations from the offical web site: "
65 | "https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html",
66 | RuntimeWarning,
67 | )
68 |
69 | def __call__(
70 | self, input_content: List[Union[np.ndarray, np.ndarray]]
71 | ) -> np.ndarray:
72 | input_dict = dict(zip(self.get_input_names(), input_content))
73 | try:
74 | return self.session.run(self.get_output_names(), input_dict)
75 | except Exception as e:
76 | raise ONNXRuntimeError("ONNXRuntime inferece failed.") from e
77 |
78 | def get_input_names(
79 | self,
80 | ):
81 | return [v.name for v in self.session.get_inputs()]
82 |
83 | def get_output_names(
84 | self,
85 | ):
86 | return [v.name for v in self.session.get_outputs()]
87 |
88 | def get_character_list(self, key: str = "character"):
89 | return self.meta_dict[key].splitlines()
90 |
91 | def have_key(self, key: str = "character") -> bool:
92 | self.meta_dict = self.session.get_modelmeta().custom_metadata_map
93 | if key in self.meta_dict.keys():
94 | return True
95 | return False
96 |
97 | @staticmethod
98 | def _verify_model(model_path):
99 | model_path = Path(model_path)
100 | if not model_path.exists():
101 | raise FileNotFoundError(f"{model_path} does not exists.")
102 | if not model_path.is_file():
103 | raise FileExistsError(f"{model_path} is not a file.")
104 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |