├── .gitignore ├── .DS_Store ├── Data └── .gitignore ├── Models └── .gitignore ├── Results └── .gitignore ├── Utils ├── preprocess_transcripts.py ├── remove_duplicates.py ├── filter_dev.py └── data_analysis.py ├── Code ├── plot_attentions.py ├── scoreboard.py ├── dataset.py ├── beamsearch.py ├── model.py └── main.py ├── LICENSE └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | Code/__pycache__/* 2 | Utils/__pycache__/* 3 | -------------------------------------------------------------------------------- /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Prasadlokhande-880/automatic_speech_recognition/HEAD/.DS_Store -------------------------------------------------------------------------------- /Data/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore 5 | -------------------------------------------------------------------------------- /Models/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore 5 | -------------------------------------------------------------------------------- /Results/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore 5 | -------------------------------------------------------------------------------- /Utils/preprocess_transcripts.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | 5 | # Preprocess Data. Modify to make sentences. 6 | 7 | DUMP_DATA_PATH = './../Data_Orig/Preprocessed' 8 | 9 | SPEECH_DATA_PATH = './../Data_Orig' 10 | # mode = 'train' 11 | 12 | # SPEECH_DATA_PATH = './Data_Clean/Filtered_Dev' 13 | mode = 'dev' 14 | 15 | data = np.load(os.path.join(SPEECH_DATA_PATH, 16 | '{}_transcripts.npy'.format(mode)), 17 | encoding='bytes') 18 | for i, utt in enumerate(data): 19 | s = "" 20 | for w in utt: 21 | for c in w: 22 | s += chr(c) 23 | s += " " 24 | s = s[:-1] 25 | data[i] = s 26 | 27 | np.save(os.path.join(DUMP_DATA_PATH, '{}_transcripts.npy'.format(mode)), data) 28 | -------------------------------------------------------------------------------- /Utils/remove_duplicates.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | 5 | SPEECH_DATA_PATH = './../Data' 6 | DUMP_DATA_PATH = './../Data_Clean' 7 | 8 | mode = 'train' 9 | # mode = 'dev' 10 | 11 | data_x = np.load(os.path.join(SPEECH_DATA_PATH, '{}.npy'.format(mode)), 12 | encoding='bytes') 13 | data_y = np.load(os.path.join(SPEECH_DATA_PATH, 14 | '{}_transcripts.npy'.format(mode)), 15 | encoding='bytes') 16 | dup_idx = np.load('{}_duplicates.npy'.format(mode), encoding='bytes') 17 | 18 | assert (len(data_x) == len(data_y)) 19 | data_x_rev = np.delete(data_x, dup_idx) 20 | data_y_rev = np.delete(data_y, dup_idx) 21 | assert (len(data_x_rev) == len(data_y_rev)) 22 | 23 | np.save(os.path.join(DUMP_DATA_PATH, '{}.npy'.format(mode)), data_x_rev) 24 | np.save(os.path.join(DUMP_DATA_PATH, '{}_transcripts.npy'.format(mode)), 25 | data_y_rev) 26 | -------------------------------------------------------------------------------- /Code/plot_attentions.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import torch 3 | 4 | # mode = 'train' 5 | # mode = 'val' 6 | mode = 'test' 7 | 8 | if mode == 'train': 9 | attention_weights = torch.load('./attention_weights_train.pt', 10 | map_location='cpu') 11 | elif mode == 'val': 12 | attention_weights = torch.load('./attention_weights_val.pt', 13 | map_location='cpu') 14 | else: 15 | attention_weights = torch.load('./attention_weights_test.pt', 16 | map_location='cpu') 17 | batch_idx = 0 18 | batch_size = len(attention_weights[0]) 19 | fig = plt.figure() 20 | plt.tight_layout() 21 | att_w = torch.from_numpy(attention_weights[0][batch_idx]).unsqueeze(0) 22 | for at in attention_weights[1:]: 23 | att_w = torch.cat((att_w, torch.from_numpy(at[batch_idx]).unsqueeze(0))) 24 | plt.imshow(att_w.detach().numpy()) 25 | plt.show() 26 | -------------------------------------------------------------------------------- /Utils/filter_dev.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | 5 | SPEECH_DATA_PATH = './../Data_Clean' 6 | DUMP_DATA_PATH = './../Data_Clean/Filtered_Dev' 7 | 8 | train_y = np.load(os.path.join(SPEECH_DATA_PATH, 'train_transcripts.npy'), 9 | encoding='bytes') 10 | dev_y = np.load(os.path.join(SPEECH_DATA_PATH, 'dev_transcripts.npy'), 11 | encoding='bytes') 12 | dev_x = np.load(os.path.join(SPEECH_DATA_PATH, 'dev.npy'), encoding='bytes') 13 | dup_list = [] 14 | for i in range(len(dev_y)): 15 | for j in range(len(train_y)): 16 | if np.array_equal(dev_y[i], train_y[j]): 17 | dup_list.append(i) 18 | break 19 | 20 | dev_y_rev = np.delete(dev_y, dup_list) 21 | dev_x_rev = np.delete(dev_x, dup_list) 22 | assert (len(dev_y_rev) == len(dev_x_rev)) 23 | 24 | np.save(os.path.join(DUMP_DATA_PATH, 'dev.npy'), dev_x_rev) 25 | np.save(os.path.join(DUMP_DATA_PATH, 'dev_transcripts.npy'), dev_y_rev) 26 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Anjandeep Singh Sahni 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Code/scoreboard.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | class Scoreboard(): 4 | def __init__(self, sort_param_idx=0, name='NoName'): 5 | self.sb = [] 6 | self.spi = sort_param_idx 7 | self.name = name 8 | 9 | def addItem(self, param_list): 10 | if self.sb: 11 | assert (len(param_list) == len(self.sb[0])) 12 | self.sb.append(param_list) 13 | 14 | def print_scoreboard(self, k, key): 15 | assert (len(key) == len(self.sb[0])) 16 | self.sb.sort(key=lambda x: x[self.spi]) 17 | 18 | print('=' * 20) 19 | print('Printing Scoreboard for', self.name) 20 | print('=' * 20) 21 | print("Top-{}".format(k)) 22 | print('=' * 20) 23 | for i in range(k): 24 | for idx in range(len(key)): 25 | print('{}: {}'.format(key[idx], self.sb[i][idx])) 26 | print('\n') 27 | print('=' * 20) 28 | print("Last-{}".format(k)) 29 | print('=' * 20) 30 | for i in range(-k, 0, 1): 31 | for idx in range(len(key)): 32 | print('{}: {}'.format(key[idx], self.sb[i][idx])) 33 | print('\n') 34 | 35 | def flush(self): 36 | self.sb = [] 37 | -------------------------------------------------------------------------------- /Utils/data_analysis.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | 5 | SPEECH_DATA_PATH = './../Data' 6 | SPEECH_DATA_PATH = './Data_Clean' 7 | 8 | mode = 'train' 9 | # mode = 'dev' 10 | 11 | data = np.load(os.path.join(SPEECH_DATA_PATH, 12 | '{}_transcripts.npy'.format(mode)), 13 | encoding='bytes') 14 | 15 | dup_list = [] 16 | duplicates = np.array([[0.1, 0.1]]) 17 | idx_list = list(np.arange(len(data))) 18 | for i in range(len(data)): 19 | if i in dup_list: 20 | continue 21 | else: 22 | curr_dups = [] 23 | for j in idx_list: 24 | if np.array_equal(data[j], data[i]) and (j != i): 25 | duplicates = np.vstack((duplicates, np.array([i, j]))) 26 | curr_dups.append(j) 27 | curr_dups.append(i) 28 | dup_list.extend(curr_dups) 29 | for dup in curr_dups: 30 | idx_list.remove(dup) 31 | 32 | duplicates = duplicates[1:, :] 33 | assert (len(duplicates) == len(np.unique(duplicates[:, 1]))) 34 | 35 | per = len(duplicates) / len(data) * 100 36 | print('Number of duplicate instances in %s data: %d/%d, %.2f %%' % 37 | (mode, len(duplicates), len(data), per)) 38 | 39 | remove_idx = np.unique(duplicates[:, 1]) 40 | np.save(mode + '_duplicates.npy', remove_idx) 41 | 42 | # Test 43 | if len(duplicates) > 1: 44 | idx = duplicates[0, 0] 45 | print('Testing with %s instance: %d' % (mode.upper(), int(idx))) 46 | sample_list = np.where(idx == duplicates[:, 0])[0] 47 | for s in sample_list: 48 | print(mode.upper(), 'Instance:', int(duplicates[s, 1])) 49 | print(data[int(duplicates[s, 1])]) 50 | -------------------------------------------------------------------------------- /Code/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn.utils.rnn as rnn 6 | from torch.utils.data import Dataset as Dataset 7 | 8 | SPEECH_DATA_PATH = './../Data_Orig' 9 | 10 | IGNORE_ID = -1 11 | VOCAB = [ 12 | '', ' ', "'", '+', '-', '.', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 13 | 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 14 | 'X', 'Y', 'Z', '_' 15 | ] 16 | 17 | 18 | class SpeechDataset(Dataset): 19 | def __init__(self, mode='train'): 20 | # Check for valid mode. 21 | self.mode = mode 22 | valid_modes = {'train', 'dev', 'test'} 23 | if self.mode not in valid_modes: 24 | raise ValueError("SpeechDataset Error: Mode must be one of %r." % 25 | valid_modes) 26 | self.vocab = VOCAB 27 | self.vocab_size = len(VOCAB) 28 | # Load the data and labels (labels = None for 'test' mode) 29 | # Labels must be a set of strings. 30 | self.data, self.labels_raw = self.loadRawData() 31 | if self.mode != 'test': 32 | self.labels = self.labels_raw.copy() 33 | self._preprocess_labels() 34 | 35 | def _preprocess_labels(self): 36 | for idx in range(len(self.labels)): 37 | label = [] 38 | for c in self.labels[idx]: 39 | label.append(self.vocab.index(c)) 40 | label.append(self.vocab.index(' ')) 41 | label[-1] = self.vocab.index('') 42 | label = [self.vocab.index('')] + label 43 | self.labels[idx] = torch.from_numpy(np.array(label)).long() 44 | 45 | def __len__(self): 46 | return len(self.data) 47 | 48 | def __getitem__(self, idx): 49 | data = torch.from_numpy(self.data[idx]) 50 | if self.mode == 'test': 51 | label = None 52 | else: 53 | label = self.labels[idx] 54 | return data, label 55 | 56 | def loadRawData(self): 57 | if self.mode == 'train' or self.mode == 'dev': 58 | return (np.load(os.path.join(SPEECH_DATA_PATH, 59 | '{}.npy'.format(self.mode)), 60 | encoding='bytes'), 61 | np.load(os.path.join( 62 | SPEECH_DATA_PATH, 63 | '{}_transcripts.npy'.format(self.mode)), 64 | encoding='bytes')) 65 | else: # No labels in test mode. 66 | return (np.load(os.path.join(SPEECH_DATA_PATH, 'test.npy'), 67 | encoding='bytes'), None) 68 | 69 | def _generate_vocab(self): 70 | vocab = set({}) 71 | for utt in self.labels: 72 | for w in utt: 73 | for c in w: 74 | vocab.add(chr(c)) 75 | vocab = list(sorted(vocab)) 76 | vocab = ['', ' '] + vocab 77 | return vocab 78 | 79 | 80 | # Modify the batch in collate_fn to sort the 81 | # batch in decreasing order of size. 82 | def SpeechCollateFn(seq_list): 83 | inputs, targets = zip(*seq_list) 84 | inp_lens = [len(seq) for seq in inputs] 85 | seq_order = sorted(range(len(inp_lens)), 86 | key=inp_lens.__getitem__, 87 | reverse=True) 88 | inputs = [inputs[i].type(torch.float32) 89 | for i in seq_order] # RNN does not accept Float64. 90 | inp_lens = [len(seq) for seq in inputs] 91 | tar_lens = [] 92 | targets_loss = None 93 | if targets[0] is not None: 94 | targets = [targets[i] for i in seq_order] 95 | tar_lens = [len(tar) for tar in targets] 96 | targets_loss = rnn.pad_sequence(targets, padding_value=IGNORE_ID) 97 | targets = rnn.pad_sequence(targets) 98 | return inputs, inp_lens, targets, targets_loss, tar_lens, seq_order 99 | -------------------------------------------------------------------------------- /Code/beamsearch.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | # Node class for beam search. 5 | class BeamNode(object): 6 | def __init__(self, parent, state, value, cost, extras): 7 | super(BeamNode, self).__init__() 8 | # value/predicted word for current node. 9 | self.value = value 10 | # parent Node, None for root 11 | self.parent = parent 12 | # current node's lstm hidden state 13 | self.state = state 14 | # cumulative cost of entire chain upto current node. 15 | self.cum_cost = parent.cum_cost + cost if parent else cost 16 | # length of entire chain 17 | self.length = 1 if parent is None else parent.length + 1 18 | # any extra variables to store for the node 19 | self.extras = extras 20 | # to hold the node's entire sequence. 21 | self._sequence = None 22 | 23 | def to_sequence(self): 24 | # Return sequence of nodes from root to current node. 25 | if not self._sequence: 26 | self._sequence = [] 27 | current_node = self 28 | while current_node: 29 | self._sequence.insert(0, current_node) 30 | current_node = current_node.parent 31 | return self._sequence 32 | 33 | def to_sequence_of_values(self): 34 | return [s.value for s in self.to_sequence()] 35 | 36 | def to_sequence_of_extras(self): 37 | return [s.extras for s in self.to_sequence()] 38 | 39 | 40 | def beam_search(initial_state_function, 41 | generate_function, 42 | X, 43 | end_id, 44 | batch_size=1, 45 | beam_width=4, 46 | num_hypotheses=1, 47 | max_length=50, 48 | vocab_size=33): 49 | # initial_state_function: A function that takes X as input and returns 50 | # state (2-dimensonal numpy array with 1 row 51 | # representing decoder recurrent layer state). 52 | # generate_function: A function that takes Y_tm1 (1-dimensional numpy array 53 | # of token indices in decoder vocabulary generated at 54 | # previous step) and state_tm1 (2-dimensonal numpy array 55 | # of previous step decoder recurrent layer states) as 56 | # input and returns state_t (2-dimensonal numpy array of 57 | # current step decoder recurrent layer states), 58 | # p_t (2-dimensonal numpy array of decoder softmax 59 | # outputs) and optional extras (e.g. attention weights 60 | # at current step). 61 | # X: List of input token indices in encoder vocabulary. 62 | # end_id: Index of token in decoder vocabulary. 63 | # batch_size: Batch size. TBD ! 64 | # beam_width: Beam size. Default 4. (NOTE: Fails for beam > vocab) 65 | # num_hypotheses: Number of hypotheses to generate. Default 1. 66 | # max_length: Length limit for generated sequence. Default 50. 67 | initial_state, initial_value = initial_state_function(X, batch_size) 68 | next_fringe = [BeamNode(parent=None, 69 | state=initial_state, 70 | value=initial_value, 71 | cost=0.0, 72 | extras=None) 73 | ] 74 | hypotheses = [] 75 | 76 | for step in range(max_length): 77 | fringe = [] 78 | for n in next_fringe: 79 | if (step != 0 and n.value == end_id) or step == max_length - 1: 80 | hypotheses.append(n) 81 | else: 82 | fringe.append(n) 83 | 84 | if not fringe or len(hypotheses) >= num_hypotheses: 85 | # if not fringe: 86 | break 87 | 88 | Y_tm1 = [n.value for n in fringe] 89 | state_tm1 = [n.state for n in fringe] 90 | state_t, p_t, extras_t = generate_function(Y_tm1, state_tm1) 91 | Y_t = np.argsort( 92 | p_t, axis=1 93 | )[:, -beam_width:] # no point in taking more than fits in the beam 94 | next_fringe = [] 95 | for Y_t_n, p_t_n, extras_t_n, state_t_n, n in zip( 96 | Y_t, p_t, extras_t, state_t, fringe): 97 | Y_nll_t_n = -np.log(p_t_n[Y_t_n]) 98 | 99 | for y_t_n, y_nll_t_n in zip(Y_t_n, Y_nll_t_n): 100 | n_new = BeamNode(parent=n, 101 | state=state_t_n, 102 | value=y_t_n, 103 | cost=y_nll_t_n, 104 | extras=extras_t_n) 105 | next_fringe.append(n_new) 106 | 107 | next_fringe = sorted( 108 | next_fringe, key=lambda n: n.cum_cost 109 | )[:beam_width] # may move this into loop to save memory 110 | 111 | hypotheses.sort(key=lambda n: n.cum_cost) 112 | return hypotheses[:num_hypotheses] 113 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Automatic Speech Recognition 2 | 3 | ## Introduction 4 | 5 | An **automatic speech recognition** system should be able to transcribe a given speech utterance to its corresponding transcript, end-to-end. We are provided with the utterances and their corresponding transcript. We can achieve this by using a combination of Recurrent Neural Networks (RNNs) / Convolutional Neural Networks (CNNs) and Dense Networks to design a system for speech to text transcription. 6 | 7 | ## Design 8 | 9 | There are many ways to approach this problem. This project makes use of an **attention** based system. Attention Mechanisms are widely used for various applications these days. More often than not, speech tasks can also be extended to images. Specifically this repo implements a variation of **Listen, Attend and Spell**. 10 | 11 | ### Listen, Attend and Spell 12 | 13 | The idea is to learn all components of a speech recogniser jointly. The paper describes an encoder-decoder approach, called Listener and Speller respectively. 14 | 15 | The **Listener** consists of a **Pyramidal Bi-LSTM Network** structure that takes in the given utterances and compresses it to produce high-level representations for the Speller network. 16 | 17 | The **Speller** takes in the high-level feature output from the Listener network and uses it to compute a probability distribution over sequences of characters using the **attention mechanism**. 18 | 19 | Attention intuitively can be understood as trying to learn a mapping from a word vector to some areas of the utterance map. The Listener produces a high-level representation of the given utterance and the Speller uses parts of the representation (produced from the Listener) to predict the next word in the sequence. 20 | 21 | ### Variation to LAS 22 | 23 | The LAS model only uses a single projection from the Listener network. However, we can instead take two projections and use them as an Attention Key and an Attention Value. 24 | 25 | The encoder network in this case produces two outputs, an attention **value** and a **key** and the decoder network over the transcripts will produce an attention query. The dot product between that query and the key is called the **energy** of the attention. 26 | 27 | Subsequently, we feed that energy into a Softmax, and use that Softmax distribution as a mask to take a weighted sum from the attention value, that is, apply the attention mask on the values from the encoder. This masked value is called the attention **context**, which is fed back into the transcript network. 28 | 29 | ### Variable Length Inputs 30 | 31 | The transcripts as well as the utterances are of variable length. In order to deal with this problem, we use the built-in pack padded sequence and pad packed sequence APIs from PyTorch. This will pack variable length inputs into a combined tensor input which can be fed into the Encoder. 32 | 33 | ### Listener/Encoder 34 | 35 | The encoder is the part that runs over the utterances to produce attention values and keys. Here we have a batch of utterances and use a layer of Bi-LSTMs to obtain the features. Subsequently we perform a pooling like operation by concatenating outputs. We do this three times as mentioned in the paper and lastly project the final layer output into an attention key and value pair. 36 | 37 | ### Speller/Decoder 38 | 39 | The decoder is an LSTM that takes character[t] as input and produces character[t+1] as output on each time-step. The decoder also receives additional information through the attention context mechanism. As a consequence, we cannot use the LSTM implementation in PyTorch directly, and we 40 | instead have to use LSTMCell to run each time-step in a for loop. 41 | 42 | ### Teacher Forcing 43 | 44 | One problem we encounter in this setting is the difference of training time and evaluation time: at test time we pass in the generated characters from our model (to predict the output at t+1), when our network is used to having perfect labels passed in during training. One way to help our network be better at accounting for this noise is to actually pass in the generated characters during training, rather than the true characters, with some probability. This is known as teacher forcing. 45 | 46 | ## Dataset and Preprocessing 47 | 48 | The Wall Street Journal (WSJ) dataset was used for this work. It contains the raw text. We can either use character-based or word-based model. 49 | 50 | Word-based models wont have incorrect spelling and are very quick in training because the sample size decreases drastically. The problem is, it cannot predict rare words. 51 | 52 | Character-based models are known to be able to predict some really rare words but at the same time they are slow to train because the model needs to predict character by character. 53 | 54 | This repo implements the character-based model. Hence we need to preprocess the data to split the raw text (sentences) into characters and subsequently each character is mapped to a unique integer (refer to VOCAB in dataset.py). 55 | 56 | Each transcript/utterance is a separate sample that is a variable length. In order to predict all characters, we need a start and end character added to our vocabulary. We can make them both the same number, like 0, to make things easier. 57 | 58 | For example, if the utterance is hello, then: 59 | - inputs=[start]hello 60 | - outputs=hello[end] 61 | 62 | Refer to in VOCAB list in dataset.py. 63 | 64 | ## Evaluation 65 | 66 | Performance is evaluated using CER - character error rate (edit distance). 67 | 68 | ## Results 69 | 70 | The given model achieves **CER of 10.63** on WSJ dataset. 71 | 72 | ## References 73 | 74 | - **Listen, Attend and Spell**: https://arxiv.org/pdf/1508.01211.pdf 75 | -------------------------------------------------------------------------------- /Code/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | from torch.nn.utils.rnn import (pack_padded_sequence, pack_sequence, 6 | pad_packed_sequence) 7 | 8 | 9 | class LockedDropout(nn.Module): 10 | def __init__(self): 11 | super().__init__() 12 | 13 | def forward(self, x, dropout=0.5): 14 | # x: (L, B, C) 15 | if dropout == 0 or not self.training: 16 | return x 17 | mask = x.data.new(1, x.size(1), x.size(2)) 18 | mask = mask.bernoulli_(1 - dropout) 19 | mask = Variable(mask, requires_grad=False) / (1 - dropout) 20 | mask = mask.expand_as(x) 21 | return mask * x 22 | 23 | 24 | ''' 25 | class WeightDrop(nn.Module): 26 | def __init__(self, module, weights, dropout=0): 27 | super(WeightDrop, self).__init__() 28 | self.module = module 29 | self.weights = weights 30 | self.dropout = dropout 31 | if issubclass(type(self.module), torch.nn.RNNBase): 32 | self.module.flatten_parameters = self.dummy_flatten 33 | for name_w in self.weights: 34 | print('Applying weight drop of {} to {}'.format(self.dropout, 35 | name_w)) 36 | w = getattr(self.module, name_w) 37 | del self.module._parameters[name_w] 38 | self.module.register_parameter(name_w + '_raw', nn.Parameter(w)) 39 | 40 | def dummy_flatten(*args, **kwargs): 41 | return 42 | 43 | def forward(self, *args): 44 | for name_w in self.weights: 45 | raw_w = getattr(self.module, name_w + '_raw') 46 | w = torch.nn.functional.dropout(raw_w, p=self.dropout, 47 | training=self.training) 48 | setattr(self.module, name_w, nn.Parameter(w)) 49 | return self.module.forward(*args) 50 | ''' 51 | 52 | 53 | class BackHook(torch.nn.Module): 54 | def __init__(self, hook): 55 | super(BackHook, self).__init__() 56 | self._hook = hook 57 | self.register_backward_hook(self._backward) 58 | 59 | def forward(self, *inp): 60 | return inp 61 | 62 | @staticmethod 63 | def _backward(self, grad_in, grad_out): 64 | self._hook() 65 | return None 66 | 67 | 68 | class WeightDrop(torch.nn.Module): 69 | """ 70 | Implements drop-connect, as per Merity, https://arxiv.org/abs/1708.02182 71 | """ 72 | def __init__(self, module, weights, dropout=0, variational=False): 73 | super(WeightDrop, self).__init__() 74 | self.module = module 75 | self.weights = weights 76 | self.dropout = dropout 77 | self.variational = variational 78 | self._setup() 79 | self.hooker = BackHook(lambda: self._backward()) 80 | 81 | def _setup(self): 82 | for name_w in self.weights: 83 | print('Applying weight drop of {} to {}'.format( 84 | self.dropout, name_w)) 85 | w = getattr(self.module, name_w) 86 | self.register_parameter(name_w + '_raw', nn.Parameter(w.data)) 87 | 88 | def _setweights(self): 89 | for name_w in self.weights: 90 | raw_w = getattr(self, name_w + '_raw') 91 | if self.training: 92 | mask = raw_w.new_ones((raw_w.size(0), 1)) 93 | mask = torch.nn.functional.dropout(mask, 94 | p=self.dropout, 95 | training=True) 96 | w = mask.expand_as(raw_w) * raw_w 97 | setattr(self, name_w + "_mask", mask) 98 | else: 99 | w = raw_w 100 | rnn_w = getattr(self.module, name_w) 101 | rnn_w.data.copy_(w) 102 | 103 | def _backward(self): 104 | # transfer gradients from embeddedRNN to raw params 105 | for name_w in self.weights: 106 | raw_w = getattr(self, name_w + '_raw') 107 | rnn_w = getattr(self.module, name_w) 108 | raw_w.grad = rnn_w.grad * getattr(self, name_w + "_mask") 109 | 110 | def forward(self, *args): 111 | self._setweights() 112 | return self.module(*self.hooker(*args)) 113 | 114 | 115 | class Encoder(nn.Module): 116 | def __init__(self, base=128, device="cpu"): 117 | super(Encoder, self).__init__() 118 | self.lstm1 = nn.LSTM(40, base, bidirectional=True) 119 | self.lstm2 = self.__make_layer__(base * 4, base) 120 | self.lstm3 = self.__make_layer__(base * 4, base) 121 | self.lstm4 = self.__make_layer__(base * 4, base) 122 | 123 | self.fc1 = nn.Linear(base * 2, base * 2) 124 | self.fc2 = nn.Linear(base * 2, base * 2) 125 | self.act = nn.SELU(inplace=True) 126 | 127 | self.drop = LockedDropout() 128 | self.device = device 129 | 130 | def _stride2(self, x): 131 | x = x[:x.size(0) // 2 * 2] # make even 132 | x = self.drop(x, dropout=0.3) 133 | x = x.permute(1, 0, 2) # seq, batch, feature -> batch, seq, feature 134 | x = x.reshape(x.size(0), x.size(1) // 2, x.size(2) * 2) 135 | x = x.permute(1, 0, 2) # batch, seq, feature -> seq, batch, feature 136 | return x 137 | 138 | def __make_layer__(self, in_dim, out_dim): 139 | lstm = nn.LSTM(input_size=in_dim, 140 | hidden_size=out_dim, 141 | bidirectional=True) 142 | # return lstm 143 | return WeightDrop(lstm, ['weight_hh_l0', 'weight_hh_l0_reverse'], 144 | dropout=0.5) 145 | 146 | def forward(self, x): 147 | # x is list of variable length inputs. 148 | x = pack_sequence(x) # seq, batch, 40 149 | x = x.to(self.device) 150 | 151 | x, _ = self.lstm1(x) # seq, batch, base*2 152 | x, seq_len = pad_packed_sequence(x) 153 | x = self._stride2(x) # seq//2, batch, base*4 154 | 155 | x = pack_padded_sequence(x, seq_len // 2) 156 | x, _ = self.lstm2(x) # seq//2, batch, base*2 157 | x, _ = pad_packed_sequence(x) 158 | x = self._stride2(x) # seq//4, batch, base*4 159 | 160 | x = pack_padded_sequence(x, seq_len // 4) 161 | x, _ = self.lstm3(x) # seq//4, batch, base*2 162 | x, _ = pad_packed_sequence(x) 163 | x = self._stride2(x) # seq//8, batch, base*4 164 | 165 | x = pack_padded_sequence(x, seq_len // 8) 166 | x, (hidden, _) = self.lstm4(x) # seq//8, batch, base*2 167 | x, _ = pad_packed_sequence(x) 168 | 169 | key = self.act(self.fc1(x)) 170 | value = self.act(self.fc2(x)) 171 | hidden = torch.cat([hidden[0, :, :], hidden[1, :, :]], dim=1) 172 | return seq_len // 8, key, value, hidden 173 | 174 | 175 | class Attention(nn.Module): 176 | def __init__(self): 177 | super(Attention, self).__init__() 178 | 179 | def forward(self, hidden2, key, value, mask): 180 | # key: seq, batch, base # value: seq, batch, base 181 | # mask: batch, seq # hidden2: batch, base 182 | # batch, 1, base X batch, base, seq -> batch, 1, seq 183 | attn = torch.bmm(hidden2.unsqueeze(1), key.permute(1, 2, 0)) 184 | attn = F.softmax(attn, dim=2) 185 | attn = attn * mask.unsqueeze(1).float() 186 | attn = attn / attn.sum(2).unsqueeze(2) 187 | 188 | # batch, 1, seq X batch, seq, base -> batch, 1, base 189 | context = torch.bmm(attn, value.permute(1, 0, 2)).squeeze(1) 190 | 191 | # context: batch, 1, base -> batch, base 192 | # attn: batch, 1, seq -> batch, seq 193 | return context.squeeze(1), attn.cpu().squeeze(1).data.numpy() 194 | 195 | 196 | class Decoder(nn.Module): 197 | def __init__(self, vocab_dim, lstm_dim): 198 | super(Decoder, self).__init__() 199 | self.embed = nn.Embedding(vocab_dim, lstm_dim) 200 | self.lstm1 = nn.LSTMCell(lstm_dim * 2, lstm_dim) 201 | self.lstm2 = nn.LSTMCell(lstm_dim, lstm_dim) 202 | self.drop = nn.Dropout(0.3) 203 | self.fc = nn.Linear(lstm_dim, vocab_dim) 204 | self.fc.weight = self.embed.weight # weight tying 205 | 206 | def forward(self, x, context, hidden1, cell1, hidden2, cell2, first_step): 207 | # x is batch x 1. Contains word for previous timestep. 208 | x = self.embed(x) 209 | x = torch.cat([x, context], dim=1) 210 | if first_step: 211 | hidden1, cell1 = self.lstm1(x) 212 | hidden2, cell2 = self.lstm2(hidden1) 213 | else: 214 | hidden1, cell1 = self.lstm1(x, (hidden1, cell1)) 215 | hidden2, cell2 = self.lstm2(hidden1, (hidden2, cell2)) 216 | x = self.drop(hidden2) 217 | x = self.fc(x) 218 | return x, hidden1, cell1, hidden2, cell2 219 | 220 | 221 | class Seq2Seq(nn.Module): 222 | def __init__(self, base, vocab_dim, device="cpu"): 223 | super().__init__() 224 | self.base = base 225 | self.device = device 226 | self.vocab_dim = vocab_dim 227 | self.encoder = Encoder(base=base, device=device) 228 | self.attention = Attention() 229 | self.decoder = Decoder(vocab_dim=vocab_dim, lstm_dim=base * 2) 230 | 231 | for name, param in self.named_parameters(): 232 | if 'weight' in name: 233 | nn.init.orthogonal_(param.data) 234 | else: 235 | nn.init.constant_(param.data, 0) 236 | 237 | def sample_gumbel(self, shape, eps=1e-10, out=None): 238 | U = out.resize_(shape).uniform_() if out is not None else torch.rand( 239 | shape) 240 | return -torch.log(eps - torch.log(U + eps)) 241 | 242 | def forward(self, inputs, words, TF): 243 | if self.training: 244 | word, hidden1, cell1, hidden2, cell2 = words[ 245 | 0, :], None, None, None, None 246 | words = words[1:, :] # Removing sos, already saved in word 247 | max_len, batch_size = words.shape[0], words.shape[1] 248 | else: 249 | max_len, batch_size = 251, len(inputs) 250 | word = torch.zeros(batch_size).long().to(self.device) 251 | hidden1, cell1, hidden2, cell2 = None, None, None, None 252 | TF = 0 # no teacher forcing for test and val. 253 | 254 | prediction = torch.zeros(max_len, batch_size, 255 | self.vocab_dim).to(self.device) 256 | 257 | # Run through encoder. 258 | lens, key, value, hidden2 = self.encoder(inputs) 259 | mask = torch.arange(lens.max()).unsqueeze(0) < lens.unsqueeze(1) 260 | mask = mask.to(self.device) 261 | 262 | attention_weights = [] 263 | 264 | for t in range(max_len): 265 | context, attention = self.attention(hidden2, key, value, mask) 266 | word_vec, hidden1, cell1, hidden2, cell2 = self.decoder( 267 | word, 268 | context, 269 | hidden1, 270 | cell1, 271 | hidden2, 272 | cell2, 273 | first_step=(t == 0)) 274 | prediction[t] = word_vec 275 | teacher_force = torch.rand(1) < TF 276 | if teacher_force: 277 | word = words[t] 278 | else: 279 | gumbel = torch.autograd.Variable( 280 | self.sample_gumbel(shape=word_vec.size(), 281 | out=word_vec.data.new())) 282 | word_vec += gumbel 283 | word = word_vec.max(1)[1] 284 | attention_weights.append(attention) 285 | return prediction, attention_weights 286 | 287 | def get_initial_state(self, inputs, batch_size): 288 | self._lens, self._key, self._value, hidden2 = self.encoder(inputs) 289 | self._mask = torch.arange( 290 | self._lens.max()).unsqueeze(0) < self._lens.unsqueeze(1) 291 | self._mask = self._mask.to(self.device) 292 | word = torch.zeros(batch_size).long().to(self.device) 293 | return [None, None, hidden2, None], word # Initial state is none. 294 | 295 | def generate(self, prev_words, prev_states): 296 | new_states, raw_preds, attention_scores = [], [], [] 297 | for prev_word, prev_state in zip(prev_words, prev_states): 298 | prev_word = Variable( 299 | self._value.data.new(1).long().fill_(int(prev_word))) 300 | hidden1, cell1, hidden2, cell2 = prev_state[0], prev_state[ 301 | 1], prev_state[2], prev_state[3] 302 | context, attention = self.attention(hidden2, self._key, 303 | self._value, self._mask) 304 | first_step = False 305 | if prev_state[0] is None: 306 | first_step = True # First timestep. 307 | word_vec, hidden1, cell1, hidden2, cell2 = self.decoder( 308 | prev_word, 309 | context, 310 | hidden1, 311 | cell1, 312 | hidden2, 313 | cell2, 314 | first_step=first_step) 315 | new_state = [hidden1, cell1, hidden2, cell2] 316 | new_states.append(new_state) 317 | raw_preds.append( 318 | F.softmax(word_vec, dim=1).squeeze().data.cpu().numpy()) 319 | attention_scores.append(attention) 320 | return new_states, raw_preds, attention_scores 321 | -------------------------------------------------------------------------------- /Code/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import csv 3 | import os 4 | import time 5 | 6 | import Levenshtein as Lev 7 | import matplotlib.pyplot as plt 8 | import numpy as np 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | import torch.optim as optim 13 | from torch.utils.data import DataLoader 14 | 15 | from beamsearch import beam_search 16 | from dataset import * # noqa F403 17 | from model import Seq2Seq 18 | from scoreboard import Scoreboard 19 | 20 | # Paths 21 | MODEL_PATH = './Models' 22 | TEST_RESULT_PATH = './Results' 23 | GRAD_FIGURES_PATH = './Grad_Figures' 24 | 25 | # Defaults 26 | DEFAULT_RUN_MODE = 'train' 27 | DEFAULT_FEATURE_SIZE = 40 28 | DEFAULT_TRAIN_BATCH_SIZE = 32 29 | DEFAULT_TEST_BATCH_SIZE = 32 30 | DEFAULT_RANDOM_SEED = 2222 31 | SCOREBOARD_KEY = ["CER", "Mine", "Label"] 32 | 33 | # Hyperparameters. 34 | LEARNING_RATE = 1e-3 35 | WEIGHT_DECAY = 1.2e-6 36 | GRADIENT_CLIP = 0 # 0.25 37 | 38 | 39 | def parse_args(): 40 | parser = argparse.ArgumentParser( 41 | description='Training/testing for Speech Recognition.') 42 | parser.add_argument('--mode', type=str, choices=['train', 'test'], 43 | default=DEFAULT_RUN_MODE, 44 | help='\'train\' or \'test\' mode.') 45 | parser.add_argument('--train_batch_size', type=int, 46 | default=DEFAULT_TRAIN_BATCH_SIZE, 47 | help='Training batch size.') 48 | parser.add_argument('--test_batch_size', type=int, 49 | default=DEFAULT_TEST_BATCH_SIZE, 50 | help='Testing batch size.') 51 | parser.add_argument('--model_path', type=str, 52 | help='Path to model to be reloaded.') 53 | return parser.parse_args() 54 | 55 | 56 | def generate_labels_string(batch_pred, vocab): 57 | # Loop over entire batch list of predicted labels 58 | # and convert them to strings. 59 | batch_strings = [] 60 | for pred in batch_pred: 61 | batch_strings.append(''.join([vocab[pred[i]] 62 | for i in range(len(pred))])) 63 | return batch_strings 64 | 65 | 66 | def find_best_word(sample, word_dict): 67 | # find best possible/closest word to predicted word. 68 | if sample in word_dict: 69 | return sample, 0 70 | else: 71 | best_word = sample 72 | best_dist = 500 73 | for idx, word in enumerate(word_dict): 74 | dist = Lev.distance(sample, word) 75 | if dist < best_dist: 76 | best_dist = dist 77 | best_word = word 78 | return best_word, best_dist 79 | 80 | 81 | def map_strings_to_closest_words(pred_list, word_dict): 82 | # cleans up predicted strings by mapping them 83 | # to the closest word in word dict 84 | print('\nCleaning up predicted strings.') 85 | new_pred_list = [] 86 | for idx, pred_str in enumerate(pred_list): 87 | print('Predicted String: %d/%d' % (idx+1, len(pred_list)), 88 | end="\r", flush=True) 89 | new_pred_str = [] 90 | words = pred_str.split(" ") 91 | for w in words: 92 | new_w, _ = find_best_word(w, word_dict) 93 | new_pred_str.append(new_w) 94 | new_pred_list.append(" ".join(new_pred_str)) 95 | return new_pred_list 96 | 97 | 98 | def character_error_rate(pred, targets): 99 | assert len(pred) == len(targets) 100 | dist = [] 101 | for idx, p in enumerate(pred): 102 | dist.append(Lev.distance(p, targets[idx])) 103 | return dist 104 | 105 | 106 | def plot_grad_flow(named_parameters, batch, epoch): 107 | ave_grads = [] 108 | layers = [] 109 | for n, p in named_parameters: 110 | if(p.requires_grad) and ("bias" not in n): 111 | layers.append(n) 112 | if p.grad is None: 113 | ave_grads.append(-10) 114 | else: 115 | ave_grads.append(p.grad.abs().mean()) 116 | fig = plt.figure() 117 | plt.plot(ave_grads, alpha=0.3, color="r") 118 | plt.hlines(0, 0, len(ave_grads)+1, linewidth=1, color="k") 119 | plt.xticks(range(0, len(ave_grads), 1), layers, rotation="vertical") 120 | plt.xlim(xmin=0, xmax=len(ave_grads)) 121 | plt.xlabel("Layers") 122 | plt.ylabel("average gradient") 123 | plt.title("Gradient flow") 124 | plt.grid(True) 125 | plt.tight_layout() 126 | save_path = os.path.join(GRAD_FIGURES_PATH, str(epoch)) 127 | if not os.path.isdir(save_path): 128 | os.makedirs(save_path, exist_ok=True) 129 | fig.savefig(save_path+"/gradient_flow_"+str(batch)+".png") 130 | plt.close() 131 | 132 | 133 | def greedy_decode(outputs, eos_token): 134 | probs = F.softmax(outputs, dim=2) 135 | preds = torch.argmax(probs, dim=2) 136 | # Iterate over each item in batch. 137 | pred_list = [] 138 | for i in range(preds.size(0)): 139 | eos_idx = (preds[i] == eos_token).nonzero() 140 | eos_idx = (len(preds[i])-1) if eos_idx.nelement() == 0 else eos_idx[0] 141 | # pick all predicted chars excluding eos 142 | pred_list.append(preds[i, :eos_idx]) 143 | return pred_list 144 | 145 | 146 | def decode_and_cer(outputs, targets, tar_lens, vocab): 147 | eos_token = vocab.index('') 148 | pred_list = greedy_decode(outputs, eos_token) 149 | # exclude eos and sos 150 | tar_list = [targets[i, 1:(tar_lens[i]-1)] for i in range(targets.shape[0])] 151 | # Calculate the strings for predictions. 152 | pred_str = generate_labels_string(pred_list, vocab) 153 | # Calculate the strings for targets. 154 | tar_str = generate_labels_string(tar_list, vocab) 155 | # Calculate edit distance between predictions and targets. 156 | return character_error_rate(pred_str, tar_str), pred_str, tar_str 157 | 158 | 159 | def save_test_results(predictions): 160 | predictions_count = list(range(len(predictions))) 161 | csv_output = [[i, j] for i, j in zip(predictions_count, predictions)] 162 | if not os.path.isdir(TEST_RESULT_PATH): 163 | os.makedirs(TEST_RESULT_PATH, exist_ok=True) 164 | result_file_path = os.path.join(TEST_RESULT_PATH, 'result_{}.csv'.format( 165 | (str.split(str.split(args.model_path, '/')[-1], '.pt')[0]))) 166 | with open(result_file_path, mode='w') as csv_file: 167 | csv_writer = csv.writer(csv_file, delimiter=',', quotechar='"', 168 | quoting=csv.QUOTE_MINIMAL) 169 | csv_writer.writerow(['Id', 'Predicted']) 170 | csv_writer.writerows(csv_output) 171 | 172 | 173 | def test_model(model, test_loader, device): 174 | with torch.no_grad(): 175 | model.eval() 176 | start_time = time.time() 177 | all_predictions = [] 178 | for batch_idx, (inputs, _, _, _, _, seq_order) in \ 179 | enumerate(test_loader): 180 | outputs, attention_weights = model(inputs, None, 0) 181 | torch.save(attention_weights, 'attention_weights_test.pt') 182 | # make outputs batch first. 183 | outputs = outputs.permute(1, 0, 2) 184 | eos_token = test_loader.dataset.vocab.index('') 185 | pred_list = greedy_decode(outputs, eos_token) 186 | pred_str = generate_labels_string(pred_list, 187 | test_loader.dataset.vocab) 188 | # Input is sorted as per length for rnn. Resort the output. 189 | reorder_seq = np.argsort(seq_order) 190 | pred_str = [pred_str[i] for i in reorder_seq] 191 | all_predictions.extend(pred_str) 192 | print('Test Iteration: %d/%d' % (batch_idx+1, len(test_loader)), 193 | end="\r", flush=True) 194 | end_time = time.time() 195 | # Try to map words in strings to closest words. 196 | # all_predictions = map_strings_to_closest_words(all_predictions, 197 | # word_dict) 198 | # Save predictions in csv file. 199 | save_test_results(all_predictions) 200 | print('\nTotal Test Predictions: %d Time: %d s' % ( 201 | len(all_predictions), end_time - start_time)) 202 | 203 | 204 | # NOTE: Batch size must be one for test_model2 !! 205 | def test_model2(model, test_loader, device, word_dict): 206 | with torch.no_grad(): 207 | model.eval() 208 | start_time = time.time() 209 | all_predictions = [] 210 | eos_token = test_loader.dataset.vocab.index('') 211 | for batch_idx, (inputs, _, _, _, _, seq_order) in \ 212 | enumerate(test_loader): 213 | hypos = beam_search(model.get_initial_state, model.generate, 214 | inputs, eos_token, batch_size=1, beam_width=8, 215 | num_hypotheses=1, max_length=250) 216 | pred_list = [] 217 | for n in hypos: 218 | nn = n.to_sequence_of_values() 219 | pred_list.append(nn[1:][:-1]) 220 | attention_weights = [n.to_sequence_of_extras() for n in hypos] 221 | pred_str = generate_labels_string(pred_list, 222 | test_loader.dataset.vocab) 223 | torch.save(attention_weights, 'attention_weights_test.pt') 224 | all_predictions.extend(pred_str) 225 | print('Test Iteration: %d/%d' % (batch_idx+1, len(test_loader)), 226 | end="\r", flush=True) 227 | # Try to map words in strings to closest words. 228 | # all_predictions = map_strings_to_closest_words(all_predictions, 229 | # word_dict) 230 | # Save predictions in csv file. 231 | save_test_results(all_predictions) 232 | end_time = time.time() 233 | print('\nTotal Test Predictions: %d Time: %d s' % ( 234 | len(all_predictions), end_time - start_time)) 235 | 236 | 237 | def val_model(model, val_loader, device, sb): 238 | with torch.no_grad(): 239 | model.eval() 240 | dist = [] 241 | start_time = time.time() 242 | for batch_idx, (inputs, _, targets, _, tar_lens, _) in \ 243 | enumerate(val_loader): 244 | targets = targets.to(device) 245 | outputs, attention_weights = model(inputs, None, 0) 246 | torch.save(attention_weights, 'attention_weights_val.pt') 247 | # make targets and outputs batch first. 248 | targets = targets.permute(1, 0) 249 | outputs = outputs.permute(1, 0, 2) 250 | # Decode and get edit distance. 251 | distances, pred_str, tar_str = \ 252 | decode_and_cer(outputs, targets, tar_lens, 253 | val_loader.dataset.vocab) 254 | dist.extend(distances) 255 | for i in range(len(distances)): 256 | sb.addItem([distances[i], pred_str[i], tar_str[i]]) 257 | print('Validation Iteration: %d/%d' % 258 | (batch_idx+1, len(val_loader)), 259 | end="\r", flush=True) 260 | end_time = time.time() 261 | dist = sum(dist)/len(dist) # Average over edit distance. 262 | print('\nValidation -> Edit Distance: %5.3f Time: %d s' % 263 | (dist, end_time - start_time)) 264 | return dist 265 | 266 | 267 | def train_model(model, train_loader, criterion, optimizer, device, tf, epoch, 268 | sb): 269 | model.train() 270 | running_loss = 0.0 271 | running_lens = 0.0 272 | dist = [] 273 | measure_training_accuracy = True 274 | start_time = time.time() 275 | for batch_idx, (inputs, _, targets, targets_loss, tar_lens, _) in \ 276 | enumerate(train_loader): 277 | targets, targets_loss = targets.to(device), targets_loss.to(device) 278 | optimizer.zero_grad() 279 | outputs, attention_weights = model(inputs, targets, tf) 280 | torch.save(attention_weights, 'attention_weights_train.pt') 281 | # make targets and outputs batch first. 282 | targets, targets_loss = \ 283 | targets.permute(1, 0), targets_loss.permute(1, 0) 284 | outputs = outputs.permute(1, 0, 2) 285 | loss = criterion(outputs.contiguous().view(-1, outputs.size(2)), 286 | targets_loss[:, 1:].contiguous().view(-1)) 287 | running_loss += loss.item() 288 | running_lens += float(sum(tar_lens)) 289 | loss = loss/len(tar_lens) # Average over batch. 290 | loss.backward() 291 | plot_grad_flow(model.named_parameters(), batch_idx+1, epoch) 292 | # To avoid exploding gradient issue. 293 | # torch.nn.utils.clip_grad_norm_(model.parameters(), GRADIENT_CLIP) 294 | optimizer.step() 295 | if measure_training_accuracy: 296 | distances, pred_str, tar_str = \ 297 | decode_and_cer(outputs, targets, tar_lens, 298 | train_loader.dataset.vocab) 299 | dist.extend(distances) 300 | for i in range(len(distances)): 301 | sb.addItem([distances[i], pred_str[i], tar_str[i]]) 302 | curr_avg_loss = (running_loss/running_lens) 303 | curr_perp = np.exp(curr_avg_loss) 304 | print('Train Iteration: %d/%d Loss = %5.4f, Perplexity = %5.4f' % 305 | (batch_idx+1, len(train_loader), curr_avg_loss, curr_perp), 306 | end="\r", flush=True) 307 | end_time = time.time() 308 | # Average over edit distance. 309 | dist = sum(dist)/len(dist) if measure_training_accuracy else -1 310 | running_loss = (running_loss/running_lens) 311 | perplexity = np.exp(running_loss) 312 | print('\nTraining -> Loss: %5.4f Perplexity: %5.4f ' 313 | 'Edit Distance: %5.4f Time: %d s' 314 | % (running_loss, perplexity, dist, end_time - start_time)) 315 | return running_loss 316 | 317 | 318 | if __name__ == "__main__": 319 | # Parse args. 320 | args = parse_args() 321 | print('='*20) 322 | print('Input arguments:\n%s' % (args)) 323 | 324 | # Validate arguments. 325 | if args.mode == 'test' and args.model_path is None \ 326 | and not args.model_ensemble: 327 | raise ValueError("Input Argument Error: Test mode specified " 328 | "but model_path is %s." % (args.model_path)) 329 | 330 | # Check for CUDA. 331 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 332 | 333 | # Create datasets and dataloaders. 334 | speechTrainDataset = SpeechDataset(mode='train') # noqa F405 335 | speechValDataset = SpeechDataset(mode='dev') # noqa F405 336 | speechTestDataset = SpeechDataset(mode='test') # noqa F405 337 | 338 | train_loader = DataLoader(speechTrainDataset, 339 | batch_size=args.train_batch_size, 340 | shuffle=True, num_workers=1, 341 | collate_fn=SpeechCollateFn) # noqa F405 342 | val_loader = DataLoader(speechValDataset, batch_size=args.train_batch_size, 343 | shuffle=False, num_workers=4, 344 | collate_fn=SpeechCollateFn) # noqa F405 345 | test_loader = DataLoader(speechTestDataset, batch_size=1, 346 | shuffle=False, num_workers=4, 347 | collate_fn=SpeechCollateFn) # noqa F405 348 | 349 | # Prepare a dictionary of words. Used to clean up final predictions. 350 | WORD_DICT = [] 351 | ALL_LABELS = [speechTrainDataset.labels_raw, speechValDataset.labels_raw] 352 | for curr_labels in ALL_LABELS: 353 | for utt in curr_labels: 354 | words = utt.split(" ") 355 | for w in words: 356 | if w not in WORD_DICT: 357 | WORD_DICT.append(w) 358 | print('Prepared word dictionary: %d words' % (len(WORD_DICT))) 359 | print('='*20) 360 | 361 | # Set random seed. 362 | np.random.seed(DEFAULT_RANDOM_SEED) 363 | torch.manual_seed(DEFAULT_RANDOM_SEED) 364 | if device == "cuda": 365 | torch.cuda.manual_seed(DEFAULT_RANDOM_SEED) 366 | 367 | # Create the model. 368 | model = Seq2Seq(base=128, vocab_dim=speechTrainDataset.vocab_size, 369 | device=device) 370 | model.to(device) 371 | print('='*20) 372 | print(model) 373 | model_params = sum(p.size()[0] * p.size()[1] if len(p.size()) > 1 374 | else p.size()[0] for p in model.parameters()) 375 | print('Total model parameters:', model_params) 376 | print("Running on device = %s." % (device)) 377 | 378 | # Setup learning parameters. 379 | criterion = nn.CrossEntropyLoss(reduction='sum', ignore_index=IGNORE_ID) # noqa F405 380 | criterion = criterion.to(device) 381 | optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, 382 | weight_decay=WEIGHT_DECAY) 383 | scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=2, 384 | threshold=0.01, 385 | verbose=True) 386 | 387 | if args.model_path is not None: 388 | model.load_state_dict(torch.load(args.model_path, map_location=device)) 389 | print('Loaded model:', args.model_path) 390 | 391 | n_epochs = 100 392 | start_epoch = 0 393 | print('='*20) 394 | 395 | train_scoreboard = Scoreboard(sort_param_idx=0, name='Train') 396 | val_scoreboard = Scoreboard(sort_param_idx=0, name='Val') 397 | 398 | teacher_force = 0.9 399 | if args.mode == 'train': 400 | for epoch in range(start_epoch, n_epochs): 401 | print('Epoch: %d/%d' % (epoch+1, n_epochs)) 402 | train_loss = train_model(model, train_loader, criterion, optimizer, 403 | device, teacher_force, epoch+1, 404 | train_scoreboard) 405 | val_dist = val_model(model, val_loader, device, val_scoreboard) 406 | # Print scoreboards. 407 | train_scoreboard.print_scoreboard(k=10, key=SCOREBOARD_KEY) 408 | val_scoreboard.print_scoreboard(k=10, key=SCOREBOARD_KEY) 409 | train_scoreboard.flush() 410 | val_scoreboard.flush() 411 | # Checkpoint the model after each epoch. 412 | finalValDist = '%.3f' % (val_dist) 413 | if not os.path.isdir(MODEL_PATH): 414 | os.mkdir(MODEL_PATH) 415 | model_path = os.path.join(MODEL_PATH, 'model_{}_val_{}.pt'.format( 416 | time.strftime("%Y%m%d-%H%M%S"), finalValDist)) 417 | torch.save(model.state_dict(), model_path) 418 | print('='*20) 419 | # Update learning rate as required. 420 | optim_state = optimizer.state_dict() 421 | if epoch < 4: 422 | optim_state['param_groups'][0]['lr'] = 1e-4 # warmup 423 | elif epoch < 30: 424 | optim_state['param_groups'][0]['lr'] = 1e-3 425 | elif epoch < 75: 426 | optim_state['param_groups'][0]['lr'] = 1e-4 427 | else: 428 | optim_state['param_groups'][0]['lr'] = 1e-5 429 | # Teacher Forcing Schedule 430 | if epoch >= 7 and teacher_force > 0.7: 431 | teacher_force -= 0.005 432 | else: 433 | # Only testing the model. 434 | test_model(model, test_loader, device) 435 | # test_model2(model, test_loader, device, WORD_DICT) 436 | print('='*20) 437 | --------------------------------------------------------------------------------