├── .gitignore ├── LICENSE ├── MANIFEST.in ├── README.md ├── ctcdecode ├── __init__.py ├── decoder │ ├── __init__.py │ ├── base.py │ ├── beam_search.py │ ├── best_path.py │ └── substring_beam_search.py ├── prefix.py └── scorer.py ├── setup.cfg └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | .venv/ 132 | playground/ 133 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Matthias 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include LICENSE README.md 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # py-ctc-decode 2 | Decoding output of CTC trained models. 3 | Implementation was adapted from [https://github.com/PaddlePaddle/DeepSpeech](https://github.com/PaddlePaddle/DeepSpeech), but in python for easier modifications. If you need the speed, use the Paddle implementation. 4 | - Best Path 5 | - Beam Search 6 | - Beam Search with LM 7 | 8 | ## Usage 9 | Following examples show how to use the different decoding strategies. 10 | It is assumed that ``_`` is the blank symbol. 11 | Probabilites are expected to be in natural log. 12 | 13 | ### Best Path 14 | ```python 15 | logits = [] # TxV 16 | vocabulary = [' ', 'a', 'b', 'c', '_'] # V 17 | decoder = ctcdecode.BestPathDecoder( 18 | vocabulary, 19 | num_workers=4 20 | ) 21 | 22 | predictions = decoder.decode_batch(logits) 23 | prediction = decoder.decode(logits[0]) 24 | ``` 25 | 26 | ### Beam Search 27 | ```python 28 | logits = [] # TxV 29 | vocabulary = [' ', 'a', 'b', 'c', '_'] # V 30 | decoder = ctcdecode.BeamSearchDecoder( 31 | vocabulary, 32 | num_workers=4, 33 | beam_width=64, 34 | cutoff_prob=np.log(0.000001), 35 | cutoff_top_n=40 36 | ) 37 | 38 | predictions = decoder.decode_batch(logits) 39 | prediction = decoder.decode(logits[0]) 40 | ``` 41 | 42 | ### Beam Search with LM 43 | ```python 44 | logits = [ 45 | [-1.1906, -1.0623, -1.7766, -1.7086], 46 | [-1.4091, -1.4424, -1.1923, -1.5336], 47 | [-1.4091, -1.6900, -1.6956, -0.9477], 48 | [-1.3715, -1.2527, -1.7445, -1.2524], 49 | [-1.2577, -1.2588, -1.3380, -1.7759] 50 | ] # 5x4 TimeSteps x Softmax over Vocabulary (NATURAL LOG !!!) 51 | vocabulary = [' ', 'a', 'b', '_'] # 4 52 | alpha = 2.5 # LM Weight 53 | beta = 0.0 # LM Usage Reward 54 | word_lm_scorer = ctcdecode.WordKenLMScorer('path/to/kenlm', alpha, beta) 55 | decoder = ctcdecode.BeamSearchDecoder( 56 | vocabulary, 57 | num_workers=4, 58 | beam_width=64, 59 | scorers=[word_lm_scorer], 60 | cutoff_prob=np.log(0.000001), 61 | cutoff_top_n=40 62 | ) 63 | 64 | prediction = decoder.decode(logits) # text (e.g. "a b") 65 | 66 | # Batch decoding for multiple utterances 67 | batch = [logits, ....] 68 | predictions = decoder.decode_batch(batch) 69 | ``` 70 | -------------------------------------------------------------------------------- /ctcdecode/__init__.py: -------------------------------------------------------------------------------- 1 | from ctcdecode.decoder.best_path import BestPathDecoder 2 | from ctcdecode.decoder.beam_search import BeamSearchDecoder 3 | from ctcdecode.decoder.substring_beam_search import SubstringBeamSearchDecoder 4 | 5 | from ctcdecode.scorer import WordKenLMScorer 6 | from ctcdecode.scorer import CharOfWordKenLMScorer 7 | 8 | __version__ = '0.0.0' 9 | -------------------------------------------------------------------------------- /ctcdecode/decoder/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ynop/py-ctc-decode/2719bd9e5f539dfd24c7a4d0dff3c7dcb092c076/ctcdecode/decoder/__init__.py -------------------------------------------------------------------------------- /ctcdecode/decoder/base.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | 3 | from tqdm import tqdm 4 | import psutil 5 | 6 | 7 | class Decoder: 8 | 9 | def __init__(self, vocab, num_workers=4, fix_cpu_per_process=True): 10 | self.num_workers = num_workers 11 | self.vocab = vocab 12 | self.fix_cpu_per_process = fix_cpu_per_process 13 | 14 | def decode(self, probs): 15 | pass 16 | 17 | def decode_batch(self, prob_list): 18 | if psutil.LINUX: 19 | p = psutil.Process() 20 | available_cpus = p.cpu_affinity() 21 | else: 22 | available_cpus = [None]*1000 23 | 24 | num_procs = min(self.num_workers, len(available_cpus)) 25 | decode_procs = [] 26 | 27 | tasks = multiprocessing.Queue() 28 | results = multiprocessing.Queue() 29 | 30 | for i in range(num_procs): 31 | if self.fix_cpu_per_process: 32 | cpu_id = available_cpus[i] 33 | else: 34 | cpu_id = None 35 | 36 | decode_procs.append( 37 | DecoderProcess(self, cpu_id, tasks, results) 38 | ) 39 | 40 | for p in decode_procs: 41 | p.start() 42 | 43 | for index, probs in enumerate(prob_list): 44 | tasks.put((index, probs)) 45 | 46 | predictions = [] 47 | 48 | for i in tqdm(range(len(prob_list))): 49 | predictions.append(results.get()) 50 | 51 | # Send end 52 | for i in range(num_procs): 53 | tasks.put(None) 54 | 55 | predictions = sorted(predictions, key=lambda x: x[0]) 56 | return [p[1] for p in predictions] 57 | 58 | 59 | class DecoderProcess(multiprocessing.Process): 60 | 61 | def __init__(self, decoder, cpu_id, task_queue, result_queue): 62 | multiprocessing.Process.__init__(self) 63 | 64 | self.decoder = decoder 65 | self.cpu_id = cpu_id 66 | self.task_queue = task_queue 67 | self.result_queue = result_queue 68 | 69 | def run(self): 70 | if self.cpu_id is not None: 71 | p = psutil.Process() 72 | p.cpu_affinity(cpus=[self.cpu_id]) 73 | 74 | for index, probs in iter(self.task_queue.get, None): 75 | prediction = self.decoder.decode(probs) 76 | self.result_queue.put((index, prediction)) 77 | -------------------------------------------------------------------------------- /ctcdecode/decoder/beam_search.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from ctcdecode.prefix import State 4 | from ctcdecode.decoder import base 5 | 6 | 7 | class BeamSearchDecoder(base.Decoder): 8 | 9 | def __init__(self, vocab, num_workers=4, beam_width=64, scorers=None, 10 | cutoff_prob=1.0, cutoff_top_n=40): 11 | super(BeamSearchDecoder, self).__init__(vocab, num_workers=num_workers) 12 | 13 | self.beam_width = beam_width 14 | self.scorers = scorers 15 | self.cutoff_prob = cutoff_prob 16 | self.cutoff_top_n = cutoff_top_n 17 | 18 | def decode(self, probs): 19 | # Num time steps 20 | nT = probs.shape[0] 21 | 22 | # Initialize prefixes 23 | prefixes = State( 24 | scorers=self.scorers, 25 | size=self.beam_width 26 | ) 27 | 28 | # Iterate over timesteps 29 | for t in range(nT): 30 | step_probs = probs[t] 31 | pruned_step_probs = self._get_pruned_vocab_indices(step_probs) 32 | 33 | # Iterate over symbols 34 | for v in pruned_step_probs: 35 | symbol = self.vocab[v] 36 | symbol_prob = step_probs[v] 37 | 38 | # Iterate over prefixes 39 | for prefix in prefixes: 40 | 41 | # If there is a blank, we extend the existing prefix 42 | if symbol == '_': 43 | prefix.add_p_blank(symbol_prob + prefix.score) 44 | 45 | else: 46 | 47 | # If the last symbol is repeated 48 | # update the existing prefix 49 | if symbol == prefix.symbol: 50 | p = symbol_prob + prefix.p_non_blank_prev 51 | prefix.add_p_non_blank(p) 52 | 53 | new_prefix = prefixes.get_prefix(prefix, symbol) 54 | 55 | if new_prefix is not None: 56 | p = -np.inf 57 | 58 | if symbol == prefix.symbol and \ 59 | prefix.p_blank_prev > -np.inf: 60 | p = prefix.p_blank_prev + symbol_prob 61 | 62 | elif prefix.symbol != symbol: 63 | p = prefix.score + symbol_prob 64 | 65 | new_prefix.add_p_non_blank(p) 66 | 67 | prefixes.step() 68 | 69 | prefixes.finalize() 70 | 71 | return prefixes.best() 72 | 73 | def _get_pruned_vocab_indices(self, log_probs): 74 | """ Return vocab indices of pruned probabilities of a time step. """ 75 | 76 | index_to_prob = [(k, log_probs[k]) for k in range(log_probs.shape[0])] 77 | index_to_prob = sorted(index_to_prob, key=lambda x: x[1], reverse=True) 78 | 79 | if self.cutoff_top_n < len(index_to_prob): 80 | index_to_prob = index_to_prob[:self.cutoff_top_n] 81 | 82 | if self.cutoff_prob < 1.0: 83 | filtered = [] 84 | for x in index_to_prob: 85 | if x[1] >= self.cutoff_prob: 86 | filtered.append(x) 87 | index_to_prob = filtered 88 | 89 | return [x[0] for x in index_to_prob] 90 | -------------------------------------------------------------------------------- /ctcdecode/decoder/best_path.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from ctcdecode.decoder import base 4 | 5 | 6 | class BestPathDecoder(base.Decoder): 7 | 8 | def decode(self, probs): 9 | pred = [] 10 | for t in np.argmax(probs, axis=1): 11 | c = self.vocab[t] 12 | if len(pred) == 0 or pred[-1] != c: 13 | pred.append(c) 14 | 15 | pred = ''.join(pred).replace('_', '') 16 | return pred 17 | -------------------------------------------------------------------------------- /ctcdecode/decoder/substring_beam_search.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from ctcdecode.prefix import State 4 | from ctcdecode.decoder import beam_search 5 | 6 | 7 | class SubstringBeamSearchDecoder(beam_search.BeamSearchDecoder): 8 | """ 9 | Decoder using Beam-Search. 10 | In contrast to the default implementation, 11 | this creates new prefixes by adding substrings of symbols. 12 | 13 | e.g. Given Prefix is "hell" and a symbol "lo" comes in. 14 | A Prefix "hello" is created as well. 15 | """ 16 | 17 | def __init__(self, vocab, num_workers=4, beam_width=64, scorers=None, 18 | cutoff_prob=1.0, cutoff_top_n=40, only_repeating=True): 19 | super(SubstringBeamSearchDecoder, self).__init__( 20 | vocab, num_workers=num_workers, beam_width=beam_width, 21 | scorers=scorers, cutoff_prob=cutoff_prob, cutoff_top_n=cutoff_top_n 22 | ) 23 | 24 | self.repeating = only_repeating 25 | 26 | def decode(self, probs): 27 | # Num time steps 28 | nT = probs.shape[0] 29 | 30 | # Initialize prefixes 31 | prefixes = State( 32 | scorers=self.scorers, 33 | size=self.beam_width 34 | ) 35 | 36 | # Iterate over timesteps 37 | for t in range(nT): 38 | step_probs = probs[t] 39 | pruned_step_probs = self._get_pruned_vocab_indices(step_probs) 40 | 41 | # Iterate over symbols 42 | for v in pruned_step_probs: 43 | symbol = self.vocab[v] 44 | symbol_prob = step_probs[v] 45 | 46 | # Iterate over prefixes 47 | for prefix in prefixes: 48 | 49 | # If there is a blank, we extend the existing prefix 50 | if symbol == '_': 51 | prefix.add_p_blank(symbol_prob + prefix.score) 52 | 53 | else: 54 | partial_symbols = [symbol] 55 | 56 | for i in range(1, len(symbol) + 1): 57 | if prefix.symbol is not None and \ 58 | (not self.repeating or prefix.symbol.endswith(symbol[:i])): 59 | partial_symbols.append(symbol[i:]) 60 | 61 | for partial_sym in partial_symbols: 62 | 63 | # If the last symbol is repeated 64 | # update the existing prefix 65 | if partial_sym == prefix.symbol: 66 | p = symbol_prob + prefix.p_non_blank_prev 67 | prefix.add_p_non_blank(p) 68 | 69 | new_prefix = prefixes.get_prefix( 70 | prefix, partial_sym 71 | ) 72 | 73 | if new_prefix is not None: 74 | p = -np.inf 75 | 76 | if partial_sym == prefix.symbol and \ 77 | prefix.p_blank_prev > -np.inf: 78 | p = prefix.p_blank_prev + symbol_prob 79 | elif prefix.symbol != partial_sym: 80 | p = prefix.score + symbol_prob 81 | 82 | new_prefix.add_p_non_blank(p) 83 | 84 | prefixes.step() 85 | 86 | prefixes.finalize() 87 | 88 | return prefixes.best() 89 | -------------------------------------------------------------------------------- /ctcdecode/prefix.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class Prefix: 5 | """ 6 | Class holding the state of a single Prefix/Beam. 7 | """ 8 | 9 | __slots__ = [ 10 | 'value', 'symbol', 'p_blank', 'p_non_blank', 11 | 'p_blank_prev', 'p_non_blank_prev', 'score', 'ext_weight' 12 | ] 13 | 14 | def __init__(self): 15 | self.value = '' 16 | self.symbol = None 17 | 18 | self.p_blank = -np.inf 19 | self.p_non_blank = -np.inf 20 | 21 | self.p_blank_prev = 0.0 22 | self.p_non_blank_prev = -np.inf 23 | 24 | self.score = 0.0 25 | 26 | self.ext_weight = 0.0 27 | 28 | def __repr__(self): 29 | return 'Prefix("{}", {}, "{}", {}, {})'.format( 30 | self.value, self.score, self.symbol, 31 | self.p_blank_prev, 32 | self.p_non_blank_prev 33 | ) 34 | 35 | def step(self): 36 | self.p_blank_prev = self.p_blank 37 | self.p_non_blank_prev = self.ext_weight + self.p_non_blank 38 | 39 | self.score = np.logaddexp(self.p_blank_prev, self.p_non_blank_prev) 40 | 41 | self.p_blank = -np.inf 42 | self.p_non_blank = -np.inf 43 | 44 | def add_p_blank(self, p): 45 | self.p_blank = np.logaddexp(self.p_blank, p) 46 | 47 | def add_p_non_blank(self, p): 48 | self.p_non_blank = np.logaddexp(self.p_non_blank, p) 49 | 50 | 51 | class State: 52 | """ 53 | Class holding the state of the decoding process. 54 | """ 55 | 56 | def __init__(self, size=64, scorers=None): 57 | self.prefixes = { 58 | '': Prefix() 59 | } 60 | self.step_prefixes = {} 61 | self.prev_prefixes = {} 62 | 63 | self.size = size 64 | self.scorers = scorers or [] 65 | 66 | def __iter__(self): 67 | for p in list(self.prefixes.values()): 68 | yield p 69 | 70 | def get_prefix(self, prefix, symbol): 71 | new_value = prefix.value + symbol 72 | 73 | if new_value in self.prefixes.keys(): 74 | return self.prefixes[new_value] 75 | elif new_value in self.prev_prefixes.keys(): 76 | new_prefix = self.prev_prefixes[new_value] 77 | self.step_prefixes[new_value] = new_prefix 78 | return new_prefix 79 | else: 80 | new_prefix = Prefix() 81 | new_prefix.value = new_value 82 | 83 | new_prefix.p_blank_prev = prefix.p_blank_prev 84 | new_prefix.p_non_blank_prev = prefix.p_non_blank_prev 85 | new_prefix.score = prefix.score 86 | 87 | new_prefix.p_blank = -np.inf 88 | new_prefix.p_non_blank = -np.inf 89 | 90 | new_prefix.symbol = symbol 91 | new_prefix.ext_weight = 0.0 92 | 93 | for scorer in self.scorers: 94 | if not scorer.is_valid_prefix(new_value): 95 | return None 96 | 97 | new_prefix.ext_weight += scorer.score_prefix(new_prefix) 98 | 99 | self.step_prefixes[new_value] = new_prefix 100 | return new_prefix 101 | 102 | return None 103 | 104 | def step(self): 105 | self.prefixes.update(self.step_prefixes) 106 | self.step_prefixes = {} 107 | 108 | for prefix in self.prefixes.values(): 109 | prefix.step() 110 | 111 | p_sorted = sorted( 112 | self.prefixes.items(), 113 | key=lambda x: x[1].score, 114 | reverse=True 115 | ) 116 | 117 | self.prefixes = {} 118 | for value, prefix in p_sorted[:self.size]: 119 | self.prefixes[value] = prefix 120 | 121 | self.prev_prefixes = {} 122 | for value, prefix in p_sorted[self.size:]: 123 | self.prev_prefixes[value] = prefix 124 | 125 | def finalize(self): 126 | for scorer in self.scorers: 127 | for prefix in self: 128 | ext_score = scorer.final_prefix_score(prefix) 129 | prefix.score += ext_score 130 | 131 | def best(self): 132 | p_sorted = sorted( 133 | self.prefixes.items(), 134 | key=lambda x: x[1].score, 135 | reverse=True 136 | ) 137 | return p_sorted[0][0] 138 | -------------------------------------------------------------------------------- /ctcdecode/scorer.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import uuid 3 | 4 | import numpy as np 5 | import kenlm 6 | 7 | OOV_WORD_PENALTY = -1000.0 8 | 9 | 10 | class Scorer: 11 | """ 12 | Base class for a external scorer. 13 | This can be used to integrate for example a language model. 14 | """ 15 | 16 | def score_prefix(self, prefix): 17 | """ 18 | Return a score (log base e) for the given prefix. 19 | """ 20 | pass 21 | 22 | def final_prefix_score(self, prefix): 23 | """ 24 | Return a score (log base e) for the given prefix, 25 | considering the prefix won't be extended anymore. 26 | This is called for every prefix at the end, 27 | whether ``score_prefix`` was already called or not. 28 | """ 29 | pass 30 | 31 | def is_valid_prefix(self, value): 32 | """ 33 | Return ``True``, if the given prefix is valid. 34 | """ 35 | pass 36 | 37 | 38 | class WordKenLMScorer(Scorer): 39 | 40 | def __init__(self, path, alpha, beta): 41 | self.path = path 42 | self.alpha = alpha 43 | self.beta = beta 44 | 45 | self.lm = kenlm.Model(path) 46 | 47 | self.words = self._get_words(path) 48 | self.word_prefixes = self._get_word_prefixes(self.words) 49 | 50 | self.idx = uuid.uuid1() 51 | 52 | def score_prefix(self, prefix): 53 | if prefix.symbol == ' ': 54 | words = prefix.value.strip().split(' ') 55 | cond_prob = self.get_cond_log_prob(words) 56 | cond_prob *= self.alpha 57 | cond_prob += self.beta 58 | 59 | return self._to_base_e(cond_prob) 60 | 61 | return 0.0 62 | 63 | def final_prefix_score(self, prefix): 64 | if prefix.symbol != ' ': 65 | words = prefix.value.strip().split(' ') 66 | cond_prob = self.get_cond_log_prob(words) 67 | cond_prob *= self.alpha 68 | cond_prob += self.beta 69 | 70 | return self._to_base_e(cond_prob) 71 | 72 | return 0.0 73 | 74 | def is_valid_prefix(self, value): 75 | last_word = value.strip().split(' ')[-1] 76 | return last_word in self.word_prefixes[len(last_word)] 77 | 78 | def get_cond_log_prob(self, sequence): 79 | sequence = sequence[-self.lm.order:] 80 | 81 | in_state = kenlm.State() 82 | self.lm.NullContextWrite(in_state) 83 | out_state = kenlm.State() 84 | 85 | for word in sequence: 86 | if word not in self.words: 87 | return OOV_WORD_PENALTY 88 | 89 | lm_prob = self.lm.BaseScore( 90 | in_state, word, out_state 91 | ) 92 | tmp_state = in_state 93 | in_state = out_state 94 | out_state = tmp_state 95 | 96 | return lm_prob 97 | 98 | def _to_base_e(self, x): 99 | return x / np.log(10) 100 | 101 | def _get_words(self, path): 102 | words = set() 103 | 104 | with open(path, 'r') as f: 105 | start_1_gram = False 106 | end_1_gram = False 107 | 108 | while not end_1_gram: 109 | line = f.readline().strip() 110 | 111 | if line == '\\1-grams:': 112 | print('found 1gram') 113 | start_1_gram = True 114 | 115 | elif line == '\\2-grams:': 116 | print('found 2gram') 117 | end_1_gram = True 118 | 119 | elif start_1_gram and line != '': 120 | parts = line.split('\t') 121 | if len(parts) == 3: 122 | words.add(parts[1]) 123 | 124 | return words 125 | 126 | def _get_word_prefixes(self, words): 127 | word_prefixes = collections.defaultdict(set) 128 | 129 | for word in words: 130 | for i in range(1, len(word) + 1): 131 | word_prefixes[i].add(word[:i]) 132 | 133 | return word_prefixes 134 | 135 | 136 | class CharOfWordKenLMScorer(Scorer): 137 | 138 | def __init__(self, path, alpha, beta): 139 | self.path = path 140 | self.alpha = alpha 141 | self.beta = beta 142 | 143 | self.lm = kenlm.Model(path) 144 | 145 | def score_prefix(self, prefix): 146 | if prefix.symbol != ' ': 147 | words = prefix.value.strip().split(' ') 148 | last_word = words[-1] 149 | 150 | total_cond_prob = 0.0 151 | 152 | # Account for multi-char symbols 153 | for i in range(len(prefix.symbol)): 154 | part = last_word[:len(last_word)-i] 155 | chars = list(part) 156 | 157 | cond_prob = self.get_cond_log_prob(chars) 158 | total_cond_prob += cond_prob * self.alpha 159 | total_cond_prob += self.beta 160 | 161 | return self._to_base_e(total_cond_prob) 162 | 163 | return 0.0 164 | 165 | def final_prefix_score(self, prefix): 166 | return 0.0 167 | 168 | def is_valid_prefix(self, value): 169 | return True 170 | 171 | def get_cond_log_prob(self, sequence): 172 | sequence = sequence[-self.lm.order:] 173 | 174 | in_state = kenlm.State() 175 | self.lm.NullContextWrite(in_state) 176 | out_state = kenlm.State() 177 | 178 | for word in sequence: 179 | lm_prob = self.lm.BaseScore( 180 | in_state, word, out_state 181 | ) 182 | tmp_state = in_state 183 | in_state = out_state 184 | out_state = tmp_state 185 | 186 | return lm_prob 187 | 188 | def _to_base_e(self, x): 189 | return x / np.log(10) 190 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [aliases] 2 | test=pytest 3 | 4 | [tool:pytest] 5 | minversion = 3.0 6 | testpaths = tests 7 | addopts = --benchmark-autosave --benchmark-cprofile=tottime 8 | 9 | [flake8] 10 | max-line-length=80 11 | exclude = .eggs,.git,build/*,bin/*,docs/*,lib/*,examples/*,.venv/,venv/,.35venv/ 12 | inline-quotes = single 13 | 14 | [metadata] 15 | license_file = LICENSE 16 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from setuptools import find_packages 4 | from setuptools import setup 5 | 6 | ################################################## 7 | # Dependencies 8 | ################################################## 9 | 10 | PYTEST_VERSION_ = '5.2.3' 11 | KENLM_COMMIT = '96d303cfb1a0c21b8f060dbad640d7ab301c019a' 12 | 13 | # Packages required in 'production' 14 | REQUIRED = [ 15 | 'tqdm==4.39.0', 16 | 'kenlm @ git+ssh://git@github.com/kpu/kenlm@{}#egg=kenlm'.format( 17 | KENLM_COMMIT 18 | ), 19 | 'numpy==1.16.2', 20 | 'psutil==5.6.7', 21 | ] 22 | 23 | # Packages required for dev/ci enrionment 24 | EXTRAS = { 25 | 'dev': [ 26 | 'click==7.0', 27 | 'pytest==%s' % (PYTEST_VERSION_,), 28 | 'pytest-runner==5.2', 29 | 'pytest-cov==2.8.1', 30 | 'pytest-benchmark==3.2.2', 31 | ], 32 | 'ci': [ 33 | 'flake8==3.7.9', 34 | 'flake8-quotes==2.1.1' 35 | ], 36 | } 37 | 38 | # Packages required for testing 39 | TESTS = [ 40 | 'pytest==%s' % (PYTEST_VERSION_,), 41 | 'requests_mock==1.7.0' 42 | ] 43 | 44 | ################################################## 45 | # Description 46 | ################################################## 47 | 48 | DESCRIPTION = 'ctcdecode is package to decode output from CTC trained dnn.' 49 | 50 | root = os.path.abspath(os.path.dirname(__file__)) 51 | readme_path = os.path.join(root, 'README.md') 52 | 53 | # Import the README and use it as the long-description. 54 | try: 55 | with open(readme_path, encoding='utf-8') as f: 56 | long_description = '\n' + f.read() 57 | except FileNotFoundError: 58 | long_description = DESCRIPTION 59 | 60 | ################################################## 61 | # SETUP 62 | ################################################## 63 | 64 | setup(name='ctcdecode', 65 | version='0.0.0', 66 | description=DESCRIPTION, 67 | long_description=long_description, 68 | long_description_content_type='text/markdown', 69 | url='https://github.com/ynop/ctcdecode', 70 | download_url='https://github.com/ynop/ctcdecode/releases', 71 | author='Matthias Buechi', 72 | author_email='buec@zhaw.ch', 73 | classifiers=[ 74 | 'Intended Audience :: Science/Research', 75 | 'License :: OSI Approved :: MIT License', 76 | 'Programming Language :: Python :: 3 :: Only', 77 | 'Topic :: Scientific/Engineering :: Human Machine Interfaces' 78 | ], 79 | keywords=('speech recognition lexicon dictionary ' 80 | 'phone phoneme pronunciation'), 81 | license='MIT', 82 | packages=find_packages(exclude=['tests']), 83 | install_requires=REQUIRED, 84 | include_package_data=True, 85 | zip_safe=False, 86 | test_suite='tests', 87 | extras_require=EXTRAS, 88 | setup_requires=['pytest-runner'], 89 | tests_require=TESTS, 90 | entry_points={} 91 | ) 92 | --------------------------------------------------------------------------------