├── .gitignore ├── infer.py ├── model.py └── train-rnn.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.pt 3 | -------------------------------------------------------------------------------- /infer.py: -------------------------------------------------------------------------------- 1 | threshold = 0.07 # You can use threshold values you settled on for the noteobok, or adjust until it's right 2 | num_layers = 12 3 | model_path = f"model-{num_layers}.pt" 4 | do_ascii_bell = True # If you're piping desktop audio this might create a feedback loop 5 | 6 | import sys 7 | import kaldi_native_fbank as knf 8 | import numpy as np 9 | import torch 10 | import torch.nn.functional as F 11 | from model import StreamingTurnModel 12 | 13 | print(""" 14 | This script reads PCM16 from stdin, so make sure to pipe something or else nothing will happen. 15 | 16 | Run this with desktop audio: 17 | $ parec --format=s16 --rate=16000 --channels=1 --latency-ms=100 --device=@DEFAULT_MONITOR@ | python3 infer.py 18 | 19 | Run this with mic audio: 20 | $ parec --format=s16 --rate=16000 --channels=1 --latency-ms=100 | python3 infer.py 21 | 22 | 23 | Download a sample model from 24 | https://sapples.net/ckpt/turn-detection/model-12.pt 25 | """) 26 | 27 | model = StreamingTurnModel(None, num_layers=num_layers) 28 | print(model.load_state_dict(torch.load(model_path, map_location='cpu', weights_only=True))) 29 | 30 | model = model.eval() 31 | 32 | def predict_endpoint(model, frames, states): 33 | with torch.no_grad(): 34 | frames = torch.tensor(frames).unsqueeze(0) 35 | frame_lengths = torch.tensor([9]).int() 36 | result, result_lens, states = model(frames, frame_lengths, states=states) 37 | result = F.sigmoid(result) 38 | return result.item(), states 39 | 40 | opts = knf.FbankOptions() 41 | opts.frame_opts.samp_freq = 16000 42 | opts.frame_opts.frame_shift_ms = 10.0 43 | opts.frame_opts.frame_length_ms = 25.0 44 | opts.frame_opts.snip_edges = True 45 | opts.mel_opts.num_bins = 80 46 | 47 | fbank = knf.OnlineFbank(opts) 48 | 49 | 50 | fbank_head = 0 51 | states = ( 52 | torch.zeros(model.encoder.num_encoder_layers, 1, model.encoder.d_model), 53 | torch.zeros(model.encoder.num_encoder_layers, 1, model.encoder.rnn_hidden_size) 54 | ) 55 | 56 | last_belled = False 57 | 58 | while True: 59 | data = sys.stdin.buffer.read(2 * 400) 60 | data = np.frombuffer(data, dtype=np.int16).astype(np.float32) 61 | data = data / 32768.0 62 | fbank.accept_waveform(16_000, data.tolist()) 63 | 64 | if (fbank.num_frames_ready - fbank_head) >= 9: 65 | frames = [fbank.get_frame(fbank_head + i).tolist() for i in range(9)] 66 | fbank_head += 4 67 | is_endpoint, states = predict_endpoint(model, frames, states) 68 | print(is_endpoint) 69 | if is_endpoint > threshold: 70 | if not last_belled: 71 | print("\n\nDing!\n\n") 72 | if do_ascii_bell: 73 | print("\a") 74 | last_belled = True 75 | else: 76 | last_belled = False 77 | else: 78 | continue 79 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | # This code is taken from various files around https://github.com/k2-fsa/icefall 2 | 3 | import copy 4 | import collections 5 | import random 6 | from itertools import repeat 7 | from typing import Optional, Tuple, List 8 | 9 | import torch 10 | import torch.backends.cudnn.rnn as rnn 11 | import torch.nn as nn 12 | from torch import _VF, Tensor 13 | 14 | 15 | # Pytorch issue: https://github.com/pytorch/pytorch/issues/47379 16 | # Fixed: https://github.com/pytorch/pytorch/pull/49853 17 | # The fix was included in v1.9.0 18 | # https://github.com/pytorch/pytorch/releases/tag/v1.9.0 19 | def is_jit_tracing(): 20 | if torch.jit.is_scripting(): 21 | return False 22 | elif torch.jit.is_tracing(): 23 | return True 24 | return False 25 | 26 | class EncoderInterface(nn.Module): 27 | def forward( 28 | self, x: torch.Tensor, x_lens: torch.Tensor 29 | ) -> Tuple[torch.Tensor, torch.Tensor]: 30 | """ 31 | Args: 32 | x: 33 | A tensor of shape (batch_size, input_seq_len, num_features) 34 | containing the input features. 35 | x_lens: 36 | A tensor of shape (batch_size,) containing the number of frames 37 | in `x` before padding. 38 | Returns: 39 | Return a tuple containing two tensors: 40 | - encoder_out, a tensor of (batch_size, out_seq_len, output_dim) 41 | containing unnormalized probabilities, i.e., the output of a 42 | linear layer. 43 | - encoder_out_lens, a tensor of shape (batch_size,) containing 44 | the number of frames in `encoder_out` before padding. 45 | """ 46 | raise NotImplementedError("Please implement it in a subclass") 47 | 48 | 49 | 50 | class GradientFilterFunction(torch.autograd.Function): 51 | @staticmethod 52 | def forward( 53 | ctx, 54 | x: Tensor, 55 | batch_dim: int, # e.g., 1 56 | threshold: float, # e.g., 10.0 57 | *params: Tensor, # module parameters 58 | ) -> Tuple[Tensor, ...]: 59 | if x.requires_grad: 60 | if batch_dim < 0: 61 | batch_dim += x.ndim 62 | ctx.batch_dim = batch_dim 63 | ctx.threshold = threshold 64 | return (x,) + params 65 | 66 | @staticmethod 67 | def backward( 68 | ctx, 69 | x_grad: Tensor, 70 | *param_grads: Tensor, 71 | ) -> Tuple[Tensor, ...]: 72 | eps = 1.0e-20 73 | dim = ctx.batch_dim 74 | norm_dims = [d for d in range(x_grad.ndim) if d != dim] 75 | norm_of_batch = (x_grad ** 2).mean(dim=norm_dims, keepdim=True).sqrt() 76 | median_norm = norm_of_batch.median() 77 | 78 | cutoff = median_norm * ctx.threshold 79 | inv_mask = (cutoff + norm_of_batch) / (cutoff + eps) 80 | mask = 1.0 / (inv_mask + eps) 81 | x_grad = x_grad * mask 82 | 83 | avg_mask = 1.0 / (inv_mask.mean() + eps) 84 | param_grads = [avg_mask * g for g in param_grads] 85 | 86 | return (x_grad, None, None) + tuple(param_grads) 87 | 88 | 89 | class GradientFilter(torch.nn.Module): 90 | """This is used to filter out elements that have extremely large gradients 91 | in batch and the module parameters with soft masks. 92 | 93 | Args: 94 | batch_dim (int): 95 | The batch dimension. 96 | threshold (float): 97 | For each element in batch, its gradient will be 98 | filtered out if the gradient norm is larger than 99 | `grad_norm_threshold * median`, where `median` is the median 100 | value of gradient norms of all elememts in batch. 101 | """ 102 | 103 | def __init__(self, batch_dim: int = 1, threshold: float = 10.0): 104 | super(GradientFilter, self).__init__() 105 | self.batch_dim = batch_dim 106 | self.threshold = threshold 107 | 108 | def forward(self, x: Tensor, *params: Tensor) -> Tuple[Tensor, ...]: 109 | if torch.jit.is_scripting() or is_jit_tracing(): 110 | return (x,) + params 111 | else: 112 | return GradientFilterFunction.apply( 113 | x, 114 | self.batch_dim, 115 | self.threshold, 116 | *params, 117 | ) 118 | 119 | 120 | 121 | class BasicNorm(torch.nn.Module): 122 | """ 123 | This is intended to be a simpler, and hopefully cheaper, replacement for 124 | LayerNorm. The observation this is based on, is that Transformer-type 125 | networks, especially with pre-norm, sometimes seem to set one of the 126 | feature dimensions to a large constant value (e.g. 50), which "defeats" 127 | the LayerNorm because the output magnitude is then not strongly dependent 128 | on the other (useful) features. Presumably the weight and bias of the 129 | LayerNorm are required to allow it to do this. 130 | 131 | So the idea is to introduce this large constant value as an explicit 132 | parameter, that takes the role of the "eps" in LayerNorm, so the network 133 | doesn't have to do this trick. We make the "eps" learnable. 134 | 135 | Args: 136 | num_channels: the number of channels, e.g. 512. 137 | channel_dim: the axis/dimension corresponding to the channel, 138 | interprted as an offset from the input's ndim if negative. 139 | shis is NOT the num_channels; it should typically be one of 140 | {-2, -1, 0, 1, 2, 3}. 141 | eps: the initial "epsilon" that we add as ballast in: 142 | scale = ((input_vec**2).mean() + epsilon)**-0.5 143 | Note: our epsilon is actually large, but we keep the name 144 | to indicate the connection with conventional LayerNorm. 145 | learn_eps: if true, we learn epsilon; if false, we keep it 146 | at the initial value. 147 | """ 148 | 149 | def __init__( 150 | self, 151 | num_channels: int, 152 | channel_dim: int = -1, # CAUTION: see documentation. 153 | eps: float = 0.25, 154 | learn_eps: bool = True, 155 | ) -> None: 156 | super(BasicNorm, self).__init__() 157 | self.num_channels = num_channels 158 | self.channel_dim = channel_dim 159 | if learn_eps: 160 | self.eps = nn.Parameter(torch.tensor(eps).log().detach()) 161 | else: 162 | self.register_buffer("eps", torch.tensor(eps).log().detach()) 163 | 164 | def forward(self, x: Tensor) -> Tensor: 165 | if not is_jit_tracing(): 166 | assert x.shape[self.channel_dim] == self.num_channels 167 | scales = ( 168 | torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) 169 | + self.eps.exp() 170 | ) ** -0.5 171 | return x * scales 172 | 173 | class ScaledLinear(nn.Linear): 174 | """ 175 | A modified version of nn.Linear where the parameters are scaled before 176 | use, via: 177 | weight = self.weight * self.weight_scale.exp() 178 | bias = self.bias * self.bias_scale.exp() 179 | 180 | Args: 181 | Accepts the standard args and kwargs that nn.Linear accepts 182 | e.g. in_features, out_features, bias=False. 183 | 184 | initial_scale: you can override this if you want to increase 185 | or decrease the initial magnitude of the module's output 186 | (affects the initialization of weight_scale and bias_scale). 187 | Another option, if you want to do something like this, is 188 | to re-initialize the parameters. 189 | initial_speed: this affects how fast the parameter will 190 | learn near the start of training; you can set it to a 191 | value less than one if you suspect that a module 192 | is contributing to instability near the start of training. 193 | Nnote: regardless of the use of this option, it's best to 194 | use schedulers like Noam that have a warm-up period. 195 | Alternatively you can set it to more than 1 if you want it to 196 | initially train faster. Must be greater than 0. 197 | """ 198 | 199 | def __init__( 200 | self, 201 | *args, 202 | initial_scale: float = 1.0, 203 | initial_speed: float = 1.0, 204 | **kwargs, 205 | ): 206 | super(ScaledLinear, self).__init__(*args, **kwargs) 207 | initial_scale = torch.tensor(initial_scale).log() 208 | self.weight_scale = nn.Parameter(initial_scale.clone().detach()) 209 | if self.bias is not None: 210 | self.bias_scale = nn.Parameter(initial_scale.clone().detach()) 211 | else: 212 | self.register_parameter("bias_scale", None) 213 | 214 | self._reset_parameters( 215 | initial_speed 216 | ) # Overrides the reset_parameters in nn.Linear 217 | 218 | def _reset_parameters(self, initial_speed: float): 219 | std = 0.1 / initial_speed 220 | a = (3 ** 0.5) * std 221 | nn.init.uniform_(self.weight, -a, a) 222 | if self.bias is not None: 223 | nn.init.constant_(self.bias, 0.0) 224 | fan_in = self.weight.shape[1] * self.weight[0][0].numel() 225 | scale = fan_in ** -0.5 # 1/sqrt(fan_in) 226 | with torch.no_grad(): 227 | self.weight_scale += torch.tensor(scale / std).log() 228 | 229 | def get_weight(self): 230 | return self.weight * self.weight_scale.exp() 231 | 232 | def get_bias(self): 233 | if self.bias is None or self.bias_scale is None: 234 | return None 235 | else: 236 | return self.bias * self.bias_scale.exp() 237 | 238 | def forward(self, input: Tensor) -> Tensor: 239 | return torch.nn.functional.linear( 240 | input, self.get_weight(), self.get_bias() 241 | ) 242 | 243 | 244 | class ScaledConv2d(nn.Conv2d): 245 | # See docs for ScaledLinear 246 | def __init__( 247 | self, 248 | *args, 249 | initial_scale: float = 1.0, 250 | initial_speed: float = 1.0, 251 | **kwargs, 252 | ): 253 | super(ScaledConv2d, self).__init__(*args, **kwargs) 254 | initial_scale = torch.tensor(initial_scale).log() 255 | self.weight_scale = nn.Parameter(initial_scale.clone().detach()) 256 | if self.bias is not None: 257 | self.bias_scale = nn.Parameter(initial_scale.clone().detach()) 258 | else: 259 | self.register_parameter("bias_scale", None) 260 | self._reset_parameters( 261 | initial_speed 262 | ) # Overrides the reset_parameters in base class 263 | 264 | def _reset_parameters(self, initial_speed: float): 265 | std = 0.1 / initial_speed 266 | a = (3 ** 0.5) * std 267 | nn.init.uniform_(self.weight, -a, a) 268 | if self.bias is not None: 269 | nn.init.constant_(self.bias, 0.0) 270 | fan_in = self.weight.shape[1] * self.weight[0][0].numel() 271 | scale = fan_in ** -0.5 # 1/sqrt(fan_in) 272 | with torch.no_grad(): 273 | self.weight_scale += torch.tensor(scale / std).log() 274 | 275 | def get_weight(self): 276 | return self.weight * self.weight_scale.exp() 277 | 278 | def get_bias(self): 279 | # see https://github.com/pytorch/pytorch/issues/24135 280 | bias = self.bias 281 | bias_scale = self.bias_scale 282 | if bias is None or bias_scale is None: 283 | return None 284 | else: 285 | return bias * bias_scale.exp() 286 | 287 | def _conv_forward(self, input, weight): 288 | F = torch.nn.functional 289 | if self.padding_mode != "zeros": 290 | return F.conv2d( 291 | F.pad( 292 | input, 293 | self._reversed_padding_repeated_twice, 294 | mode=self.padding_mode, 295 | ), 296 | weight, 297 | self.get_bias(), 298 | self.stride, 299 | (0, 0), 300 | self.dilation, 301 | self.groups, 302 | ) 303 | return F.conv2d( 304 | input, 305 | weight, 306 | self.get_bias(), 307 | self.stride, 308 | self.padding, 309 | self.dilation, 310 | self.groups, 311 | ) 312 | 313 | def forward(self, input: Tensor) -> Tensor: 314 | return self._conv_forward(input, self.get_weight()) 315 | 316 | 317 | class ScaledLSTM(nn.LSTM): 318 | # See docs for ScaledLinear. 319 | # This class implements LSTM with scaling mechanism, using `torch._VF.lstm` 320 | # Please refer to https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/rnn.py 321 | def __init__( 322 | self, 323 | *args, 324 | initial_scale: float = 1.0, 325 | initial_speed: float = 1.0, 326 | grad_norm_threshold: float = 10.0, 327 | **kwargs, 328 | ): 329 | if "bidirectional" in kwargs: 330 | assert kwargs["bidirectional"] is False 331 | super(ScaledLSTM, self).__init__(*args, **kwargs) 332 | initial_scale = torch.tensor(initial_scale).log() 333 | self._scales_names = [] 334 | self._scales = [] 335 | for name in self._flat_weights_names: 336 | scale_name = name + "_scale" 337 | self._scales_names.append(scale_name) 338 | param = nn.Parameter(initial_scale.clone().detach()) 339 | setattr(self, scale_name, param) 340 | self._scales.append(param) 341 | 342 | self.grad_filter = GradientFilter( 343 | batch_dim=1, threshold=grad_norm_threshold 344 | ) 345 | 346 | self._reset_parameters( 347 | initial_speed 348 | ) # Overrides the reset_parameters in base class 349 | 350 | def _reset_parameters(self, initial_speed: float): 351 | std = 0.1 / initial_speed 352 | a = (3 ** 0.5) * std 353 | scale = self.hidden_size ** -0.5 354 | v = scale / std 355 | for idx, name in enumerate(self._flat_weights_names): 356 | if "weight" in name: 357 | nn.init.uniform_(self._flat_weights[idx], -a, a) 358 | with torch.no_grad(): 359 | self._scales[idx] += torch.tensor(v).log() 360 | elif "bias" in name: 361 | nn.init.constant_(self._flat_weights[idx], 0.0) 362 | 363 | def _flatten_parameters(self, flat_weights) -> None: 364 | """Resets parameter data pointer so that they can use faster code paths. 365 | 366 | Right now, this works only if the module is on the GPU and cuDNN is enabled. 367 | Otherwise, it's a no-op. 368 | 369 | This function is modified from https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/rnn.py # noqa 370 | """ 371 | # Short-circuits if _flat_weights is only partially instantiated 372 | if len(flat_weights) != len(self._flat_weights_names): 373 | return 374 | 375 | for w in flat_weights: 376 | if not isinstance(w, Tensor): 377 | return 378 | # Short-circuits if any tensor in flat_weights is not acceptable to cuDNN 379 | # or the tensors in flat_weights are of different dtypes 380 | 381 | first_fw = flat_weights[0] 382 | dtype = first_fw.dtype 383 | for fw in flat_weights: 384 | if ( 385 | not isinstance(fw.data, Tensor) 386 | or not (fw.data.dtype == dtype) 387 | or not fw.data.is_cuda 388 | or not torch.backends.cudnn.is_acceptable(fw.data) 389 | ): 390 | return 391 | 392 | # If any parameters alias, we fall back to the slower, copying code path. This is 393 | # a sufficient check, because overlapping parameter buffers that don't completely 394 | # alias would break the assumptions of the uniqueness check in 395 | # Module.named_parameters(). 396 | unique_data_ptrs = set(p.data_ptr() for p in flat_weights) 397 | if len(unique_data_ptrs) != len(flat_weights): 398 | return 399 | 400 | with torch.cuda.device_of(first_fw): 401 | 402 | # Note: no_grad() is necessary since _cudnn_rnn_flatten_weight is 403 | # an inplace operation on self._flat_weights 404 | with torch.no_grad(): 405 | if torch._use_cudnn_rnn_flatten_weight(): 406 | num_weights = 4 if self.bias else 2 407 | if self.proj_size > 0: 408 | num_weights += 1 409 | torch._cudnn_rnn_flatten_weight( 410 | flat_weights, 411 | num_weights, 412 | self.input_size, 413 | rnn.get_cudnn_mode(self.mode), 414 | self.hidden_size, 415 | self.proj_size, 416 | self.num_layers, 417 | self.batch_first, 418 | bool(self.bidirectional), 419 | ) 420 | 421 | def _get_flat_weights(self): 422 | """Get scaled weights, and resets their data pointer.""" 423 | flat_weights = [] 424 | for idx in range(len(self._flat_weights_names)): 425 | flat_weights.append( 426 | self._flat_weights[idx] * self._scales[idx].exp() 427 | ) 428 | self._flatten_parameters(flat_weights) 429 | return flat_weights 430 | 431 | def forward( 432 | self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None 433 | ): 434 | # This function is modified from https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/rnn.py # noqa 435 | # The change for calling `_VF.lstm()` is: 436 | # self._flat_weights -> self._get_flat_weights() 437 | if hx is None: 438 | h_zeros = torch.zeros( 439 | self.num_layers, 440 | input.size(1), 441 | self.proj_size if self.proj_size > 0 else self.hidden_size, 442 | dtype=input.dtype, 443 | device=input.device, 444 | ) 445 | c_zeros = torch.zeros( 446 | self.num_layers, 447 | input.size(1), 448 | self.hidden_size, 449 | dtype=input.dtype, 450 | device=input.device, 451 | ) 452 | hx = (h_zeros, c_zeros) 453 | 454 | self.check_forward_args(input, hx, None) 455 | 456 | flat_weights = self._get_flat_weights() 457 | input, *flat_weights = self.grad_filter(input, *flat_weights) 458 | 459 | result = _VF.lstm( 460 | input, 461 | hx, 462 | flat_weights, 463 | self.bias, 464 | self.num_layers, 465 | self.dropout, 466 | self.training, 467 | self.bidirectional, 468 | self.batch_first, 469 | ) 470 | 471 | output = result[0] 472 | hidden = result[1:] 473 | return output, hidden 474 | 475 | class ActivationBalancerFunction(torch.autograd.Function): 476 | @staticmethod 477 | def forward( 478 | ctx, 479 | x: Tensor, 480 | channel_dim: int, 481 | min_positive: float, # e.g. 0.05 482 | max_positive: float, # e.g. 0.95 483 | max_factor: float, # e.g. 0.01 484 | min_abs: float, # e.g. 0.2 485 | max_abs: float, # e.g. 100.0 486 | ) -> Tensor: 487 | if x.requires_grad: 488 | if channel_dim < 0: 489 | channel_dim += x.ndim 490 | 491 | # sum_dims = [d for d in range(x.ndim) if d != channel_dim] 492 | # The above line is not torch scriptable for torch 1.6.0 493 | # torch.jit.frontend.NotSupportedError: comprehension ifs not supported yet: # noqa 494 | sum_dims = [] 495 | for d in range(x.ndim): 496 | if d != channel_dim: 497 | sum_dims.append(d) 498 | 499 | xgt0 = x > 0 500 | proportion_positive = torch.mean( 501 | xgt0.to(x.dtype), dim=sum_dims, keepdim=True 502 | ) 503 | factor1 = ( 504 | (min_positive - proportion_positive).relu() 505 | * (max_factor / min_positive) 506 | if min_positive != 0.0 507 | else 0.0 508 | ) 509 | factor2 = ( 510 | (proportion_positive - max_positive).relu() 511 | * (max_factor / (max_positive - 1.0)) 512 | if max_positive != 1.0 513 | else 0.0 514 | ) 515 | factor = factor1 + factor2 516 | if isinstance(factor, float): 517 | factor = torch.zeros_like(proportion_positive) 518 | 519 | mean_abs = torch.mean(x.abs(), dim=sum_dims, keepdim=True) 520 | below_threshold = mean_abs < min_abs 521 | above_threshold = mean_abs > max_abs 522 | 523 | ctx.save_for_backward( 524 | factor, xgt0, below_threshold, above_threshold 525 | ) 526 | ctx.max_factor = max_factor 527 | ctx.sum_dims = sum_dims 528 | return x 529 | 530 | @staticmethod 531 | def backward( 532 | ctx, x_grad: Tensor 533 | ) -> Tuple[Tensor, None, None, None, None, None, None]: 534 | factor, xgt0, below_threshold, above_threshold = ctx.saved_tensors 535 | dtype = x_grad.dtype 536 | scale_factor = ( 537 | (below_threshold.to(dtype) - above_threshold.to(dtype)) 538 | * (xgt0.to(dtype) - 0.5) 539 | * (ctx.max_factor * 2.0) 540 | ) 541 | 542 | neg_delta_grad = x_grad.abs() * (factor + scale_factor) 543 | return x_grad - neg_delta_grad, None, None, None, None, None, None 544 | 545 | 546 | class ActivationBalancer(torch.nn.Module): 547 | """ 548 | Modifies the backpropped derivatives of a function to try to encourage, for 549 | each channel, that it is positive at least a proportion `threshold` of the 550 | time. It does this by multiplying negative derivative values by up to 551 | (1+max_factor), and positive derivative values by up to (1-max_factor), 552 | interpolated from 1 at the threshold to those extremal values when none 553 | of the inputs are positive. 554 | 555 | 556 | Args: 557 | channel_dim: the dimension/axis corresponding to the channel, e.g. 558 | -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative. 559 | min_positive: the minimum, per channel, of the proportion of the time 560 | that (x > 0), below which we start to modify the derivatives. 561 | max_positive: the maximum, per channel, of the proportion of the time 562 | that (x > 0), above which we start to modify the derivatives. 563 | max_factor: the maximum factor by which we modify the derivatives for 564 | either the sign constraint or the magnitude constraint; 565 | e.g. with max_factor=0.02, the the derivatives would be multiplied by 566 | values in the range [0.98..1.02]. 567 | min_abs: the minimum average-absolute-value per channel, which 568 | we allow, before we start to modify the derivatives to prevent 569 | this. 570 | max_abs: the maximum average-absolute-value per channel, which 571 | we allow, before we start to modify the derivatives to prevent 572 | this. 573 | balance_prob: the probability to apply the ActivationBalancer. 574 | """ 575 | 576 | def __init__( 577 | self, 578 | channel_dim: int, 579 | min_positive: float = 0.05, 580 | max_positive: float = 0.95, 581 | max_factor: float = 0.01, 582 | min_abs: float = 0.2, 583 | max_abs: float = 100.0, 584 | balance_prob: float = 0.25, 585 | ): 586 | super(ActivationBalancer, self).__init__() 587 | self.channel_dim = channel_dim 588 | self.min_positive = min_positive 589 | self.max_positive = max_positive 590 | self.max_factor = max_factor 591 | self.min_abs = min_abs 592 | self.max_abs = max_abs 593 | assert 0 < balance_prob <= 1, balance_prob 594 | self.balance_prob = balance_prob 595 | 596 | def forward(self, x: Tensor) -> Tensor: 597 | if random.random() >= self.balance_prob: 598 | return x 599 | else: 600 | return ActivationBalancerFunction.apply( 601 | x, 602 | self.channel_dim, 603 | self.min_positive, 604 | self.max_positive, 605 | self.max_factor / self.balance_prob, 606 | self.min_abs, 607 | self.max_abs, 608 | ) 609 | 610 | 611 | class DoubleSwishFunction(torch.autograd.Function): 612 | """ 613 | double_swish(x) = x * torch.sigmoid(x-1) 614 | This is a definition, originally motivated by its close numerical 615 | similarity to swish(swish(x)), where swish(x) = x * sigmoid(x). 616 | 617 | Memory-efficient derivative computation: 618 | double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1) 619 | double_swish'(x) = d/dx double_swish(x) = x * s'(x) + x' * s(x) = x * s'(x) + s(x). 620 | Now, s'(x) = s(x) * (1-s(x)). 621 | double_swish'(x) = x * s'(x) + s(x). 622 | = x * s(x) * (1-s(x)) + s(x). 623 | = double_swish(x) * (1-s(x)) + s(x) 624 | ... so we just need to remember s(x) but not x itself. 625 | """ 626 | 627 | @staticmethod 628 | def forward(ctx, x: Tensor) -> Tensor: 629 | x = x.detach() 630 | s = torch.sigmoid(x - 1.0) 631 | y = x * s 632 | ctx.save_for_backward(s, y) 633 | return y 634 | 635 | @staticmethod 636 | def backward(ctx, y_grad: Tensor) -> Tensor: 637 | s, y = ctx.saved_tensors 638 | return (y * (1 - s) + s) * y_grad 639 | 640 | 641 | class DoubleSwish(torch.nn.Module): 642 | def forward(self, x: Tensor) -> Tensor: 643 | """Return double-swish activation function which is an approximation to Swish(Swish(x)), 644 | that we approximate closely with x * sigmoid(x-1). 645 | """ 646 | if torch.jit.is_scripting() or is_jit_tracing(): 647 | return x * torch.sigmoid(x - 1.0) 648 | else: 649 | return DoubleSwishFunction.apply(x) 650 | 651 | 652 | 653 | 654 | class Conv2dSubsampling(nn.Module): 655 | """Convolutional 2D subsampling (to 1/4 length). 656 | 657 | Convert an input of shape (N, T, idim) to an output 658 | with shape (N, T', odim), where 659 | T' = ((T-3)//2-1)//2, which approximates T' == T//4 660 | 661 | It is based on 662 | https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa 663 | """ 664 | 665 | def __init__( 666 | self, 667 | in_channels: int, 668 | out_channels: int, 669 | layer1_channels: int = 8, 670 | layer2_channels: int = 32, 671 | layer3_channels: int = 128, 672 | is_pnnx: bool = False, 673 | ) -> None: 674 | """ 675 | Args: 676 | in_channels: 677 | Number of channels in. The input shape is (N, T, in_channels). 678 | Caution: It requires: T >= 9, in_channels >= 9. 679 | out_channels 680 | Output dim. The output shape is (N, ((T-3)//2-1)//2, out_channels) 681 | layer1_channels: 682 | Number of channels in layer1 683 | layer1_channels: 684 | Number of channels in layer2 685 | is_pnnx: 686 | True if we are converting the model to PNNX format. 687 | False otherwise. 688 | """ 689 | assert in_channels >= 9 690 | super().__init__() 691 | 692 | self.conv = nn.Sequential( 693 | ScaledConv2d( 694 | in_channels=1, 695 | out_channels=layer1_channels, 696 | kernel_size=3, 697 | padding=0, 698 | ), 699 | ActivationBalancer(channel_dim=1), 700 | DoubleSwish(), 701 | ScaledConv2d( 702 | in_channels=layer1_channels, 703 | out_channels=layer2_channels, 704 | kernel_size=3, 705 | stride=2, 706 | ), 707 | ActivationBalancer(channel_dim=1), 708 | DoubleSwish(), 709 | ScaledConv2d( 710 | in_channels=layer2_channels, 711 | out_channels=layer3_channels, 712 | kernel_size=3, 713 | stride=2, 714 | ), 715 | ActivationBalancer(channel_dim=1), 716 | DoubleSwish(), 717 | ) 718 | self.out = ScaledLinear( 719 | layer3_channels * (((in_channels - 3) // 2 - 1) // 2), out_channels 720 | ) 721 | # set learn_eps=False because out_norm is preceded by `out`, and `out` 722 | # itself has learned scale, so the extra degree of freedom is not 723 | # needed. 724 | self.out_norm = BasicNorm(out_channels, learn_eps=False) 725 | # constrain median of output to be close to zero. 726 | self.out_balancer = ActivationBalancer( 727 | channel_dim=-1, min_positive=0.45, max_positive=0.55 728 | ) 729 | 730 | # ncnn supports only batch size == 1 731 | self.is_pnnx = is_pnnx 732 | self.conv_out_dim = self.out.weight.shape[1] 733 | 734 | def forward(self, x: torch.Tensor) -> torch.Tensor: 735 | """Subsample x. 736 | 737 | Args: 738 | x: 739 | Its shape is (N, T, idim). 740 | 741 | Returns: 742 | Return a tensor of shape (N, ((T-3)//2-1)//2, odim) 743 | """ 744 | # On entry, x is (N, T, idim) 745 | x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W) 746 | x = self.conv(x) 747 | 748 | if torch.jit.is_tracing() and self.is_pnnx: 749 | x = x.permute(0, 2, 1, 3).reshape(1, -1, self.conv_out_dim) 750 | x = self.out(x) 751 | else: 752 | # Now x is of shape (N, odim, ((T-3)//2-1)//2, ((idim-3)//2-1)//2) 753 | b, c, t, f = x.size() 754 | x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) 755 | 756 | # Now x is of shape (N, ((T-3)//2-1))//2, odim) 757 | x = self.out_norm(x) 758 | x = self.out_balancer(x) 759 | return x 760 | 761 | 762 | class RNNEncoderLayer(nn.Module): 763 | """ 764 | RNNEncoderLayer is made up of lstm and feedforward networks. 765 | 766 | Args: 767 | d_model: 768 | The number of expected features in the input (required). 769 | dim_feedforward: 770 | The dimension of feedforward network model (default=2048). 771 | rnn_hidden_size: 772 | The hidden dimension of rnn layer. 773 | dropout: 774 | The dropout value (default=0.1). 775 | layer_dropout: 776 | The dropout value for model-level warmup (default=0.075). 777 | """ 778 | 779 | def __init__( 780 | self, 781 | d_model: int, 782 | dim_feedforward: int, 783 | rnn_hidden_size: int, 784 | dropout: float = 0.1, 785 | layer_dropout: float = 0.075, 786 | ) -> None: 787 | super(RNNEncoderLayer, self).__init__() 788 | self.layer_dropout = layer_dropout 789 | self.d_model = d_model 790 | self.rnn_hidden_size = rnn_hidden_size 791 | 792 | assert rnn_hidden_size >= d_model, (rnn_hidden_size, d_model) 793 | self.lstm = ScaledLSTM( 794 | input_size=d_model, 795 | hidden_size=rnn_hidden_size, 796 | proj_size=d_model if rnn_hidden_size > d_model else 0, 797 | num_layers=1, 798 | dropout=0.0, 799 | ) 800 | self.feed_forward = nn.Sequential( 801 | ScaledLinear(d_model, dim_feedforward), 802 | ActivationBalancer(channel_dim=-1), 803 | DoubleSwish(), 804 | nn.Dropout(dropout), 805 | ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), 806 | ) 807 | self.norm_final = BasicNorm(d_model) 808 | 809 | # try to ensure the output is close to zero-mean (or at least, zero-median). # noqa 810 | self.balancer = ActivationBalancer( 811 | channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0 812 | ) 813 | self.dropout = nn.Dropout(dropout) 814 | 815 | def forward( 816 | self, 817 | src: torch.Tensor, 818 | states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, 819 | warmup: float = 1.0, 820 | ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: 821 | """ 822 | Pass the input through the encoder layer. 823 | 824 | Args: 825 | src: 826 | The sequence to the encoder layer (required). 827 | Its shape is (S, N, E), where S is the sequence length, 828 | N is the batch size, and E is the feature number. 829 | states: 830 | A tuple of 2 tensors (optional). It is for streaming inference. 831 | states[0] is the hidden states of all layers, 832 | with shape of (1, N, d_model); 833 | states[1] is the cell states of all layers, 834 | with shape of (1, N, rnn_hidden_size). 835 | warmup: 836 | It controls selective bypass of of layers; if < 1.0, we will 837 | bypass layers more frequently. 838 | """ 839 | src_orig = src 840 | 841 | warmup_scale = min(0.1 + warmup, 1.0) 842 | # alpha = 1.0 means fully use this encoder layer, 0.0 would mean 843 | # completely bypass it. 844 | if self.training: 845 | alpha = ( 846 | warmup_scale 847 | if torch.rand(()).item() <= (1.0 - self.layer_dropout) 848 | else 0.1 849 | ) 850 | else: 851 | alpha = 1.0 852 | 853 | # lstm module 854 | if states is None: 855 | src_lstm = self.lstm(src)[0] 856 | # torch.jit.trace requires returned types be the same as annotated 857 | new_states = (torch.empty(0), torch.empty(0)) 858 | else: 859 | assert not self.training 860 | assert len(states) == 2 861 | if not torch.jit.is_tracing(): 862 | # for hidden state 863 | assert states[0].shape == (1, src.size(1), self.d_model) 864 | # for cell state 865 | assert states[1].shape == (1, src.size(1), self.rnn_hidden_size) 866 | src_lstm, new_states = self.lstm(src, states) 867 | src = self.dropout(src_lstm) + src 868 | 869 | # feed forward module 870 | src = src + self.dropout(self.feed_forward(src)) 871 | 872 | src = self.norm_final(self.balancer(src)) 873 | 874 | if alpha != 1.0: 875 | src = alpha * src + (1 - alpha) * src_orig 876 | 877 | return src, new_states 878 | 879 | 880 | class RandomCombine(nn.Module): 881 | """ 882 | This module combines a list of Tensors, all with the same shape, to 883 | produce a single output of that same shape which, in training time, 884 | is a random combination of all the inputs; but which in test time 885 | will be just the last input. 886 | 887 | The idea is that the list of Tensors will be a list of outputs of multiple 888 | conformer layers. This has a similar effect as iterated loss. (See: 889 | DEJA-VU: DOUBLE FEATURE PRESENTATION AND ITERATED LOSS IN DEEP TRANSFORMER 890 | NETWORKS). 891 | """ 892 | 893 | def __init__( 894 | self, 895 | num_inputs: int, 896 | final_weight: float = 0.5, 897 | pure_prob: float = 0.5, 898 | stddev: float = 2.0, 899 | ) -> None: 900 | """ 901 | Args: 902 | num_inputs: 903 | The number of tensor inputs, which equals the number of layers' 904 | outputs that are fed into this module. E.g. in an 18-layer neural 905 | net if we output layers 16, 12, 18, num_inputs would be 3. 906 | final_weight: 907 | The amount of weight or probability we assign to the 908 | final layer when randomly choosing layers or when choosing 909 | continuous layer weights. 910 | pure_prob: 911 | The probability, on each frame, with which we choose 912 | only a single layer to output (rather than an interpolation) 913 | stddev: 914 | A standard deviation that we add to log-probs for computing 915 | randomized weights. 916 | 917 | The method of choosing which layers, or combinations of layers, to use, 918 | is conceptually as follows:: 919 | 920 | With probability `pure_prob`:: 921 | With probability `final_weight`: choose final layer, 922 | Else: choose random non-final layer. 923 | Else:: 924 | Choose initial log-weights that correspond to assigning 925 | weight `final_weight` to the final layer and equal 926 | weights to other layers; then add Gaussian noise 927 | with variance `stddev` to these log-weights, and normalize 928 | to weights (note: the average weight assigned to the 929 | final layer here will not be `final_weight` if stddev>0). 930 | """ 931 | super().__init__() 932 | assert 0 <= pure_prob <= 1, pure_prob 933 | assert 0 < final_weight < 1, final_weight 934 | assert num_inputs >= 1 935 | 936 | self.num_inputs = num_inputs 937 | self.final_weight = final_weight 938 | self.pure_prob = pure_prob 939 | self.stddev = stddev 940 | 941 | self.final_log_weight = ( 942 | torch.tensor( 943 | (final_weight / (1 - final_weight)) * (self.num_inputs - 1) 944 | ) 945 | .log() 946 | .item() 947 | ) 948 | 949 | def forward(self, inputs: List[torch.Tensor]) -> torch.Tensor: 950 | """Forward function. 951 | Args: 952 | inputs: 953 | A list of Tensor, e.g. from various layers of a transformer. 954 | All must be the same shape, of (*, num_channels) 955 | Returns: 956 | A Tensor of shape (*, num_channels). In test mode 957 | this is just the final input. 958 | """ 959 | num_inputs = self.num_inputs 960 | assert len(inputs) == num_inputs 961 | if not self.training or torch.jit.is_scripting(): 962 | return inputs[-1] 963 | 964 | # Shape of weights: (*, num_inputs) 965 | num_channels = inputs[0].shape[-1] 966 | num_frames = inputs[0].numel() // num_channels 967 | 968 | ndim = inputs[0].ndim 969 | # stacked_inputs: (num_frames, num_channels, num_inputs) 970 | stacked_inputs = torch.stack(inputs, dim=ndim).reshape( 971 | (num_frames, num_channels, num_inputs) 972 | ) 973 | 974 | # weights: (num_frames, num_inputs) 975 | weights = self._get_random_weights( 976 | inputs[0].dtype, inputs[0].device, num_frames 977 | ) 978 | 979 | weights = weights.reshape(num_frames, num_inputs, 1) 980 | # ans: (num_frames, num_channels, 1) 981 | ans = torch.matmul(stacked_inputs, weights) 982 | # ans: (*, num_channels) 983 | 984 | ans = ans.reshape(inputs[0].shape[:-1] + (num_channels,)) 985 | 986 | # The following if causes errors for torch script in torch 1.6.0 987 | # if __name__ == "__main__": 988 | # # for testing only... 989 | # print("Weights = ", weights.reshape(num_frames, num_inputs)) 990 | return ans 991 | 992 | def _get_random_weights( 993 | self, dtype: torch.dtype, device: torch.device, num_frames: int 994 | ) -> torch.Tensor: 995 | """Return a tensor of random weights, of shape 996 | `(num_frames, self.num_inputs)`, 997 | Args: 998 | dtype: 999 | The data-type desired for the answer, e.g. float, double. 1000 | device: 1001 | The device needed for the answer. 1002 | num_frames: 1003 | The number of sets of weights desired 1004 | Returns: 1005 | A tensor of shape (num_frames, self.num_inputs), such that 1006 | `ans.sum(dim=1)` is all ones. 1007 | """ 1008 | pure_prob = self.pure_prob 1009 | if pure_prob == 0.0: 1010 | return self._get_random_mixed_weights(dtype, device, num_frames) 1011 | elif pure_prob == 1.0: 1012 | return self._get_random_pure_weights(dtype, device, num_frames) 1013 | else: 1014 | p = self._get_random_pure_weights(dtype, device, num_frames) 1015 | m = self._get_random_mixed_weights(dtype, device, num_frames) 1016 | return torch.where( 1017 | torch.rand(num_frames, 1, device=device) < self.pure_prob, p, m 1018 | ) 1019 | 1020 | def _get_random_pure_weights( 1021 | self, dtype: torch.dtype, device: torch.device, num_frames: int 1022 | ): 1023 | """Return a tensor of random one-hot weights, of shape 1024 | `(num_frames, self.num_inputs)`, 1025 | Args: 1026 | dtype: 1027 | The data-type desired for the answer, e.g. float, double. 1028 | device: 1029 | The device needed for the answer. 1030 | num_frames: 1031 | The number of sets of weights desired. 1032 | Returns: 1033 | A one-hot tensor of shape `(num_frames, self.num_inputs)`, with 1034 | exactly one weight equal to 1.0 on each frame. 1035 | """ 1036 | final_prob = self.final_weight 1037 | 1038 | # final contains self.num_inputs - 1 in all elements 1039 | final = torch.full((num_frames,), self.num_inputs - 1, device=device) 1040 | # nonfinal contains random integers in [0..num_inputs - 2], these are for non-final weights. # noqa 1041 | nonfinal = torch.randint( 1042 | self.num_inputs - 1, (num_frames,), device=device 1043 | ) 1044 | 1045 | indexes = torch.where( 1046 | torch.rand(num_frames, device=device) < final_prob, final, nonfinal 1047 | ) 1048 | ans = torch.nn.functional.one_hot( 1049 | indexes, num_classes=self.num_inputs 1050 | ).to(dtype=dtype) 1051 | return ans 1052 | 1053 | def _get_random_mixed_weights( 1054 | self, dtype: torch.dtype, device: torch.device, num_frames: int 1055 | ): 1056 | """Return a tensor of random one-hot weights, of shape 1057 | `(num_frames, self.num_inputs)`, 1058 | Args: 1059 | dtype: 1060 | The data-type desired for the answer, e.g. float, double. 1061 | device: 1062 | The device needed for the answer. 1063 | num_frames: 1064 | The number of sets of weights desired. 1065 | Returns: 1066 | A tensor of shape (num_frames, self.num_inputs), which elements 1067 | in [0..1] that sum to one over the second axis, i.e. 1068 | `ans.sum(dim=1)` is all ones. 1069 | """ 1070 | logprobs = ( 1071 | torch.randn(num_frames, self.num_inputs, dtype=dtype, device=device) 1072 | * self.stddev # noqa 1073 | ) 1074 | logprobs[:, -1] += self.final_log_weight 1075 | return logprobs.softmax(dim=1) 1076 | 1077 | 1078 | class RNNEncoder(nn.Module): 1079 | """ 1080 | RNNEncoder is a stack of N encoder layers. 1081 | 1082 | Args: 1083 | encoder_layer: 1084 | An instance of the RNNEncoderLayer() class (required). 1085 | num_layers: 1086 | The number of sub-encoder-layers in the encoder (required). 1087 | """ 1088 | 1089 | def __init__( 1090 | self, 1091 | encoder_layer: nn.Module, 1092 | num_layers: int, 1093 | aux_layers: Optional[List[int]] = None, 1094 | ) -> None: 1095 | super(RNNEncoder, self).__init__() 1096 | self.layers = nn.ModuleList( 1097 | [copy.deepcopy(encoder_layer) for i in range(num_layers)] 1098 | ) 1099 | self.num_layers = num_layers 1100 | self.d_model = encoder_layer.d_model 1101 | self.rnn_hidden_size = encoder_layer.rnn_hidden_size 1102 | 1103 | self.aux_layers: List[int] = [] 1104 | self.combiner: Optional[nn.Module] = None 1105 | if aux_layers is not None: 1106 | assert len(set(aux_layers)) == len(aux_layers) 1107 | assert num_layers - 1 not in aux_layers 1108 | self.aux_layers = aux_layers + [num_layers - 1] 1109 | self.combiner = RandomCombine( 1110 | num_inputs=len(self.aux_layers), 1111 | final_weight=0.5, 1112 | pure_prob=0.333, 1113 | stddev=2.0, 1114 | ) 1115 | 1116 | def trim_layers(self, stride): 1117 | self.aux_layers = [] 1118 | self.combiner = None 1119 | self.layers = self.layers[::stride] 1120 | self.num_layers = len(self.layers) 1121 | 1122 | def forward( 1123 | self, 1124 | src: torch.Tensor, 1125 | states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, 1126 | warmup: float = 1.0, 1127 | ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: 1128 | """ 1129 | Pass the input through the encoder layer in turn. 1130 | 1131 | Args: 1132 | src: 1133 | The sequence to the encoder layer (required). 1134 | Its shape is (S, N, E), where S is the sequence length, 1135 | N is the batch size, and E is the feature number. 1136 | states: 1137 | A tuple of 2 tensors (optional). It is for streaming inference. 1138 | states[0] is the hidden states of all layers, 1139 | with shape of (num_layers, N, d_model); 1140 | states[1] is the cell states of all layers, 1141 | with shape of (num_layers, N, rnn_hidden_size). 1142 | warmup: 1143 | It controls selective bypass of of layers; if < 1.0, we will 1144 | bypass layers more frequently. 1145 | """ 1146 | if states is not None: 1147 | assert not self.training 1148 | assert len(states) == 2 1149 | if not torch.jit.is_tracing(): 1150 | # for hidden state 1151 | assert states[0].shape == ( 1152 | self.num_layers, 1153 | src.size(1), 1154 | self.d_model, 1155 | ) 1156 | # for cell state 1157 | assert states[1].shape == ( 1158 | self.num_layers, 1159 | src.size(1), 1160 | self.rnn_hidden_size, 1161 | ) 1162 | 1163 | output = src 1164 | 1165 | outputs = [] 1166 | 1167 | new_hidden_states = [] 1168 | new_cell_states = [] 1169 | 1170 | for i, mod in enumerate(self.layers): 1171 | if states is None: 1172 | output = mod(output, warmup=warmup)[0] 1173 | else: 1174 | layer_state = ( 1175 | states[0][i : i + 1, :, :], # h: (1, N, d_model) 1176 | states[1][i : i + 1, :, :], # c: (1, N, rnn_hidden_size) 1177 | ) 1178 | output, (h, c) = mod(output, layer_state) 1179 | new_hidden_states.append(h) 1180 | new_cell_states.append(c) 1181 | 1182 | if self.combiner is not None and i in self.aux_layers: 1183 | outputs.append(output) 1184 | 1185 | if self.combiner is not None: 1186 | output = self.combiner(outputs) 1187 | 1188 | if states is None: 1189 | new_states = (torch.empty(0), torch.empty(0)) 1190 | else: 1191 | new_states = ( 1192 | torch.cat(new_hidden_states, dim=0), 1193 | torch.cat(new_cell_states, dim=0), 1194 | ) 1195 | 1196 | return output, new_states 1197 | 1198 | 1199 | class RNN(EncoderInterface): 1200 | """ 1201 | Args: 1202 | num_features (int): 1203 | Number of input features. 1204 | subsampling_factor (int): 1205 | Subsampling factor of encoder (convolution layers before lstm layers) (default=4). # noqa 1206 | d_model (int): 1207 | Output dimension (default=512). 1208 | dim_feedforward (int): 1209 | Feedforward dimension (default=2048). 1210 | rnn_hidden_size (int): 1211 | Hidden dimension for lstm layers (default=1024). 1212 | num_encoder_layers (int): 1213 | Number of encoder layers (default=12). 1214 | dropout (float): 1215 | Dropout rate (default=0.1). 1216 | layer_dropout (float): 1217 | Dropout value for model-level warmup (default=0.075). 1218 | aux_layer_period (int): 1219 | Period of auxiliary layers used for random combiner during training. 1220 | If set to 0, will not use the random combiner (Default). 1221 | You can set a positive integer to use the random combiner, e.g., 3. 1222 | is_pnnx: 1223 | True to make this class exportable via PNNX. 1224 | """ 1225 | 1226 | def __init__( 1227 | self, 1228 | num_features: int, 1229 | subsampling_factor: int = 4, 1230 | d_model: int = 512, 1231 | dim_feedforward: int = 2048, 1232 | rnn_hidden_size: int = 1024, 1233 | num_encoder_layers: int = 12, 1234 | dropout: float = 0.1, 1235 | layer_dropout: float = 0.075, 1236 | aux_layer_period: int = 0, 1237 | is_pnnx: bool = False, 1238 | ) -> None: 1239 | super(RNN, self).__init__() 1240 | 1241 | self.num_features = num_features 1242 | self.subsampling_factor = subsampling_factor 1243 | if subsampling_factor != 4: 1244 | raise NotImplementedError("Support only 'subsampling_factor=4'.") 1245 | 1246 | # self.encoder_embed converts the input of shape (N, T, num_features) 1247 | # to the shape (N, T//subsampling_factor, d_model). 1248 | # That is, it does two things simultaneously: 1249 | # (1) subsampling: T -> T//subsampling_factor 1250 | # (2) embedding: num_features -> d_model 1251 | self.encoder_embed = Conv2dSubsampling( 1252 | num_features, 1253 | d_model, 1254 | is_pnnx=is_pnnx, 1255 | ) 1256 | 1257 | self.is_pnnx = is_pnnx 1258 | 1259 | self.num_encoder_layers = num_encoder_layers 1260 | self.d_model = d_model 1261 | self.rnn_hidden_size = rnn_hidden_size 1262 | 1263 | encoder_layer = RNNEncoderLayer( 1264 | d_model=d_model, 1265 | dim_feedforward=dim_feedforward, 1266 | rnn_hidden_size=rnn_hidden_size, 1267 | dropout=dropout, 1268 | layer_dropout=layer_dropout, 1269 | ) 1270 | self.encoder = RNNEncoder( 1271 | encoder_layer, 1272 | num_encoder_layers, 1273 | aux_layers=list( 1274 | range( 1275 | num_encoder_layers // 3, 1276 | num_encoder_layers - 1, 1277 | aux_layer_period, 1278 | ) 1279 | ) 1280 | if aux_layer_period > 0 1281 | else None, 1282 | ) 1283 | 1284 | def forward( 1285 | self, 1286 | x: torch.Tensor, 1287 | x_lens: torch.Tensor, 1288 | states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, 1289 | warmup: float = 1.0, 1290 | ) -> Tuple[torch.Tensor, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: 1291 | """ 1292 | Args: 1293 | x: 1294 | The input tensor. Its shape is (N, T, C), where N is the batch size, 1295 | T is the sequence length, C is the feature dimension. 1296 | x_lens: 1297 | A tensor of shape (N,), containing the number of frames in `x` 1298 | before padding. 1299 | states: 1300 | A tuple of 2 tensors (optional). It is for streaming inference. 1301 | states[0] is the hidden states of all layers, 1302 | with shape of (num_layers, N, d_model); 1303 | states[1] is the cell states of all layers, 1304 | with shape of (num_layers, N, rnn_hidden_size). 1305 | warmup: 1306 | A floating point value that gradually increases from 0 throughout 1307 | training; when it is >= 1.0 we are "fully warmed up". It is used 1308 | to turn modules on sequentially. 1309 | 1310 | Returns: 1311 | A tuple of 3 tensors: 1312 | - embeddings: its shape is (N, T', d_model), where T' is the output 1313 | sequence lengths. 1314 | - lengths: a tensor of shape (batch_size,) containing the number of 1315 | frames in `embeddings` before padding. 1316 | - updated states, whose shape is the same as the input states. 1317 | """ 1318 | x = self.encoder_embed(x) 1319 | x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) 1320 | 1321 | # lengths = ((x_lens - 3) // 2 - 1) // 2 # issue an warning 1322 | # 1323 | # Note: rounding_mode in torch.div() is available only in torch >= 1.8.0 1324 | if not self.is_pnnx: 1325 | lengths = (((x_lens - 3) >> 1) - 1) >> 1 1326 | else: 1327 | lengths1 = torch.floor((x_lens - 3) / 2) 1328 | lengths = torch.floor((lengths1 - 1) / 2) 1329 | lengths = lengths.to(x_lens) 1330 | 1331 | if not torch.jit.is_tracing(): 1332 | assert x.size(0) == lengths.max().item() 1333 | 1334 | if states is None: 1335 | x = self.encoder(x, warmup=warmup)[0] 1336 | # torch.jit.trace requires returned types to be the same as annotated # noqa 1337 | new_states = (torch.empty(0), torch.empty(0)) 1338 | else: 1339 | assert not self.training 1340 | assert len(states) == 2 1341 | if not torch.jit.is_tracing(): 1342 | # for hidden state 1343 | assert states[0].shape == ( 1344 | self.num_encoder_layers, 1345 | x.size(1), 1346 | self.d_model, 1347 | ) 1348 | # for cell state 1349 | assert states[1].shape == ( 1350 | self.num_encoder_layers, 1351 | x.size(1), 1352 | self.rnn_hidden_size, 1353 | ) 1354 | x, new_states = self.encoder(x, states) 1355 | 1356 | x = x.permute(1, 0, 2) # (T, N, C) -> (N, T, C) 1357 | return x, lengths, new_states 1358 | 1359 | @torch.jit.export 1360 | def get_init_states( 1361 | self, batch_size: int = 1, device: torch.device = torch.device("cpu") 1362 | ) -> Tuple[torch.Tensor, torch.Tensor]: 1363 | """Get model initial states.""" 1364 | # for rnn hidden states 1365 | hidden_states = torch.zeros( 1366 | (self.num_encoder_layers, batch_size, self.d_model), device=device 1367 | ) 1368 | cell_states = torch.zeros( 1369 | (self.num_encoder_layers, batch_size, self.rnn_hidden_size), 1370 | device=device, 1371 | ) 1372 | return (hidden_states, cell_states) 1373 | 1374 | 1375 | class EncoderModelForLoad(nn.Module): 1376 | def __init__( 1377 | self, 1378 | num_features: int = 80 1379 | ) -> None: 1380 | super(EncoderModelForLoad, self).__init__() 1381 | self.encoder = RNN(num_features=num_features) 1382 | 1383 | class StreamingTurnModel(nn.Module): 1384 | def __init__( 1385 | self, 1386 | loaded_model, 1387 | num_layers=12 1388 | ) -> None: 1389 | super(StreamingTurnModel, self).__init__() 1390 | if loaded_model is not None: 1391 | self.encoder = loaded_model.encoder 1392 | else: 1393 | self.encoder = RNN(num_features=80, num_encoder_layers=num_layers) 1394 | self.decoder = nn.Sequential( 1395 | nn.Linear(512, 1), 1396 | #nn.Sigmoid() 1397 | ) 1398 | 1399 | def forward( 1400 | self, 1401 | x: torch.Tensor, 1402 | x_lens: torch.Tensor, 1403 | states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, 1404 | warmup: float = 1.0, 1405 | ) -> Tuple[torch.Tensor, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: 1406 | x, x_lens, states = self.encoder(x, x_lens, states=states, warmup=warmup) 1407 | x = self.decoder(x) 1408 | return x, x_lens, states 1409 | -------------------------------------------------------------------------------- /train-rnn.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "e29d978a-83b7-40ff-beea-2312db0e89ac", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "from datasets import load_from_disk\n", 11 | "train_dataset = load_from_disk(\"smart-turn/datasets/human_5_all/\")[\"train\"]" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 2, 17 | "id": "9ef6ef02-80a9-44ed-aec4-b06d2c64f3c3", 18 | "metadata": {}, 19 | "outputs": [ 20 | { 21 | "data": { 22 | "application/vnd.jupyter.widget-view+json": { 23 | "model_id": "49a5ac8d35f64498b53d44bf529965b0", 24 | "version_major": 2, 25 | "version_minor": 0 26 | }, 27 | "text/plain": [ 28 | "Map: 0%| | 0/3862 [00:00" 266 | ] 267 | }, 268 | "metadata": {}, 269 | "output_type": "display_data" 270 | } 271 | ], 272 | "source": [ 273 | "import seaborn as sns\n", 274 | "import matplotlib.pyplot as plt\n", 275 | "\n", 276 | "# You can manually set some value so that the bottom left and top right are roughly equal\n", 277 | "threshold = 0.024#sum(preds) / len(preds)\n", 278 | "print(threshold)\n", 279 | "\n", 280 | "plt.figure(figsize=(8, 6))\n", 281 | "sns.heatmap(confusion_matrix(labels, [1 if x > threshold else 0 for x in preds]),\n", 282 | " annot=True, fmt='d', cmap='Blues',\n", 283 | " xticklabels=['Incomplete', 'Complete'],\n", 284 | " yticklabels=['Incomplete', 'Complete'])\n", 285 | "plt.title(f'Confusion Matrix - Set')\n", 286 | "plt.ylabel('True Label')\n", 287 | "plt.xlabel('Predicted Label')\n", 288 | "plt.tight_layout()\n", 289 | "plt.show()" 290 | ] 291 | }, 292 | { 293 | "cell_type": "code", 294 | "execution_count": null, 295 | "id": "edf21088-61c6-487e-b311-9121a85d6b3d", 296 | "metadata": {}, 297 | "outputs": [], 298 | "source": [ 299 | "import IPython.display as ipd\n", 300 | "import numpy as np\n", 301 | "\n", 302 | "\n", 303 | "for x in s_test_dataset.shuffle():\n", 304 | " audio = x[\"audio\"]\n", 305 | " print(x[\"endpoint_bool\"])\n", 306 | " ipd.display(ipd.Audio(audio[\"array\"], rate=audio[\"sampling_rate\"], autoplay=True))\n", 307 | " break" 308 | ] 309 | } 310 | ], 311 | "metadata": { 312 | "kernelspec": { 313 | "display_name": "Python 3 (ipykernel)", 314 | "language": "python", 315 | "name": "python3" 316 | }, 317 | "language_info": { 318 | "codemirror_mode": { 319 | "name": "ipython", 320 | "version": 3 321 | }, 322 | "file_extension": ".py", 323 | "mimetype": "text/x-python", 324 | "name": "python", 325 | "nbconvert_exporter": "python", 326 | "pygments_lexer": "ipython3", 327 | "version": "3.10.12" 328 | } 329 | }, 330 | "nbformat": 4, 331 | "nbformat_minor": 5 332 | } 333 | --------------------------------------------------------------------------------