├── README.md ├── ctc.py └── example.py /README.md: -------------------------------------------------------------------------------- 1 | A primer on CTC implementation in pure Python PyTorch code. This impl is not suitable for real-world usage, only for experimentation and research on CTC modifications. Features: 2 | - CTC impl is in Python and its only loop is over time steps (parallelizes over batch and symbol dimensions) 3 | - Gradients are computed via PyTorch autograd instead of a separate beta computation 4 | - Viterbi path useful for forced alignment 5 | - Get alignment targets out of any CTC impl, so that label smoothing or reweighting can be applied [1, 2] 6 | - It might support double-backwards (not checked) 7 | 8 | ### Very rough time measurements 9 | ``` 10 | Device: cuda 11 | Log-probs shape (time X batch X channels): 128x256x32 12 | Built-in CTC loss fwd 0.002052783966064453 bwd 0.0167086124420166 13 | Custom CTC loss fwd 0.09685754776000977 bwd 0.14192843437194824 14 | Custom loss matches: True 15 | Grad matches: True 16 | CE grad matches: True 17 | 18 | Device: cpu 19 | Log-probs shape (time X batch X channels): 128x256x32 20 | Built-in CTC loss fwd 0.017746925354003906 bwd 0.21297860145568848 21 | Custom CTC loss fwd 0.38710451126098633 bwd 5.190514087677002 22 | Custom loss matches: True 23 | Grad matches: True 24 | CE grad matches: True 25 | ``` 26 | 27 | ### Very rought time measurements if custom logsumexp is used 28 | ``` 29 | Device: cuda 30 | Log-probs shape (time X batch X channels): 128x256x32 31 | Built-in CTC loss fwd 0.009581804275512695 bwd 0.012355327606201172 32 | Custom CTC loss fwd 0.09775996208190918 bwd 0.1494584083557129 33 | Custom loss matches: True 34 | Grad matches: True 35 | CE grad matches: True 36 | 37 | Device: cpu 38 | Log-probs shape (time X batch X channels): 128x256x32 39 | Built-in CTC loss fwd 0.017041444778442383 bwd 0.23205327987670898 40 | Custom CTC loss fwd 0.3748452663421631 bwd 4.206061363220215 41 | Custom loss matches: True 42 | Grad matches: True 43 | CE grad matches: True 44 | ``` 45 | 46 | ### Alignment image example 47 | ![](https://user-images.githubusercontent.com/1041752/71736894-8615e800-2e52-11ea-81cb-cb95b92175c6.png) 48 | 49 | ### References (CTC) 50 | 1. A Novel Re-weighting Method for Connectionist Temporal Classification; Li et al; https://arxiv.org/abs/1904.10619 51 | 2. Focal CTC Loss for Chinese Optical Character Recognition on Unbalanced Datasets; Feng et al; https://www.hindawi.com/journals/complexity/2019/9345861/ 52 | 3. Improved training for online end-to-end speech recognition systems; Kim et al; https://arxiv.org/abs/1711.02212 53 | 4. Connectionist Temporal Classification: Labelling Unsegmented Sequence Data with Recurrent Neural Networks; Graves et all; 54 | https://www.cs.toronto.edu/~graves/icml_2006.pdf 55 | 5. Sequence Modeling With CTC, Hannun et al, https://distill.pub/2017/ctc/ 56 | 6. My two related gists: 57 | - Loop-based CTC forward: https://gist.github.com/vadimkantorov/c1aa417cffa1450b03716c740795f107 58 | - CTC targets: https://gist.github.com/vadimkantorov/73e1915178f444b64f9ef01a1e96c1e4 59 | 8. Other CTC implementations: 60 | - https://github.com/rakeshvar/rnn_ctc/blob/master/nnet/ctc.py#L96 61 | - https://github.com/artbataev/end2end/blob/master/pytorch_end2end/src/losses/forward_backward.cpp 62 | - https://github.com/jamesdanged/LatticeCtc 63 | - https://github.com/zh217/torch-asg/blob/master/torch_asg/native/force_aligned_lattice.cpp 64 | - https://github.com/amaas/stanford-ctc/blob/master/ctc/ctc.py 65 | - https://github.com/skaae/Lasagne-CTC/blob/master/ctc_cost.py 66 | - https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/LossCTC.cpp#L37 67 | - https://github.com/musyoku/chainer-gram-ctc https://github.com/musyoku/chainer-cuda-ctc 68 | - https://github.com/1ytic/warp-rnnt 69 | 70 | ### References (beam search) 71 | - https://towardsdatascience.com/beam-search-decoding-in-ctc-trained-neural-networks-5a889a3d85a7 72 | - https://medium.com/corti-ai/ctc-networks-and-language-models-prefix-beam-search-explained-c11d1ee23306 73 | - https://github.com/githubharald/CTCDecoder 74 | - https://github.com/githubharald/CTCWordBeamSearch 75 | - https://gist.github.com/awni/56369a90d03953e370f3964c826ed4b0 76 | 77 | - https://github.com/wouterkool/stochastic-beam-search 78 | - https://github.com/mjansche/ctc_sampling 79 | - https://www.aclweb.org/anthology/D19-1331/ 80 | - https://arxiv.org/abs/1905.08760 81 | - https://arxiv.org/abs/1804.07915 82 | - http://proceedings.mlr.press/v97/cohen19a/cohen19a.pdf 83 | - https://github.com/corticph/prefix-beam-search/ 84 | -------------------------------------------------------------------------------- /ctc.py: -------------------------------------------------------------------------------- 1 | # TODO: try to replace fancy tensor indexing by gather / scatter 2 | 3 | import math 4 | import torch 5 | 6 | #@torch.jit.script 7 | def ctc_loss(log_probs : torch.Tensor, targets : torch.Tensor, input_lengths : torch.Tensor, target_lengths : torch.Tensor, blank : int = 0, reduction : str = 'none', finfo_min_fp32: float = torch.finfo(torch.float32).min, finfo_min_fp16: float = torch.finfo(torch.float16).min, alignment : bool = False) -> torch.Tensor: 8 | input_time_size, batch_size = log_probs.shape[:2] 9 | B = torch.arange(batch_size, device = input_lengths.device) 10 | 11 | _t_a_r_g_e_t_s_ = torch.cat([targets, targets[:, :1]], dim = -1) 12 | _t_a_r_g_e_t_s_ = torch.stack([torch.full_like(_t_a_r_g_e_t_s_, blank), _t_a_r_g_e_t_s_], dim = -1).flatten(start_dim = -2) 13 | 14 | diff_labels = torch.cat([torch.tensor([[False, False]], device = targets.device).expand(batch_size, -1), _t_a_r_g_e_t_s_[:, 2:] != _t_a_r_g_e_t_s_[:, :-2]], dim = 1) 15 | 16 | # if zero = float('-inf') is used as neutral element, custom logsumexp must be used to avoid nan grad in torch.logsumexp 17 | 18 | zero_padding, zero = 2, torch.tensor(finfo_min_fp16 if log_probs.dtype == torch.float16 else finfo_min_fp32, device = log_probs.device, dtype = log_probs.dtype) 19 | log_probs_ = log_probs.gather(-1, _t_a_r_g_e_t_s_.expand(input_time_size, -1, -1)) 20 | log_alpha = torch.full((input_time_size, batch_size, zero_padding + _t_a_r_g_e_t_s_.shape[-1]), zero, device = log_probs.device, dtype = log_probs.dtype) 21 | log_alpha[0, :, zero_padding + 0] = log_probs[0, :, blank] 22 | log_alpha[0, :, zero_padding + 1] = log_probs[0, B, _t_a_r_g_e_t_s_[:, 1]] 23 | # log_alpha[1:, :, zero_padding:] = log_probs.gather(-1, _t_a_r_g_e_t_s_.expand(len(log_probs), -1, -1))[1:] 24 | for t in range(1, input_time_size): 25 | log_alpha[t, :, 2:] = log_probs_[t] + logadd(log_alpha[t - 1, :, 2:], log_alpha[t - 1, :, 1:-1], torch.where(diff_labels, log_alpha[t - 1, :, :-2], zero)) 26 | 27 | l1l2 = log_alpha[input_lengths - 1, B].gather(-1, torch.stack([zero_padding + target_lengths * 2 - 1, zero_padding + target_lengths * 2], dim = -1)) 28 | loss = -torch.logsumexp(l1l2, dim = -1) 29 | return loss 30 | 31 | if not alignment: 32 | return loss 33 | 34 | # below is for debugging, for real alignment use more efficient the distinct ctc_alignment(...) method 35 | path = torch.zeros(len(log_alpha), len(B), device = log_alpha.device, dtype = torch.int64) 36 | path[input_lengths - 1, B] = zero_padding + 2 * target_lengths - 1 + l1l2.argmax(dim = -1) 37 | for t, indices in reversed(list(enumerate(path))[1:]): 38 | indices_ = torch.stack([(indices - 2) * diff_labels[B, (indices - zero_padding).clamp(min = 0)], (indices - 1).clamp(min = 0), indices], dim = -1) 39 | path[t - 1] += (indices - 2 + log_alpha[t - 1, B].gather(-1, indices_).argmax(dim = -1)).clamp(min = 0) 40 | return torch.zeros_like(log_alpha).scatter_(-1, path.unsqueeze(-1), 1.0)[..., (zero_padding + 1)::2] 41 | 42 | #@torch.jit.script 43 | def ctc_alignment(log_probs : torch.Tensor, targets : torch.Tensor, input_lengths : torch.Tensor, target_lengths : torch.Tensor, blank: int = 0, finfo_min_fp32: float = torch.finfo(torch.float32).min, finfo_min_fp16: float = torch.finfo(torch.float16).min) -> torch.Tensor: 44 | input_time_size, batch_size = log_probs.shape[:2] 45 | B = torch.arange(batch_size, device = input_lengths.device) 46 | 47 | _t_a_r_g_e_t_s_ = torch.cat([ 48 | torch.stack([torch.full_like(targets, blank), targets], dim = -1).flatten(start_dim = -2), 49 | torch.full_like(targets[:, :1], blank) 50 | ], dim = -1) 51 | diff_labels = torch.cat([ 52 | torch.tensor([[False, False]], device = targets.device).expand(batch_size, -1), 53 | _t_a_r_g_e_t_s_[:, 2:] != _t_a_r_g_e_t_s_[:, :-2] 54 | ], dim = 1) 55 | 56 | zero_padding, zero = 2, torch.tensor(finfo_min_fp16 if log_probs.dtype == torch.float16 else finfo_min_fp32, device = log_probs.device, dtype = log_probs.dtype) 57 | padded_t = zero_padding + _t_a_r_g_e_t_s_.shape[-1] 58 | log_alpha = torch.full((batch_size, padded_t), zero, device = log_probs.device, dtype = log_probs.dtype) 59 | log_alpha[:, zero_padding + 0] = log_probs[0, :, blank] 60 | log_alpha[:, zero_padding + 1] = log_probs[0, B, _t_a_r_g_e_t_s_[:, 1]] 61 | 62 | packmask = 0b11 63 | packnibbles = 4 # packnibbles = 1 64 | backpointers_shape = [len(log_probs), batch_size, int(math.ceil(padded_t / packnibbles))] 65 | backpointers = torch.zeros(backpointers_shape, device = log_probs.device, dtype = torch.uint8) 66 | backpointer = torch.zeros((backpointers_shape[-2], backpointers_shape[-1] * packnibbles), device = log_probs.device, dtype = torch.uint8) 67 | packshift = torch.tensor([[[6, 4, 2, 0]]], device = log_probs.device, dtype = torch.uint8) 68 | 69 | for t in range(1, input_time_size): 70 | prev = torch.stack([log_alpha[:, 2:], log_alpha[:, 1:-1], torch.where(diff_labels, log_alpha[:, :-2], zero)]) 71 | log_alpha[:, zero_padding:] = log_probs[t].gather(-1, _t_a_r_g_e_t_s_) + prev.logsumexp(dim = 0) 72 | backpointer[:, zero_padding:(zero_padding + prev.shape[-1] )] = prev.argmax(dim = 0) 73 | torch.sum(backpointer.unflatten(-1, (-1, packnibbles)) << packshift, dim = -1, out = backpointers[t]) # backpointers[t] = backpointer 74 | 75 | l1l2 = log_alpha.gather(-1, torch.stack([zero_padding + target_lengths * 2 - 1, zero_padding + target_lengths * 2], dim = -1)) 76 | 77 | path = torch.zeros(input_time_size, batch_size, device = log_alpha.device, dtype = torch.int64) 78 | path[input_lengths - 1, B] = zero_padding + target_lengths * 2 - 1 + l1l2.argmax(dim = -1) 79 | 80 | for t in range(input_time_size - 1, 0, -1): 81 | indices = path[t] 82 | backpointer = (backpointers[t].unsqueeze(-1) >> packshift).view_as(backpointer) #backpointer = backpointers[t] 83 | path[t - 1] += indices - backpointer.gather(-1, indices.unsqueeze(-1)).squeeze(-1).bitwise_and_(packmask) 84 | 85 | return torch.zeros_like(_t_a_r_g_e_t_s_, dtype = torch.int64).scatter_(-1, (path.t() - zero_padding).clamp(min = 0), torch.arange(input_time_size, device = log_alpha.device).expand(batch_size, -1))[:, 1::2] 86 | 87 | 88 | 89 | def ctc_alignment_uncompressed(log_probs : torch.Tensor, targets : torch.Tensor, input_lengths : torch.Tensor, target_lengths : torch.Tensor, blank: int = 0, pack_backpointers: bool = False, finfo_min_fp32: float = torch.finfo(torch.float32).min, finfo_min_fp16: float = torch.finfo(torch.float16).min) -> torch.Tensor: 90 | B = torch.arange(len(targets), device = input_lengths.device) 91 | _t_a_r_g_e_t_s_ = torch.cat([ 92 | torch.stack([torch.full_like(targets, blank), targets], dim = -1).flatten(start_dim = -2), 93 | torch.full_like(targets[:, :1], blank) 94 | ], dim = -1) 95 | diff_labels = torch.cat([ 96 | torch.as_tensor([[False, False]], device = targets.device).expand(len(B), -1), 97 | _t_a_r_g_e_t_s_[:, 2:] != _t_a_r_g_e_t_s_[:, :-2] 98 | ], dim = 1) 99 | 100 | zero, zero_padding = torch.tensor(finfo_min_fp16 if log_probs.dtype is torch.float16 else finfo_min_fp32, device = log_probs.device, dtype = log_probs.dtype), 2 101 | padded_t = zero_padding + _t_a_r_g_e_t_s_.shape[-1] 102 | log_alpha = torch.full((len(B), padded_t), zero, device = log_probs.device, dtype = log_probs.dtype) 103 | log_alpha[:, zero_padding + 0] = log_probs[0, :, blank] 104 | log_alpha[:, zero_padding + 1] = log_probs[0, B, _t_a_r_g_e_t_s_[:, 1]] 105 | 106 | packmask = 0b11 107 | packnibbles = 4 108 | padded_t = int(math.ceil(padded_t / packnibbles)) * packnibbles 109 | backpointers_shape = [len(log_probs), len(B), padded_t] 110 | backpointers = torch.zeros( 111 | backpointers_shape if not pack_backpointers else (backpointers_shape[:-1] + (padded_t // packnibbles, )), 112 | device = log_probs.device, 113 | dtype = torch.uint8 114 | ) 115 | backpointer = torch.zeros(backpointers_shape[1:], device = log_probs.device, dtype = torch.uint8) 116 | packshift = torch.tensor([[[6, 4, 2, 0]]], device = log_probs.device, dtype = torch.uint8) 117 | 118 | for t in range(1, len(log_probs)): 119 | prev = torch.stack([log_alpha[:, 2:], log_alpha[:, 1:-1], torch.where(diff_labels, log_alpha[:, :-2], zero)]) 120 | log_alpha[:, 2:] = log_probs[t].gather(-1, _t_a_r_g_e_t_s_) + prev.logsumexp(dim = 0) 121 | backpointer[:, 2:(2 + prev.shape[-1])] = prev.argmax(dim = 0) 122 | if pack_backpointers: 123 | torch.sum(backpointer.view(len(backpointer), -1, 4) << packshift, dim = -1, out = backpointers[t]) 124 | else: 125 | backpointers[t] = backpointer 126 | 127 | l1l2 = log_alpha.gather( 128 | -1, torch.stack([zero_padding + target_lengths * 2 - 1, zero_padding + target_lengths * 2], dim = -1) 129 | ) 130 | 131 | path = torch.zeros(len(log_probs), len(B), device = log_alpha.device, dtype = torch.int64) 132 | path[input_lengths - 1, B] = zero_padding + target_lengths * 2 - 1 + l1l2.argmax(dim = -1) 133 | 134 | for t in range(len(path) - 1, 0, -1): 135 | indices = path[t] 136 | 137 | if pack_backpointers: 138 | backpointer = (backpointers[t].unsqueeze(-1) >> packshift).view_as(backpointer) 139 | else: 140 | backpointer = backpointers[t] 141 | 142 | path[t - 1] += indices - backpointer.gather(-1, indices.unsqueeze(-1)).squeeze(-1).bitwise_and_(packmask) 143 | return torch.zeros_like(_t_a_r_g_e_t_s_, dtype = torch.int64).scatter_( 144 | -1, (path.t() - zero_padding).clamp(min = 0), 145 | torch.arange(len(path), device = log_alpha.device).expand(len(B), -1) 146 | )[:, 1::2] 147 | 148 | def ctc_alignment_targets(log_probs, targets, input_lengths, target_lengths, blank = 0, ctc_loss = torch.nn.functional.ctc_loss, retain_graph = True): 149 | loss = ctc_loss(log_probs, targets, input_lengths, target_lengths, blank = blank, reduction = 'sum') 150 | probs = log_probs.exp() 151 | # to simplify API we inline log_softmax gradient, i.e. next two lines are equivalent to: grad_logits, = torch.autograd.grad(loss, logits, retain_graph = True). gradient formula explained at https://stackoverflow.com/questions/35304393/trying-to-understand-code-that-computes-the-gradient-wrt-to-the-input-for-logsof 152 | grad_log_probs, = torch.autograd.grad(loss, log_probs, retain_graph = retain_graph) 153 | grad_logits = grad_log_probs - probs * grad_log_probs.sum(dim = -1, keepdim = True) 154 | temporal_mask = (torch.arange(len(log_probs), device = input_lengths.device, dtype = input_lengths.dtype).unsqueeze(1) < input_lengths.unsqueeze(0)).unsqueeze(-1) 155 | return (probs * temporal_mask - grad_logits).detach() 156 | 157 | def logadd(x0, x1, x2): 158 | # produces nan gradients in backward if -inf log-space zero element is used https://github.com/pytorch/pytorch/issues/31829 159 | return torch.logsumexp(torch.stack([x0, x1, x2]), dim = 0) 160 | 161 | # use if -inf log-space zero element is used 162 | #return LogsumexpFunction.apply(x0, x1, x2) 163 | 164 | # produces inplace modification error https://github.com/pytorch/pytorch/issues/31819 165 | #m = torch.max(torch.max(x0, x1), x2) 166 | #m = m.masked_fill(torch.isinf(m), 0) 167 | #res = (x0 - m).exp() + (x1 - m).exp() + (x2 - m).exp() 168 | #return res.log().add(m) 169 | 170 | class LogsumexpFunction(torch.autograd.function.Function): 171 | @staticmethod 172 | def forward(self, x0, x1, x2): 173 | m = torch.max(torch.max(x0, x1), x2) 174 | m = m.masked_fill_(torch.isinf(m), 0) 175 | e0 = (x0 - m).exp_() 176 | e1 = (x1 - m).exp_() 177 | e2 = (x2 - m).exp_() 178 | e = (e0 + e1).add_(e2).clamp_(min = 1e-16) 179 | self.save_for_backward(e0, e1, e2, e) 180 | return e.log_().add_(m) 181 | 182 | @staticmethod 183 | def backward(self, grad_output): 184 | e0, e1, e2, e = self.saved_tensors 185 | g = grad_output / e 186 | return (g * e0, g * e1, g * e2) 187 | -------------------------------------------------------------------------------- /example.py: -------------------------------------------------------------------------------- 1 | import time 2 | import matplotlib.pyplot as plt 3 | 4 | import torch 5 | 6 | import ctc 7 | 8 | T, B, C = 128, 256, 32 9 | t = T // 2 - 4 10 | blank = 0 11 | device = 'cpu'#'cuda' 12 | seed = 1 13 | atol = 1e-3 14 | for set_seed in [torch.manual_seed] + ([torch.cuda.manual_seed_all] if device == 'cuda' else []): 15 | set_seed(seed) 16 | tictoc = lambda: (device == 'cuda' and torch.cuda.synchronize()) or time.time() 17 | 18 | logits = torch.randn(T, B, C, device = device).requires_grad_() 19 | targets = torch.randint(blank + 1, C, (B, t), dtype = torch.long, device = device) 20 | input_lengths = torch.full((B,), T, dtype = torch.long, device = device) 21 | target_lengths = torch.full((B,), t, dtype = torch.long, device = device) 22 | log_probs = logits.log_softmax(dim = -1) 23 | 24 | print('Device:', device) 25 | print('Log-probs shape (time X batch X channels):', 'x'.join(map(str, log_probs.shape))) 26 | 27 | tic = tictoc() 28 | builtin_ctc = torch.nn.functional.ctc_loss(log_probs, targets, input_lengths, target_lengths, blank = 0, reduction = 'none') 29 | toc = tictoc() 30 | builtin_ctc_grad, = torch.autograd.grad(builtin_ctc.sum(), logits, retain_graph = True) 31 | print('Built-in CTC loss', 'fwd', toc - tic, 'bwd', tictoc() - toc) 32 | 33 | tic = tictoc() 34 | custom_ctc = ctc.ctc_loss(log_probs, targets, input_lengths, target_lengths, blank = 0, reduction = 'none') 35 | toc = tictoc() 36 | custom_ctc_grad, = torch.autograd.grad(custom_ctc.sum(), logits, retain_graph = True) 37 | print('Custom CTC loss', 'fwd', toc - tic, 'bwd', tictoc() - toc) 38 | 39 | ce_alignment_targets = ctc.ctc_alignment_targets(log_probs, targets, input_lengths, target_lengths, blank = 0) 40 | ce_ctc = -ce_alignment_targets * log_probs 41 | ce_ctc_grad, = torch.autograd.grad(ce_ctc.sum(), logits, retain_graph = True) 42 | 43 | print('Custom loss matches:', torch.allclose(builtin_ctc, custom_ctc, atol = atol)) 44 | print('Grad matches:', torch.allclose(builtin_ctc_grad, custom_ctc_grad, atol = atol)) 45 | print('CE grad matches:', torch.allclose(builtin_ctc_grad, ce_ctc_grad, atol = atol)) 46 | 47 | alignment = ctc.ctc_alignment(log_probs, targets, input_lengths, target_lengths, blank = 0) 48 | a = torch.zeros(T, t); a[alignment[0, :target_lengths[0]], torch.arange(t)] = 1.0 49 | plt.subplot(211) 50 | plt.title('Input-Output Viterbi alignment') 51 | plt.imshow(a.t().cpu(), origin = 'lower', aspect = 'auto') 52 | plt.xlabel('Input steps') 53 | plt.ylabel('Output steps') 54 | plt.subplot(212) 55 | plt.title('CTC alignment targets') 56 | a = ce_alignment_targets[:, 0, :] 57 | plt.imshow(a.t().cpu(), origin = 'lower', aspect = 'auto') 58 | plt.xlabel('Input steps') 59 | plt.ylabel(f'Output symbols, blank {blank}') 60 | plt.subplots_adjust(hspace = 0.5) 61 | plt.savefig('alignment.png') 62 | --------------------------------------------------------------------------------