├── README.md ├── __pycache__ ├── dataset.cpython-38.pyc ├── model.cpython-38.pyc └── specaugment.cpython-38.pyc ├── dataset.py ├── model.py ├── specaugment.py └── train_1.py /README.md: -------------------------------------------------------------------------------- 1 | # E2E-audio-speech-recognition 2 | Conformer encoder + Transformer decoder with Hybrid CTC/attention 3 | 4 | Used Datasets:LRS2 5 | -------------------------------------------------------------------------------- /__pycache__/dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lightning830/E2E-audio-speech-recognition/585b2b148876bf315a3a3eab26cfd5f132a6f8e4/__pycache__/dataset.cpython-38.pyc -------------------------------------------------------------------------------- /__pycache__/model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lightning830/E2E-audio-speech-recognition/585b2b148876bf315a3a3eab26cfd5f132a6f8e4/__pycache__/model.cpython-38.pyc -------------------------------------------------------------------------------- /__pycache__/specaugment.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lightning830/E2E-audio-speech-recognition/585b2b148876bf315a3a3eab26cfd5f132a6f8e4/__pycache__/specaugment.cpython-38.pyc -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import torch 4 | from torch.nn.utils.rnn import pad_sequence 5 | import random 6 | import numpy as np 7 | import sys 8 | from scipy.io import wavfile 9 | import librosa 10 | from torch.utils.data import Dataset 11 | from specaugment import freq_mask, time_mask 12 | 13 | 14 | class MyDataset(Dataset): 15 | def __init__(self, datasets, datadir, charToIx, audioParams, train=True): 16 | super(MyDataset, self).__init__() 17 | 18 | # dataset path list 19 | self.datalist = [] 20 | 21 | for dataset in datasets: 22 | with open(str(datadir) + "/" + str(dataset) + ".txt", "r") as f: 23 | lines = f.readlines() 24 | 25 | # pretrain 26 | if 'pre' in dataset: 27 | for line in lines: 28 | self.datalist.append(str(datadir) + "/pretrain/" + line.strip()) 29 | # main, val, test 30 | else: 31 | for line in lines: 32 | self.datalist.append(str(datadir) + "/main/" + line.strip()) 33 | 34 | self.charToIx = charToIx 35 | self.dataset = dataset 36 | self.audioParams = audioParams 37 | self.train = train 38 | return 39 | 40 | def __getitem__(self, index): 41 | # audioFile path 42 | audioFile = self.datalist[index] + ".wav" 43 | # targetFile path 44 | targetFile = self.datalist[index] + ".txt" 45 | 46 | # inp: log-mel spectrogram, shape:(inpLen, 80) 47 | # trgt: target, shape:(trgtLen,) 48 | inp, trgt, inpLen, trgtLen = self.prepare_input_logmel(self.dataset, audioFile, targetFile, 49 | self.charToIx, self.audioParams, self.train) 50 | return inp, trgt, inpLen, trgtLen 51 | 52 | def __len__(self): 53 | return len(self.datalist) 54 | 55 | 56 | 57 | def prepare_input_logmel(self, dataset, audioFile, targetFile, charToIx, audioParams, train): 58 | 59 | if targetFile is not None: 60 | 61 | #reading the target from the target file and converting each character to its corresponding index 62 | with open(targetFile, "r") as f: 63 | trgt = f.readline().strip()[7:] 64 | 65 | trgt = [charToIx[char] for char in trgt] 66 | trgt.append(charToIx[""]) 67 | trgt.insert(0, charToIx[""]) 68 | trgt = np.array(trgt) 69 | trgtLen = len(trgt) 70 | 71 | # load file 72 | sampFreq, inputAudio = wavfile.read(audioFile) 73 | inputAudio = inputAudio.astype(np.float64) 74 | # mel spectrogram 75 | mel = librosa.feature.melspectrogram(y=inputAudio, 76 | sr=sampFreq, 77 | n_mels=audioParams["Dim"], 78 | window = audioParams["Window"], 79 | n_fft=audioParams["WinLen"], 80 | hop_length=audioParams["Shift"]) 81 | 82 | # log 83 | log_mel = librosa.power_to_db(mel) # (n_mels, T) 84 | # normalize 85 | log_mel = log_mel / np.max(np.abs(log_mel)) 86 | 87 | audInp = torch.from_numpy(log_mel) # (80, T) 88 | 89 | 90 | # SpecAugument 91 | # I used https://github.com/zcaceres/spec_augment 92 | if train == True: 93 | audInp = audInp.unsqueeze(0) 94 | audInp = time_mask(freq_mask(audInp, F=27, num_masks=2),T=int(audInp.shape[2]*0.05), num_masks=2) 95 | audInp = audInp.squeeze(0) 96 | 97 | audInp = audInp.transpose(0,1) # (T, 80) 98 | 99 | # input Length: The length after applying Conv2dSubsampling. 100 | inpLen = len(audInp) >> 2 101 | inpLen -= 1 102 | inpLen = torch.tensor(inpLen) 103 | 104 | if targetFile is not None: 105 | trgt = torch.from_numpy(trgt) 106 | trgtLen = torch.tensor(trgtLen) 107 | else: 108 | trgt, trgtLen = None, None 109 | 110 | return audInp, trgt, inpLen, trgtLen 111 | 112 | 113 | def collate_fn(dataBatch): 114 | """ 115 | Collate function definition used in Dataloaders. 116 | """ 117 | 118 | inputBatch = pad_sequence([data[0] for data in dataBatch]) 119 | targetBatch = pad_sequence([data[1] for data in dataBatch]) 120 | 121 | inputLenBatch = torch.stack([data[2] for data in dataBatch]) 122 | if not any(data[3] is None for data in dataBatch): 123 | targetLenBatch = torch.stack([data[3] for data in dataBatch]) 124 | else: 125 | targetLenBatch = None 126 | 127 | return inputBatch, targetBatch, inputLenBatch, targetLenBatch -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | # https://github.com/sooftware/conformer/blob/main/conformer/encoder.py 4 | from conformer.encoder import ConformerBlock 5 | from conformer.modules import Linear 6 | import torch.nn.functional as F 7 | import os, sys 8 | 9 | 10 | from torchinfo import summary 11 | 12 | import math 13 | 14 | class PositionalEncoding(nn.Module): 15 | 16 | """ 17 | A layer to add positional encodings to the inputs of a Transformer model. 18 | Formula: 19 | PE(pos,2i) = sin(pos/10000^(2i/d_model)) 20 | PE(pos,2i+1) = cos(pos/10000^(2i/d_model)) 21 | """ 22 | 23 | def __init__(self, dModel, maxLen): 24 | super(PositionalEncoding, self).__init__() 25 | pe = torch.zeros(maxLen, dModel) 26 | position = torch.arange(0, maxLen, dtype=torch.float).unsqueeze(dim=-1) 27 | denominator = torch.exp(torch.arange(0, dModel, 2).float()*(math.log(10000.0)/dModel)) 28 | pe[:, 0::2] = torch.sin(position/denominator) 29 | pe[:, 1::2] = torch.cos(position/denominator) 30 | pe = pe.unsqueeze(dim=0).transpose(0, 1) 31 | self.register_buffer("pe", pe) 32 | 33 | 34 | def forward(self, inputBatch): 35 | outputBatch = inputBatch + self.pe[:inputBatch.shape[0],:,:] 36 | return outputBatch 37 | 38 | 39 | # -- auxiliary functions 40 | def threeD_to_2D_tensor(x): 41 | n_batch, n_channels, s_time, sx, sy = x.shape 42 | x = x.transpose(1, 2) 43 | return x.reshape(n_batch*s_time, n_channels, sx, sy) 44 | 45 | 46 | 47 | 48 | class AVSRwithConf2(nn.Module): #with log_mel 49 | def __init__(self, numClasses): 50 | super(AVSRwithConf2, self).__init__() 51 | self.e = 17 52 | self.d_k = 256 53 | self.logmel_dim = 80 54 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 55 | self.pos_encode = PositionalEncoding(dModel=self.d_k, maxLen=1000) 56 | self.sequential = nn.Sequential( 57 | nn.Conv2d(in_channels=1, out_channels=self.d_k, kernel_size=3, stride=2), 58 | nn.ReLU(), 59 | nn.Conv2d(self.d_k, self.d_k, kernel_size=3, stride=2), 60 | nn.ReLU(), 61 | ) 62 | self.input_projection = nn.Sequential( 63 | Linear(self.d_k * (((self.logmel_dim - 1) // 2 - 1) // 2), self.d_k), 64 | self.pos_encode, 65 | nn.Dropout(p=0.1), 66 | ) 67 | self.linear1 = nn.Linear(512, self.d_k) 68 | self.layers_A = nn.ModuleList([ConformerBlock(encoder_dim=self.d_k).to(self.device) for _ in range(self.e)]) 69 | self.linear3 = nn.Linear(256, 4 * self.d_k) 70 | self.linear4 = nn.Linear(1024, self.d_k) 71 | self.bn1 = nn.BatchNorm1d(num_features=4 * self.d_k) 72 | self.embeddings = nn.Embedding(numClasses, self.d_k) 73 | self.relu = nn.ReLU() 74 | self.decoder_layer = nn.TransformerDecoderLayer(d_model=self.d_k, nhead=8) 75 | self.decoder = nn.TransformerDecoder(self.decoder_layer, num_layers=6) 76 | self.linear5 = nn.Linear(256, numClasses-1) 77 | self.linear6 = nn.Linear(256, numClasses-1) 78 | 79 | 80 | def forward(self, x, tgt, tgt_mask, tgt_padding_mask): 81 | # x (Batch, T, 80) 82 | # tgt (Batch, S) 83 | # tgtmask (S, S) 84 | # tgt_paddingmask (Batch, S) 85 | 86 | # Conv2Dsubsampling https://github.com/sooftware/conformer/blob/main/conformer/convolution.py 87 | x = self.sequential(x.unsqueeze(1)) 88 | batch_size, channels, subsampled_lengths, sumsampled_dim = x.size() 89 | x = x.permute(0, 2, 1, 3) 90 | x = x.contiguous().view(batch_size, subsampled_lengths, channels * sumsampled_dim) 91 | x = self.input_projection(x) 92 | 93 | 94 | #tgt embedding module 95 | tgt = self.embeddings(tgt) #tgt=(B, T), embedding=(B,T,D) 96 | tgt = tgt.transpose(0,1) #(T,B,D) 97 | tgt = self.pos_encode(tgt) 98 | 99 | #conformer block 100 | for layer in self.layers_A: 101 | x = layer(x) 102 | 103 | # MLP 104 | x = self.linear3(x) # N,T,C 105 | x = x.transpose(1,2) 106 | x = self.bn1(x) # N,C,T 107 | x = x.transpose(1,2) 108 | x = self.relu(x) 109 | x = self.linear4(x) 110 | 111 | # to CE 112 | to_CE = x.transpose(0,1) 113 | to_CE = self.decoder(tgt=tgt, memory=to_CE, tgt_mask=tgt_mask, tgt_key_padding_mask=tgt_padding_mask) 114 | to_CE = self.linear5(to_CE) 115 | 116 | # to CTC 117 | x = self.linear6(x) 118 | x = F.log_softmax(x, dim=2) 119 | 120 | to_CE = to_CE.transpose(0,1) 121 | return to_CE, x #(B,T=len(tgt), C),(B, T=len(frame), C) 122 | 123 | def encode(self, x): 124 | 125 | # logmel subsampling, linear, dropout 126 | x = self.sequential(x.unsqueeze(1)) 127 | batch_size, channels, subsampled_lengths, sumsampled_dim = x.size() 128 | x = x.permute(0, 2, 1, 3) #(N, T, encoderdim, 19) 19はdimの80が畳みこまれた結果。固定値 129 | x = x.contiguous().view(batch_size, subsampled_lengths, channels * sumsampled_dim) #(N, T/4, encoderdim x Dim) Dim=80 130 | x = self.input_projection(x) #(N, T/4, encoderdim) 131 | 132 | # backend embedding module 133 | x = self.linear1(x) 134 | x = self.pos_encode(x) 135 | 136 | 137 | #conformer block 138 | for layer in self.layers_A: 139 | x = layer(x) 140 | 141 | # concat and MLP 142 | x = torch.cat([x, x], 2) 143 | x = self.linear3(x) # N,T,C 144 | x = x.transpose(1,2) 145 | x = self.bn1(x) # N,C,T 146 | x = x.transpose(1,2) 147 | x = self.relu(x) 148 | x = self.linear4(x) 149 | 150 | return x 151 | 152 | 153 | class CustomLoss(nn.Module): 154 | def __init__(self, ramda, beta): 155 | super().__init__() 156 | self.ramda = ramda 157 | self.bata = beta 158 | self.loss_ctc = nn.CTCLoss(zero_infinity=True) 159 | self.loss_ce = nn.CrossEntropyLoss(ignore_index=0) 160 | 161 | def forward(self, CE_x, CTC_x, tgt, in_len, tgt_len): 162 | tgt = tgt.to(torch.long) 163 | CTC_x = CTC_x.transpose(0,1) #(T, B, C) 164 | loss1 = self.ramda*(self.loss_ctc(CTC_x, tgt, in_len, tgt_len)) 165 | CE_x = CE_x.permute(0,2,1) #(B, C, T) 166 | loss2 = (1-self.ramda)*(self.loss_ce(CE_x, tgt)) 167 | loss = loss1 + loss2 168 | 169 | return loss, loss1, loss2 170 | 171 | 172 | if __name__ == '__main__': 173 | 174 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 175 | 176 | M2 = AVSRwithConf2(41) 177 | 178 | # avsrconf2 summary 179 | print(summary(M2, input_size=[(2, 160, 80), (2, 30), (30, 30), (2, 30)], dtypes=[torch.float, torch.long, torch.bool, torch.bool])) 180 | 181 | -------------------------------------------------------------------------------- /specaugment.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | 4 | def time_warp(spec, W=50): 5 | num_rows = spec.shape[2] 6 | spec_len = spec.shape[1] 7 | device = spec.device 8 | 9 | # adapted from https://github.com/DemisEom/SpecAugment/ 10 | pt = (num_rows - 2* W) * torch.rand([1], dtype=torch.float) + W # random point along the time axis 11 | src_ctr_pt_freq = torch.arange(0, spec_len // 2) # control points on freq-axis 12 | src_ctr_pt_time = torch.ones_like(src_ctr_pt_freq) * pt # control points on time-axis 13 | src_ctr_pts = torch.stack((src_ctr_pt_freq, src_ctr_pt_time), dim=-1) 14 | src_ctr_pts = src_ctr_pts.float().to(device) 15 | 16 | # Destination 17 | w = 2 * W * torch.rand([1], dtype=torch.float) - W# distance 18 | dest_ctr_pt_freq = src_ctr_pt_freq 19 | dest_ctr_pt_time = src_ctr_pt_time + w 20 | dest_ctr_pts = torch.stack((dest_ctr_pt_freq, dest_ctr_pt_time), dim=-1) 21 | dest_ctr_pts = dest_ctr_pts.float().to(device) 22 | 23 | # warp 24 | source_control_point_locations = torch.unsqueeze(src_ctr_pts, 0) # (1, v//2, 2) 25 | dest_control_point_locations = torch.unsqueeze(dest_ctr_pts, 0) # (1, v//2, 2) 26 | warped_spectro, dense_flows = sparse_image_warp(spec, source_control_point_locations, dest_control_point_locations) 27 | return warped_spectro.squeeze(3) 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | ################################################# 36 | ### THIS FILE WAS AUTOGENERATED! DO NOT EDIT! ### 37 | ################################################# 38 | # file to edit: dev_nb/SparseImageWarp.ipynb 39 | 40 | import torch 41 | 42 | def sparse_image_warp(img_tensor, 43 | source_control_point_locations, 44 | dest_control_point_locations, 45 | interpolation_order=2, 46 | regularization_weight=0.0, 47 | num_boundaries_points=0): 48 | device = img_tensor.device 49 | control_point_flows = (dest_control_point_locations - source_control_point_locations) 50 | 51 | # clamp_boundaries = num_boundary_points > 0 52 | # boundary_points_per_edge = num_boundary_points - 1 53 | batch_size, image_height, image_width = img_tensor.shape 54 | flattened_grid_locations = get_flat_grid_locations(image_height, image_width, device) 55 | 56 | # IGNORED FOR OUR BASIC VERSION... 57 | # flattened_grid_locations = constant_op.constant( 58 | # _expand_to_minibatch(flattened_grid_locations, batch_size), image.dtype) 59 | 60 | # if clamp_boundaries: 61 | # (dest_control_point_locations, 62 | # control_point_flows) = _add_zero_flow_controls_at_boundary( 63 | # dest_control_point_locations, control_point_flows, image_height, 64 | # image_width, boundary_points_per_edge) 65 | 66 | flattened_flows = interpolate_spline( 67 | dest_control_point_locations, 68 | control_point_flows, 69 | flattened_grid_locations, 70 | interpolation_order, 71 | regularization_weight) 72 | 73 | dense_flows = create_dense_flows(flattened_flows, batch_size, image_height, image_width) 74 | 75 | warped_image = dense_image_warp(img_tensor, dense_flows) 76 | 77 | return warped_image, dense_flows 78 | 79 | def get_grid_locations(image_height, image_width, device): 80 | y_range = torch.linspace(0, image_height - 1, image_height, device=device) 81 | x_range = torch.linspace(0, image_width - 1, image_width, device=device) 82 | y_grid, x_grid = torch.meshgrid(y_range, x_range) 83 | return torch.stack((y_grid, x_grid), -1) 84 | 85 | def flatten_grid_locations(grid_locations, image_height, image_width): 86 | return torch.reshape(grid_locations, [image_height * image_width, 2]) 87 | 88 | def get_flat_grid_locations(image_height, image_width, device): 89 | y_range = torch.linspace(0, image_height - 1, image_height, device=device) 90 | x_range = torch.linspace(0, image_width - 1, image_width, device=device) 91 | y_grid, x_grid = torch.meshgrid(y_range, x_range) 92 | return torch.stack((y_grid, x_grid), -1).reshape([image_height * image_width, 2]) 93 | 94 | def create_dense_flows(flattened_flows, batch_size, image_height, image_width): 95 | # possibly .view 96 | return torch.reshape(flattened_flows, [batch_size, image_height, image_width, 2]) 97 | 98 | def interpolate_spline(train_points, train_values, query_points, order, regularization_weight=0.0,): 99 | # First, fit the spline to the observed data. 100 | w, v = solve_interpolation(train_points, train_values, order, regularization_weight) 101 | # Then, evaluate the spline at the query locations. 102 | query_values = apply_interpolation(query_points, train_points, w, v, order) 103 | 104 | return query_values 105 | 106 | def solve_interpolation(train_points, train_values, order, regularization_weight, eps=1e-7): 107 | device = train_points.device 108 | b, n, d = train_points.shape 109 | k = train_values.shape[-1] 110 | 111 | # First, rename variables so that the notation (c, f, w, v, A, B, etc.) 112 | # follows https://en.wikipedia.org/wiki/Polyharmonic_spline. 113 | # To account for python style guidelines we use 114 | # matrix_a for A and matrix_b for B. 115 | 116 | c = train_points 117 | f = train_values.float() 118 | 119 | matrix_a = phi(cross_squared_distance_matrix(c,c), order).unsqueeze(0) # [b, n, n] 120 | # if regularization_weight > 0: 121 | # batch_identity_matrix = array_ops.expand_dims( 122 | # linalg_ops.eye(n, dtype=c.dtype), 0) 123 | # matrix_a += regularization_weight * batch_identity_matrix 124 | 125 | # Append ones to the feature values for the bias term in the linear model. 126 | ones = torch.ones(n, dtype=train_points.dtype, device=device).view([-1, n, 1]) 127 | matrix_b = torch.cat((c, ones), 2).float() # [b, n, d + 1] 128 | 129 | # [b, n + d + 1, n] 130 | left_block = torch.cat((matrix_a, torch.transpose(matrix_b, 2, 1)), 1) 131 | 132 | num_b_cols = matrix_b.shape[2] # d + 1 133 | 134 | # In Tensorflow, zeros are used here. Pytorch solve fails with zeros for some reason we don't understand. 135 | # So instead we use very tiny randn values (variance of one, zero mean) on one side of our multiplication. 136 | lhs_zeros = torch.randn((b, num_b_cols, num_b_cols), device=device) *eps 137 | right_block = torch.cat((matrix_b, lhs_zeros), 138 | 1) # [b, n + d + 1, d + 1] 139 | lhs = torch.cat((left_block, right_block), 140 | 2) # [b, n + d + 1, n + d + 1] 141 | 142 | rhs_zeros = torch.zeros((b, d + 1, k), dtype=train_points.dtype, device=device).float() 143 | rhs = torch.cat((f, rhs_zeros), 1) # [b, n + d + 1, k] 144 | 145 | # Then, solve the linear system and unpack the results. 146 | X, LU = torch.solve(rhs, lhs) 147 | w = X[:, :n, :] 148 | v = X[:, n:, :] 149 | return w, v 150 | 151 | def cross_squared_distance_matrix(x, y): 152 | """Pairwise squared distance between two (batch) matrices' rows (2nd dim). 153 | Computes the pairwise distances between rows of x and rows of y 154 | Args: 155 | x: [batch_size, n, d] float `Tensor` 156 | y: [batch_size, m, d] float `Tensor` 157 | Returns: 158 | squared_dists: [batch_size, n, m] float `Tensor`, where 159 | squared_dists[b,i,j] = ||x[b,i,:] - y[b,j,:]||^2 160 | """ 161 | x_norm_squared = torch.sum(torch.mul(x, x)) 162 | y_norm_squared = torch.sum(torch.mul(y, y)) 163 | 164 | x_y_transpose = torch.matmul(x.squeeze(0), y.squeeze(0).transpose(0,1)) 165 | 166 | # squared_dists[b,i,j] = ||x_bi - y_bj||^2 = x_bi'x_bi- 2x_bi'x_bj + x_bj'x_bj 167 | squared_dists = x_norm_squared - 2 * x_y_transpose + y_norm_squared 168 | 169 | return squared_dists.float() 170 | 171 | def phi(r, order): 172 | """Coordinate-wise nonlinearity used to define the order of the interpolation. 173 | See https://en.wikipedia.org/wiki/Polyharmonic_spline for the definition. 174 | Args: 175 | r: input op 176 | order: interpolation order 177 | Returns: 178 | phi_k evaluated coordinate-wise on r, for k = r 179 | """ 180 | EPSILON=torch.tensor(1e-10, device=r.device) 181 | # using EPSILON prevents log(0), sqrt0), etc. 182 | # sqrt(0) is well-defined, but its gradient is not 183 | if order == 1: 184 | r = torch.max(r, EPSILON) 185 | r = torch.sqrt(r) 186 | return r 187 | elif order == 2: 188 | return 0.5 * r * torch.log(torch.max(r, EPSILON)) 189 | elif order == 4: 190 | return 0.5 * torch.square(r) * torch.log(torch.max(r, EPSILON)) 191 | elif order % 2 == 0: 192 | r = torch.max(r, EPSILON) 193 | return 0.5 * torch.pow(r, 0.5 * order) * torch.log(r) 194 | else: 195 | r = torch.max(r, EPSILON) 196 | return torch.pow(r, 0.5 * order) 197 | 198 | def apply_interpolation(query_points, train_points, w, v, order): 199 | """Apply polyharmonic interpolation model to data. 200 | Given coefficients w and v for the interpolation model, we evaluate 201 | interpolated function values at query_points. 202 | Args: 203 | query_points: `[b, m, d]` x values to evaluate the interpolation at 204 | train_points: `[b, n, d]` x values that act as the interpolation centers 205 | ( the c variables in the wikipedia article) 206 | w: `[b, n, k]` weights on each interpolation center 207 | v: `[b, d, k]` weights on each input dimension 208 | order: order of the interpolation 209 | Returns: 210 | Polyharmonic interpolation evaluated at points defined in query_points. 211 | """ 212 | query_points = query_points.unsqueeze(0) 213 | # First, compute the contribution from the rbf term. 214 | pairwise_dists = cross_squared_distance_matrix(query_points.float(), train_points.float()) 215 | phi_pairwise_dists = phi(pairwise_dists, order) 216 | 217 | rbf_term = torch.matmul(phi_pairwise_dists, w) 218 | 219 | # Then, compute the contribution from the linear term. 220 | # Pad query_points with ones, for the bias term in the linear model. 221 | ones = torch.ones_like(query_points[..., :1]) 222 | query_points_pad = torch.cat(( 223 | query_points, 224 | ones 225 | ), 2).float() 226 | linear_term = torch.matmul(query_points_pad, v) 227 | 228 | return rbf_term + linear_term 229 | 230 | 231 | def dense_image_warp(image, flow): 232 | """Image warping using per-pixel flow vectors. 233 | Apply a non-linear warp to the image, where the warp is specified by a dense 234 | flow field of offset vectors that define the correspondences of pixel values 235 | in the output image back to locations in the source image. Specifically, the 236 | pixel value at output[b, j, i, c] is 237 | images[b, j - flow[b, j, i, 0], i - flow[b, j, i, 1], c]. 238 | The locations specified by this formula do not necessarily map to an int 239 | index. Therefore, the pixel value is obtained by bilinear 240 | interpolation of the 4 nearest pixels around 241 | (b, j - flow[b, j, i, 0], i - flow[b, j, i, 1]). For locations outside 242 | of the image, we use the nearest pixel values at the image boundary. 243 | Args: 244 | image: 4-D float `Tensor` with shape `[batch, height, width, channels]`. 245 | flow: A 4-D float `Tensor` with shape `[batch, height, width, 2]`. 246 | name: A name for the operation (optional). 247 | Note that image and flow can be of type tf.half, tf.float32, or tf.float64, 248 | and do not necessarily have to be the same type. 249 | Returns: 250 | A 4-D float `Tensor` with shape`[batch, height, width, channels]` 251 | and same type as input image. 252 | Raises: 253 | ValueError: if height < 2 or width < 2 or the inputs have the wrong number 254 | of dimensions. 255 | """ 256 | image = image.unsqueeze(3) # add a single channel dimension to image tensor 257 | batch_size, height, width, channels = image.shape 258 | device = image.device 259 | 260 | # The flow is defined on the image grid. Turn the flow into a list of query 261 | # points in the grid space. 262 | grid_x, grid_y = torch.meshgrid( 263 | torch.arange(width, device=device), torch.arange(height, device=device)) 264 | 265 | stacked_grid = torch.stack((grid_y, grid_x), dim=2).float() 266 | 267 | batched_grid = stacked_grid.unsqueeze(-1).permute(3, 1, 0, 2) 268 | 269 | query_points_on_grid = batched_grid - flow 270 | query_points_flattened = torch.reshape(query_points_on_grid, 271 | [batch_size, height * width, 2]) 272 | # Compute values at the query points, then reshape the result back to the 273 | # image grid. 274 | interpolated = interpolate_bilinear(image, query_points_flattened) 275 | interpolated = torch.reshape(interpolated, 276 | [batch_size, height, width, channels]) 277 | return interpolated 278 | 279 | def interpolate_bilinear(grid, 280 | query_points, 281 | name='interpolate_bilinear', 282 | indexing='ij'): 283 | """Similar to Matlab's interp2 function. 284 | Finds values for query points on a grid using bilinear interpolation. 285 | Args: 286 | grid: a 4-D float `Tensor` of shape `[batch, height, width, channels]`. 287 | query_points: a 3-D float `Tensor` of N points with shape `[batch, N, 2]`. 288 | name: a name for the operation (optional). 289 | indexing: whether the query points are specified as row and column (ij), 290 | or Cartesian coordinates (xy). 291 | Returns: 292 | values: a 3-D `Tensor` with shape `[batch, N, channels]` 293 | Raises: 294 | ValueError: if the indexing mode is invalid, or if the shape of the inputs 295 | invalid. 296 | """ 297 | if indexing != 'ij' and indexing != 'xy': 298 | raise ValueError('Indexing mode must be \'ij\' or \'xy\'') 299 | 300 | 301 | shape = grid.shape 302 | if len(shape) != 4: 303 | msg = 'Grid must be 4 dimensional. Received size: ' 304 | raise ValueError(msg + str(grid.shape)) 305 | 306 | batch_size, height, width, channels = grid.shape 307 | 308 | shape = [batch_size, height, width, channels] 309 | query_type = query_points.dtype 310 | grid_type = grid.dtype 311 | grid_device = grid.device 312 | 313 | num_queries = query_points.shape[1] 314 | 315 | alphas = [] 316 | floors = [] 317 | ceils = [] 318 | index_order = [0, 1] if indexing == 'ij' else [1, 0] 319 | unstacked_query_points = query_points.unbind(2) 320 | 321 | for dim in index_order: 322 | queries = unstacked_query_points[dim] 323 | 324 | size_in_indexing_dimension = shape[dim + 1] 325 | 326 | # max_floor is size_in_indexing_dimension - 2 so that max_floor + 1 327 | # is still a valid index into the grid. 328 | max_floor = torch.tensor(size_in_indexing_dimension - 2, dtype=query_type, device=grid_device) 329 | min_floor = torch.tensor(0.0, dtype=query_type, device=grid_device) 330 | maxx = torch.max(min_floor, torch.floor(queries)) 331 | floor = torch.min(maxx, max_floor) 332 | int_floor = floor.long() 333 | floors.append(int_floor) 334 | ceil = int_floor + 1 335 | ceils.append(ceil) 336 | 337 | # alpha has the same type as the grid, as we will directly use alpha 338 | # when taking linear combinations of pixel values from the image. 339 | 340 | 341 | alpha = (queries - floor).clone().detach().type(grid_type) 342 | min_alpha = torch.tensor(0.0, dtype=grid_type, device=grid_device) 343 | max_alpha = torch.tensor(1.0, dtype=grid_type, device=grid_device) 344 | alpha = torch.min(torch.max(min_alpha, alpha), max_alpha) 345 | 346 | # Expand alpha to [b, n, 1] so we can use broadcasting 347 | # (since the alpha values don't depend on the channel). 348 | alpha = torch.unsqueeze(alpha, 2) 349 | alphas.append(alpha) 350 | 351 | flattened_grid = torch.reshape( 352 | grid, [batch_size * height * width, channels]) 353 | batch_offsets = torch.reshape( 354 | torch.arange(batch_size, device=grid_device) * height * width, [batch_size, 1]) 355 | 356 | # This wraps array_ops.gather. We reshape the image data such that the 357 | # batch, y, and x coordinates are pulled into the first dimension. 358 | # Then we gather. Finally, we reshape the output back. It's possible this 359 | # code would be made simpler by using array_ops.gather_nd. 360 | def gather(y_coords, x_coords, name): 361 | linear_coordinates = batch_offsets + y_coords * width + x_coords 362 | gathered_values = torch.gather(flattened_grid.t(), 1, linear_coordinates) 363 | return torch.reshape(gathered_values, 364 | [batch_size, num_queries, channels]) 365 | 366 | # grab the pixel values in the 4 corners around each query point 367 | top_left = gather(floors[0], floors[1], 'top_left') 368 | top_right = gather(floors[0], ceils[1], 'top_right') 369 | bottom_left = gather(ceils[0], floors[1], 'bottom_left') 370 | bottom_right = gather(ceils[0], ceils[1], 'bottom_right') 371 | 372 | interp_top = alphas[1] * (top_right - top_left) + top_left 373 | interp_bottom = alphas[1] * (bottom_right - bottom_left) + bottom_left 374 | interp = alphas[0] * (interp_bottom - interp_top) + interp_top 375 | 376 | return interp 377 | 378 | 379 | #Export 380 | def freq_mask(spec, F=30, num_masks=1, replace_with_zero=False): 381 | cloned = spec.clone() 382 | num_mel_channels = cloned.shape[1] 383 | 384 | for i in range(0, num_masks): 385 | f = random.randrange(0, F) 386 | f_zero = random.randrange(0, num_mel_channels - f) 387 | 388 | # avoids randrange error if values are equal and range is empty 389 | if (f_zero == f_zero + f): return cloned 390 | 391 | mask_end = random.randrange(f_zero, f_zero + f) 392 | if (replace_with_zero): cloned[0][f_zero:mask_end] = 0 393 | else: cloned[0][f_zero:mask_end] = cloned.mean() 394 | 395 | return cloned 396 | 397 | 398 | 399 | #Export 400 | def time_mask(spec, T=40, num_masks=1, replace_with_zero=False): 401 | cloned = spec.clone() 402 | len_spectro = cloned.shape[2] 403 | 404 | for i in range(0, num_masks): 405 | t = random.randrange(0, T) 406 | t_zero = random.randrange(0, len_spectro - t) 407 | 408 | # avoids randrange error if values are equal and range is empty 409 | if (t_zero == t_zero + t): return cloned 410 | 411 | mask_end = random.randrange(t_zero, t_zero + t) 412 | if (replace_with_zero): cloned[0][:,t_zero:mask_end] = 0 413 | else: cloned[0][:,t_zero:mask_end] = cloned.mean() 414 | return cloned -------------------------------------------------------------------------------- /train_1.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch._C import device 3 | import torch.optim as optim 4 | import torch.nn as nn 5 | from torch.utils.data import DataLoader 6 | import numpy as np 7 | import matplotlib 8 | import matplotlib.pyplot as plt 9 | import os, shutil 10 | from dataset import MyDataset 11 | from dataset import collate_fn 12 | from model import AVSRwithConf2, CustomLoss 13 | from tqdm import tqdm 14 | import sys 15 | from collections import OrderedDict 16 | 17 | 18 | 19 | #character to index mapping 20 | CHAR_TO_INDEX = {" ":1, "'":22, "1":30, "0":29, "3":37, "2":32, "5":34, "4":38, "7":36, "6":35, "9":31, "8":33, 21 | "A":5, "C":17, "B":20, "E":2, "D":12, "G":16, "F":19, "I":6, "H":9, "K":24, "J":25, "M":18, 22 | "L":11, "O":4, "N":7, "Q":27, "P":21, "S":8, "R":10, "U":13, "T":3, "W":15, "V":23, "Y":14, 23 | "X":26, "Z":28, "":39, "":40} 24 | 25 | #index to character reverse mapping 26 | INDEX_TO_CHAR = {1:" ", 22:"'", 30:"1", 29:"0", 37:"3", 32:"2", 34:"5", 38:"4", 36:"7", 35:"6", 31:"9", 33:"8", 27 | 5:"A", 17:"C", 20:"B", 2:"E", 12:"D", 16:"G", 19:"F", 6:"I", 9:"H", 24:"K", 25:"J", 18:"M", 28 | 11:"L", 4:"O", 7:"N", 27:"Q", 21:"P", 8:"S", 10:"R", 13:"U", 3:"T", 15:"W", 23:"V", 14:"Y", 29 | 26:"X", 28:"Z", 39:"", 40:""} 30 | # zero padding 31 | PAD_IDX = 0 32 | # the number of characters = CHAR_TO_INDEX + zero(padding) 33 | NUM_CLASS = 41 34 | #absolute path to the data directory 35 | DATA_DIRECTORY = 'C:/Users/test/Desktop/python/datasets/mvlrs_v1' 36 | # batch 37 | BATCH_SIZE = 2 38 | # epoch 39 | NUM_STEPS = 50 40 | 41 | 42 | def generate_square_subsequent_mask(sz, device): 43 | mask = (torch.triu(torch.ones((sz, sz), device=device)) == 1).transpose(0, 1) 44 | mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) 45 | return mask 46 | 47 | 48 | def create_mask(tgt, device): 49 | tgt_seq_len = tgt.shape[0] 50 | tgt_mask = generate_square_subsequent_mask(tgt_seq_len, device) 51 | tgt_padding_mask = (tgt == PAD_IDX).transpose(0, 1) 52 | return tgt_mask, tgt_padding_mask #(S, S), (Batch, S) 53 | 54 | 55 | def train(model, trainLoader, optimizer, loss_function, device, trainParams): 56 | 57 | trainingLoss = 0 58 | 59 | with tqdm(trainLoader, leave=False, desc="Train", ncols=75) as pbar: 60 | for batch, (inputBatch, targetBatch, inputLenBatch, targetLenBatch) in enumerate(pbar): 61 | 62 | inputBatch_A, targetBatch = (inputBatch.float()).to(device), (targetBatch.long()).to(device) 63 | inputLenBatch, targetLenBatch = (inputLenBatch.int()).to(device), (targetLenBatch.int()).to(device) # (Batch), (Batch) 64 | inputBatch_A = inputBatch_A.transpose(0,1) #(Batch, T, 80) 65 | 66 | tgt_mask, tgt_padding_mask = create_mask(targetBatch[:-1,:], device) #(Batch, S, S), (Batch, S) 67 | 68 | targetBatch = targetBatch.transpose(0,1) #(Batch, S) 69 | 70 | optimizer.zero_grad() 71 | model.train() 72 | output_Att, output_CTC = model(inputBatch_A, targetBatch[:,:-1], tgt_mask, tgt_padding_mask) 73 | 74 | with torch.backends.cudnn.flags(enabled=False): 75 | loss, ctc, ce = loss_function(output_Att, output_CTC, targetBatch[:,1:], inputLenBatch, targetLenBatch-1) 76 | pbar.set_postfix(OrderedDict(CTC=ctc.item(), CE=ce.item())) 77 | loss.backward() 78 | optimizer.step() 79 | 80 | trainingLoss = trainingLoss + loss.item() 81 | 82 | trainingLoss = trainingLoss/len(trainLoader) 83 | return trainingLoss 84 | 85 | 86 | 87 | def evaluate(model, evalLoader, loss_function, device, evalParams): 88 | 89 | evalLoss = 0 90 | 91 | with tqdm(evalLoader, leave=False, desc="Eval", ncols=75) as pbar: 92 | 93 | for batch, (inputBatch, targetBatch, inputLenBatch, targetLenBatch) in enumerate(pbar): 94 | inputBatch_A, targetBatch = (inputBatch.float()).to(device), (targetBatch.long()).to(device) 95 | inputLenBatch, targetLenBatch = (inputLenBatch.long()).to(device), (targetLenBatch.long()).to(device) # (Batch,), (Batch,) 96 | inputBatch_A = inputBatch_A.transpose(0,1) #(Batch, T, 80) 97 | tgt_mask, tgt_padding_mask = create_mask(targetBatch[:-1,:], device) #(S, S), (Batch, S) 98 | targetBatch = targetBatch.transpose(0,1) #(Batch, S) 99 | 100 | model.eval() 101 | with torch.no_grad(): 102 | output_Att, output_CTC = model(inputBatch_A, targetBatch[:,:-1], tgt_mask, tgt_padding_mask) 103 | with torch.backends.cudnn.flags(enabled=False): 104 | loss, ctc, ce = loss_function(output_Att, output_CTC, targetBatch[:,1:], inputLenBatch, targetLenBatch-1) 105 | pbar.set_postfix(OrderedDict(CTC=ctc.item(), CE=ce.item())) 106 | 107 | evalLoss = evalLoss + loss.item() 108 | 109 | evalLoss = evalLoss/len(evalLoader) 110 | return evalLoss 111 | 112 | 113 | def main(): 114 | 115 | matplotlib.use("Agg") 116 | 117 | #seed for random number generators 118 | np.random.seed(19220297) 119 | torch.manual_seed(19220297) 120 | 121 | #use gpu 122 | gpuAvailable = torch.cuda.is_available() 123 | device = torch.device("cuda" if gpuAvailable else "cpu") 124 | kwargs = {"num_workers": 6, "pin_memory": True} if gpuAvailable else {} 125 | torch.backends.cudnn.deterministic = True 126 | torch.backends.cudnn.benchmark = False 127 | 128 | # declaring the pretrain and the preval datasets and the corresponding dataloaders 129 | # Window = window to use while computing log-mel spectrogram 130 | # WinLen = window size 131 | # Shift = hop length 132 | # Dim = log-mel feature's dimention 133 | audioParams = {"Window":"hann", "WinLen":512, "Shift":160, "Dim":80} 134 | trainData = MyDataset(["pretrain","train"], DATA_DIRECTORY, CHAR_TO_INDEX, audioParams) 135 | valData = MyDataset(["val"], DATA_DIRECTORY, CHAR_TO_INDEX, audioParams, train=False) 136 | 137 | trainLoader = DataLoader(trainData, batch_size=BATCH_SIZE, collate_fn=collate_fn, shuffle=True, **kwargs) 138 | valLoader = DataLoader(valData, batch_size=BATCH_SIZE, collate_fn=collate_fn, shuffle=True, **kwargs) 139 | 140 | #declaring the model 141 | model = AVSRwithConf2(NUM_CLASS) 142 | model.to(device) 143 | 144 | optimizer = optim.Adam(model.parameters(), lr=0.0004, betas=(0.9, 0.98), eps=1e-9) 145 | scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, 146 | patience=5, threshold=0.001, 147 | threshold_mode="abs", min_lr=1e-6, verbose=True) 148 | 149 | # CTC attention loss: ramda*CTC + (1-ramda)*Attention 150 | loss_function = CustomLoss(ramda=0.1, beta=0.6) 151 | 152 | # loss curve 153 | trainingLossCurve = list() 154 | validationLossCurve = list() 155 | 156 | # Train 157 | print("/nTraining the model .... /n") 158 | 159 | trainParams = {"spaceIx":CHAR_TO_INDEX[" "], "eosIx":CHAR_TO_INDEX[""]} 160 | valParams = {"decodeScheme":"greedy", "spaceIx":CHAR_TO_INDEX[" "], "eosIx":CHAR_TO_INDEX[""]} 161 | 162 | for step in range(NUM_STEPS): 163 | 164 | #train the model for one step 165 | trainingLoss= train(model, trainLoader, optimizer, loss_function, device, trainParams) 166 | trainingLossCurve.append(trainingLoss) 167 | 168 | #evaluate the model on validation set 169 | validationLoss= evaluate(model, valLoader, loss_function, device, valParams) 170 | validationLossCurve.append(validationLoss) 171 | 172 | #printing the stats after each step 173 | print("Step: %03d || Tr.Loss: %.6f Val.Loss: %.6f" 174 | %(step, trainingLoss, validationLoss)) 175 | 176 | #make a scheduler step 177 | scheduler.step(validationLoss) 178 | 179 | #saving the model weights and loss/metric curves in the checkpoints directory 180 | savePath = "./checkpoints/train-step_{:04d}.pt".format(step) 181 | torch.save(model.state_dict(), savePath) 182 | 183 | plt.figure() 184 | plt.title("Loss Curves") 185 | plt.xlabel("Step No.") 186 | plt.ylabel("Loss value") 187 | plt.plot(list(range(1, len(trainingLossCurve)+1)), trainingLossCurve, "blue", label="Train") 188 | plt.plot(list(range(1, len(validationLossCurve)+1)), validationLossCurve, "red", label="Validation") 189 | plt.legend() 190 | plt.savefig("./checkpoints/train-step_{:04d}-loss.png".format(step)) 191 | plt.close() 192 | 193 | print("/nTraining Done./n") 194 | return 195 | 196 | 197 | if __name__ == "__main__": 198 | main() --------------------------------------------------------------------------------