├── README.md ├── examples ├── example_1518.p ├── example_2002.p └── example_99.p ├── language_model.p ├── prefix_beam_search.py └── test.py /README.md: -------------------------------------------------------------------------------- 1 | # Prefix Beam Search 2 | Code for prefix beam search tutorial by @borgholt (FKA @labodk) 3 | 4 | Link: https://medium.com/corti-ai/ctc-networks-and-language-models-prefix-beam-search-explained-c11d1ee23306 5 | 6 | ### Code 7 | This repository contains two files with Python code: 8 | 9 | * `prefix_beam_search.py` contains all the code that is explained in the tutorial. I.e., the actual prefix beam search algorithm. 10 | * `test.py` will load a language model, perform beam search on three examples and print the result along with the output from a greedy decoder for comparison. 11 | 12 | ### Examples 13 | The `examples` folder contains three examples of CTC output (2D NumPy arrays) from a CNN-based acoustic model. The model is trained on the LibriSpeech corpus (http://www.openslr.org/12). When executing `test.py` you should get the following output: 14 | 15 | ``` 16 | examples/example_2002.p 17 | 18 | BEFORE: 19 | alloud laugh followed at chunkeys expencs 20 | 21 | AFTER: 22 | a loud laugh followed at chunkys expense 23 | 24 | examples/example_99.p 25 | 26 | BEFORE: 27 | but no ghoes tor anything else appeared upon the angient wall 28 | 29 | AFTER: 30 | but no ghost or anything else appeared upon the ancient walls 31 | 32 | examples/example_1518.p 33 | 34 | BEFORE: 35 | mister qualter as the apostle of the middle classes and we re glad twelcomed his gospe 36 | 37 | AFTER: 38 | mister quilter is the apostle of the middle classes and we are glad to welcome his gospel 39 | ``` 40 | 41 | Notice that each of these examples are handpicked. Thus, the transcript resulting from the prefix beam search is also the true transcript. 42 | 43 | ### Language Model 44 | The `language_model.p` contains a dictionary mapping between all relevant prefixes queried during decoding of the three examples and the corresponding language model probabilities. Thus, this "language model" will only work for the three provided examples. The original language model file was too large to upload here. However, a range of similar pre-trained models can be found on the LibriSpeech website (http://www.openslr.org/11). The original model used in this tutorial was trained with the KenLM Language Model Toolkit (https://kheafield.com/code/kenlm/) on the additional language modeling data of the LibriSpeech corpus. 45 | 46 | ### Dependencies 47 | 48 | * `numpy` 49 | -------------------------------------------------------------------------------- /examples/example_1518.p: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/corticph/prefix-beam-search/a2dec5d989730bd9e631ad2304851262fd63e39a/examples/example_1518.p -------------------------------------------------------------------------------- /examples/example_2002.p: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/corticph/prefix-beam-search/a2dec5d989730bd9e631ad2304851262fd63e39a/examples/example_2002.p -------------------------------------------------------------------------------- /examples/example_99.p: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/corticph/prefix-beam-search/a2dec5d989730bd9e631ad2304851262fd63e39a/examples/example_99.p -------------------------------------------------------------------------------- /language_model.p: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/corticph/prefix-beam-search/a2dec5d989730bd9e631ad2304851262fd63e39a/language_model.p -------------------------------------------------------------------------------- /prefix_beam_search.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict, Counter 2 | from string import ascii_lowercase 3 | import re 4 | import numpy as np 5 | 6 | def prefix_beam_search(ctc, lm=None, k=25, alpha=0.30, beta=5, prune=0.001): 7 | """ 8 | Performs prefix beam search on the output of a CTC network. 9 | 10 | Args: 11 | ctc (np.ndarray): The CTC output. Should be a 2D array (timesteps x alphabet_size) 12 | lm (func): Language model function. Should take as input a string and output a probability. 13 | k (int): The beam width. Will keep the 'k' most likely candidates at each timestep. 14 | alpha (float): The language model weight. Should usually be between 0 and 1. 15 | beta (float): The language model compensation term. The higher the 'alpha', the higher the 'beta'. 16 | prune (float): Only extend prefixes with chars with an emission probability higher than 'prune'. 17 | 18 | Retruns: 19 | string: The decoded CTC output. 20 | """ 21 | 22 | lm = (lambda l: 1) if lm is None else lm # if no LM is provided, just set to function returning 1 23 | W = lambda l: re.findall(r'\w+[\s|>]', l) 24 | alphabet = list(ascii_lowercase) + [' ', '>', '%'] 25 | F = ctc.shape[1] 26 | ctc = np.vstack((np.zeros(F), ctc)) # just add an imaginative zero'th step (will make indexing more intuitive) 27 | T = ctc.shape[0] 28 | 29 | # STEP 1: Initiliazation 30 | O = '' 31 | Pb, Pnb = defaultdict(Counter), defaultdict(Counter) 32 | Pb[0][O] = 1 33 | Pnb[0][O] = 0 34 | A_prev = [O] 35 | # END: STEP 1 36 | 37 | # STEP 2: Iterations and pruning 38 | for t in range(1, T): 39 | pruned_alphabet = [alphabet[i] for i in np.where(ctc[t] > prune)[0]] 40 | for l in A_prev: 41 | 42 | if len(l) > 0 and l[-1] == '>': 43 | Pb[t][l] = Pb[t - 1][l] 44 | Pnb[t][l] = Pnb[t - 1][l] 45 | continue 46 | 47 | for c in pruned_alphabet: 48 | c_ix = alphabet.index(c) 49 | # END: STEP 2 50 | 51 | # STEP 3: “Extending” with a blank 52 | if c == '%': 53 | Pb[t][l] += ctc[t][-1] * (Pb[t - 1][l] + Pnb[t - 1][l]) 54 | # END: STEP 3 55 | 56 | # STEP 4: Extending with the end character 57 | else: 58 | l_plus = l + c 59 | if len(l) > 0 and c == l[-1]: 60 | Pnb[t][l_plus] += ctc[t][c_ix] * Pb[t - 1][l] 61 | Pnb[t][l] += ctc[t][c_ix] * Pnb[t - 1][l] 62 | # END: STEP 4 63 | 64 | # STEP 5: Extending with any other non-blank character and LM constraints 65 | elif len(l.replace(' ', '')) > 0 and c in (' ', '>'): 66 | lm_prob = lm(l_plus.strip(' >')) ** alpha 67 | Pnb[t][l_plus] += lm_prob * ctc[t][c_ix] * (Pb[t - 1][l] + Pnb[t - 1][l]) 68 | else: 69 | Pnb[t][l_plus] += ctc[t][c_ix] * (Pb[t - 1][l] + Pnb[t - 1][l]) 70 | # END: STEP 5 71 | 72 | # STEP 6: Make use of discarded prefixes 73 | if l_plus not in A_prev: 74 | Pb[t][l_plus] += ctc[t][-1] * (Pb[t - 1][l_plus] + Pnb[t - 1][l_plus]) 75 | Pnb[t][l_plus] += ctc[t][c_ix] * Pnb[t - 1][l_plus] 76 | # END: STEP 6 77 | 78 | # STEP 7: Select most probable prefixes 79 | A_next = Pb[t] + Pnb[t] 80 | sorter = lambda l: A_next[l] * (len(W(l)) + 1) ** beta 81 | A_prev = sorted(A_next, key=sorter, reverse=True)[:k] 82 | # END: STEP 7 83 | 84 | return A_prev[0].strip('>') -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from glob import glob 2 | from string import ascii_lowercase 3 | from collections import defaultdict 4 | import pickle 5 | import numpy as np 6 | from prefix_beam_search import prefix_beam_search 7 | 8 | class LanguageModel(object): 9 | """ 10 | Loads a dictionary mapping between prefixes and probabilities. 11 | """ 12 | 13 | def __init__(self, lm_file): 14 | """ 15 | Initializes the language model. 16 | 17 | Args: 18 | lm_file (str): Path to dictionary mapping between prefixes and lm probabilities. 19 | """ 20 | lm = pickle.load(open(lm_file, 'rb')) 21 | self._model = defaultdict(lambda: 1e-11, lm) 22 | 23 | def __call__(self, prefix): 24 | """ 25 | Returns the probability of the last word conditioned on all previous ones. 26 | 27 | Args: 28 | prefix (str): The sentence prefix to be scored. 29 | """ 30 | return self._model[prefix] 31 | 32 | def greedy_decoder(ctc): 33 | """ 34 | Performs greedy decoding (max decoding) on the output of a CTC network. 35 | 36 | Args: 37 | ctc (np.ndarray): The CTC output. Should be a 2D array (timesteps x alphabet_size) 38 | 39 | Returns: 40 | string: The decoded CTC output. 41 | """ 42 | 43 | alphabet = list(ascii_lowercase) + [' ', '>'] 44 | alphabet_size = len(alphabet) 45 | 46 | # collapse repeating characters 47 | arg_max = np.argmax(ctc, axis=1) 48 | repeat_filter = arg_max[1:] != arg_max[:-1] 49 | repeat_filter = np.concatenate([[True], repeat_filter]) 50 | collapsed = arg_max[repeat_filter] 51 | 52 | # discard blank tokens (the blank is always last in the alphabet) 53 | blank_filter = np.where(collapsed < (alphabet_size - 1))[0] 54 | final_sequence = collapsed[blank_filter] 55 | full_decode = ''.join([alphabet[letter_idx] for letter_idx in final_sequence]) 56 | 57 | return full_decode[:full_decode.find('>')] 58 | 59 | if __name__ == '__main__': 60 | lm = LanguageModel('language_model.p') 61 | for example_file in glob('examples/*.p'): 62 | example = pickle.load(open(example_file, 'rb')) 63 | before_lm = greedy_decoder(example) 64 | after_lm = prefix_beam_search(example, lm=lm) 65 | print('\n{}'.format(example_file)) 66 | print('\nBEFORE:\n{}'.format(before_lm)) 67 | print('\nAFTER:\n{}'.format(after_lm)) --------------------------------------------------------------------------------