├── README.md └── mwer_loss.py /README.md: -------------------------------------------------------------------------------- 1 | # mwer 2 | mWER (minimum word error rate) loss implementation in tensorflow 3 | -------------------------------------------------------------------------------- /mwer_loss.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def mwer_loss( 5 | candidate_seqs, 6 | candidate_seq_lens, 7 | candidate_seq_cnts, 8 | candidate_seq_edit_errors, 9 | candidate_seq_logprobs 10 | ): 11 | """Computes the mWER (minimum WER) Loss. 12 | 13 | This op implements the mWER loss as presented in the article: 14 | 15 | [Rohit Prabhavalkar etc. 16 | MINIMUM WORD ERROR RATE TRAINING FOR ATTENTION-BASED 17 | SEQUENCE-TO-SEQUENCE MODELS](https://arxiv.org/pdf/1712.01818.pdf) 18 | 19 | Input requirements: 20 | 21 | Notations: 22 | B: batch size 23 | N: the number of candidate sequences (i.e. hypothesis sequences) plus 1. 24 | the last sequence is treated as the ground truth and used to compute ce loss. 25 | U: max length of candidate sequences including SOS (but not EOS). 26 | V: vocabulary size; number of tokens. 27 | 28 | Args: 29 | candidate_seqs: An `int32` `Tensor` with shape (B, N, U + 1). 30 | `candidate_seqs[b, n, u]` means u-th token-id of n-th candidate sequence (including SOS and EOS) of b-th sample. 31 | candidate_seq_lens: An `int32` `Tensor` with shape (B, N,). 32 | actual lengths of each candidate sequence, including SOS and EOS. 33 | candidate_seq_lens[n] <= U + 1 for n in range(N) 34 | candidate_seq_edit_errors: An `float32` `Tensor` with shape (B, N,). 35 | the edit distance error for each candidate sequence. 36 | candidate_seq_logprobs: An `float32` `Tensor` with shape (B, N, U, V). 37 | `candidate_seq_logprobs[b, n, u, v]` the log prob of being token v for the u-th token 38 | of n-th candidate sequence (not including SOS) of b-th sample. 39 | candidate_seq_cnts: An `int32` `Tensor` with shape (B, 1). 40 | candidate_seq_cnts[b] <= N - 1 for all b in range(B). denotes the real number of effective candidate sequences. 41 | because sometimes there's not enough N - 1 candidates since the prior beam search is too sure 42 | about its top few hypotheses. 43 | 44 | Returns: 45 | weighted_relative_edit_errors: 46 | A 1-D `float` `Tensor`, size `[B]`, a batch of mWER loss. 47 | ce_loss: 48 | A 1-D `float` `Tensor`, size `[B]`, a batch of CE loss. 49 | rescore_wer: 50 | A 1-D `float` `Tensor`, size `[B]`, WER of the top-1 rescored sequences. 51 | This output is only for metric/evaluate and won't be back propagated. 52 | """ 53 | int_shape = candidate_seq_logprobs.get_shape().as_list() 54 | shape = tf.shape(candidate_seq_logprobs) 55 | 56 | B = int_shape[0] or shape[0] 57 | N = int_shape[1] 58 | U = int_shape[2] or shape[2] 59 | V = int_shape[3] 60 | 61 | flatten_logprobs = tf.reshape(candidate_seq_logprobs, (-1, V)) # (B * N * U, V) 62 | flatten_tokens = tf.reshape(candidate_seqs[:, :, 1:], shape=(-1,)) # (B * N * U,) 63 | indices = tf.transpose(tf.stack([tf.range(B * N * U), flatten_tokens])) # (B * N * U, 2) 64 | flatten_logprobs = tf.gather_nd(flatten_logprobs, indices) # (B * N * U,) 65 | token_logprobs = tf.reshape(flatten_logprobs, (B * N, U)) # (B * N, U) 66 | 67 | token_mask = tf.sequence_mask( 68 | tf.reshape(candidate_seq_lens - 1, shape=(-1,)), 69 | maxlen=U, 70 | dtype=tf.dtypes.float32 71 | ) # (B * N, U) 72 | masked_token_logprobs = token_logprobs * token_mask # (B * N, U) 73 | masked_token_logprobs = tf.reshape(masked_token_logprobs, (B, N, U)) # (B, N, U) 74 | seq_logprobs = tf.reduce_sum(masked_token_logprobs, axis=-1) # (B, N) 75 | 76 | def softmax_with_mask(logits, mask): 77 | mask = tf.cast(mask, tf.dtypes.float32) 78 | logits -= 10000.0 * (1.0 - mask) 79 | ai = tf.exp(logits - tf.reduce_max(logits, axis=-1, keepdims=True)) 80 | softmax_result = ai / (tf.reduce_sum(ai, axis=1, keepdims=True) + 1e-10) 81 | return softmax_result 82 | 83 | # mask out the padding seqs and the final ground truth seq. 84 | seq_mask = tf.sequence_mask( 85 | tf.reshape(candidate_seq_cnts, shape=(-1,)), 86 | maxlen=N, 87 | dtype=tf.dtypes.float32 88 | ) # (B, N) 89 | renormalized_seq_probs = softmax_with_mask(seq_logprobs, seq_mask) # (B, N) 90 | 91 | masked_edit_errors = seq_mask * candidate_seq_edit_errors # (B, N) 92 | avg_edit_errors = tf.reduce_sum(masked_edit_errors, axis=-1, keepdims=True) / tf.cast(candidate_seq_cnts, 'float32') # (B, 1) 93 | relative_edit_errors = seq_mask * (masked_edit_errors - tf.tile(avg_edit_errors, (1, N))) 94 | weighted_relative_edit_errors = tf.reduce_sum(renormalized_seq_probs * relative_edit_errors, axis=-1) # (B,) 95 | 96 | # the last seq of each sample is used to calculate CE loss 97 | ce_loss = -seq_logprobs[:, -1] 98 | 99 | top1_seq_indices = tf.argmax(renormalized_seq_probs, axis=-1, output_type=tf.dtypes.int32) # (B,) 100 | indices = tf.transpose(tf.stack([tf.range(B), top1_seq_indices])) # (B, 2) 101 | chosen_seq_edit_errors = tf.gather_nd(masked_edit_errors, indices) # (B,) 102 | ground_seq_len = candidate_seq_lens[:, -1] - 2 # (B,) 103 | rescore_wer = chosen_seq_edit_errors / tf.cast(ground_seq_len, tf.dtypes.float32) 104 | rescore_wer = tf.stop_gradient(rescore_wer) 105 | 106 | return [weighted_relative_edit_errors, ce_loss, rescore_wer] 107 | --------------------------------------------------------------------------------