├── README.md └── lstmp.py /README.md: -------------------------------------------------------------------------------- 1 | ## Description 2 | I am researching end-to-end ASR, such as CTC, Transducer and so on. There is a lot of variants of LSTM proposed for ASR task. In the [paper](https://static.googleusercontent.com/media/research.google.com/zh-CN//pubs/archive/43905.pdf), the lstm with projection layer gets better performance. But the lstmp isn't supported by Pytorch, so I implement this custom LSTM according to [this tutorial](https://pytorch.org/blog/optimizing-cuda-rnn-with-torchscript/). I hope it can helps other researchers. 3 | 4 | ## References 5 | 6 | [Long Short-Term Memory Recurrent Neural Network Architectures 7 | for Large Scale Acoustic Modeling](https://static.googleusercontent.com/media/research.google.com/zh-CN//pubs/archive/43905.pdf) 8 | 9 | [Optimizing CUDA Recurrent Neural Networks with TorchScript](https://pytorch.org/blog/optimizing-cuda-rnn-with-torchscript/)) 10 | -------------------------------------------------------------------------------- /lstmp.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.nn import init 4 | import torch.jit as jit 5 | from torch.nn import Parameter 6 | # from torch.jit import Tensor # there is an error 7 | from torch import Tensor 8 | from typing import List, Tuple 9 | 10 | class LSTMPCell(jit.ScriptModule): 11 | def __init__(self, input_size, hidden_size, projection_size): 12 | super(LSTMPCell, self).__init__() 13 | self.input_size = input_size 14 | self.hidden_size = hidden_size 15 | self.projection_size = projection_size 16 | self.weight_ih = Parameter(torch.randn(4 * hidden_size, input_size)) 17 | self.weight_hh = Parameter(torch.randn(4 * hidden_size, projection_size)) 18 | self.weight_hr = Parameter(torch.randn(projection_size, hidden_size)) 19 | self.bias_ih = Parameter(torch.randn(4 * hidden_size)) 20 | self.bias_hh = Parameter(torch.randn(4 * hidden_size)) 21 | self.init_weights() 22 | 23 | @jit.script_method 24 | def forward(self, input, state): 25 | # type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]] 26 | # input: batch_size * input_size 27 | # state: hx -> batch_size * projection_size 28 | # cx -> batch_size * hidden_size 29 | # state cannot be None 30 | ''' 31 | if state is not None: 32 | hx, cx = state 33 | else: 34 | hx = input.new_zeros(input.size(0), self.projection_size, requires_grad=False) 35 | cx = input.new_zeros(input.size(0), self.hidden_size, requires_grad=False) 36 | ''' 37 | hx, cx = state 38 | gates = (torch.mm(input, self.weight_ih.t()) + self.bias_ih + 39 | torch.mm(hx, self.weight_hh.t()) + self.bias_hh) 40 | ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) 41 | 42 | ingate = torch.sigmoid(ingate) 43 | forgetgate = torch.sigmoid(forgetgate) 44 | cellgate = torch.tanh(cellgate) 45 | outgate = torch.sigmoid(outgate) 46 | 47 | cy = (forgetgate * cx) + (ingate * cellgate) 48 | hy = outgate * torch.tanh(cy) 49 | hy = torch.mm(hy, self.weight_hr.t()) 50 | 51 | return hy, (hy, cy) 52 | 53 | def init_weights(self): 54 | stdv = 1.0 / math.sqrt(self.hidden_size) 55 | init.uniform_(self.weight_ih, -stdv, stdv) 56 | init.uniform_(self.weight_hh, -stdv, stdv) 57 | init.uniform_(self.weight_hr, -stdv, stdv) 58 | init.uniform_(self.bias_ih) 59 | init.uniform_(self.bias_hh) 60 | 61 | class LSTMPLayer(jit.ScriptModule): 62 | # def __init__(self, cell, *cell_args): 63 | def __init__(self, input_size, hidden_size, projection_size): 64 | super(LSTMPLayer, self).__init__() 65 | # self.cell = cell(*cell_args) 66 | self.input_size = input_size 67 | self.hidden_size = hidden_size 68 | self.projection_size = projection_size 69 | self.cell = LSTMPCell(input_size=input_size, hidden_size=hidden_size, projection_size=projection_size) 70 | 71 | @jit.script_method 72 | def forward(self, input, state): 73 | # type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]] 74 | # state cannot be None 75 | inputs = input.unbind(0) 76 | outputs = torch.jit.annotate(List[Tensor], []) 77 | for i in range(len(inputs)): 78 | out, state = self.cell(inputs[i], state) 79 | outputs += [out] 80 | return torch.stack(outputs), state 81 | 82 | 83 | class LSTMCell(jit.ScriptModule): 84 | def __init__(self, input_size, hidden_size): 85 | super(LSTMCell, self).__init__() 86 | self.input_size = input_size 87 | self.hidden_size = hidden_size 88 | self.weight_ih = Parameter(torch.randn(4 * hidden_size, input_size)) 89 | self.weight_hh = Parameter(torch.randn(4 * hidden_size, hidden_size)) 90 | self.bias_ih = Parameter(torch.randn(4 * hidden_size)) 91 | self.bias_hh = Parameter(torch.randn(4 * hidden_size)) 92 | self.init_weights() 93 | 94 | @jit.script_method 95 | def forward(self, input, state): 96 | # type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]] 97 | hx, cx = state 98 | gates = (torch.mm(input, self.weight_ih.t()) + self.bias_ih + 99 | torch.mm(hx, self.weight_hh.t()) + self.bias_hh) 100 | ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) 101 | 102 | ingate = torch.sigmoid(ingate) 103 | forgetgate = torch.sigmoid(forgetgate) 104 | cellgate = torch.tanh(cellgate) 105 | outgate = torch.sigmoid(outgate) 106 | 107 | cy = (forgetgate * cx) + (ingate * cellgate) 108 | hy = outgate * torch.tanh(cy) 109 | 110 | return hy, (hy, cy) 111 | 112 | def init_weights(self): 113 | stdv = 1.0 / math.sqrt(self.hidden_size) 114 | init.uniform_(self.weight_ih, -stdv, stdv) 115 | init.uniform_(self.weight_hh, -stdv, stdv) 116 | init.uniform_(self.bias_ih) 117 | init.uniform_(self.bias_hh) 118 | 119 | 120 | 121 | class LSTMLayer(jit.ScriptModule): 122 | def __init__(self, cell, *cell_args): 123 | # def __init__(self, input_size, hidden_size): 124 | super(LSTMLayer, self).__init__() 125 | self.cell = cell(*cell_args) 126 | # self.cell = LSTMCell(input_size, hidden_size) 127 | # print('initial params of weight_ih: ') 128 | # print(self.cell.weight_ih) 129 | # print('initial params of weight_hh: ') 130 | # print(self.cell.weight_hh) 131 | 132 | @jit.script_method 133 | def forward(self, input, state): 134 | # type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]] 135 | inputs = input.unbind(0) 136 | outputs = torch.jit.annotate(List[Tensor], []) 137 | for i in range(len(inputs)): 138 | out, state = self.cell(inputs[i], state) 139 | outputs += [out] 140 | return torch.stack(outputs), state 141 | 142 | def test(): 143 | input_size = 320 144 | hidden_size = 768 145 | projection_size=256 146 | rnn = LSTMPLayer(input_size=input_size, hidden_size=hidden_size, projection_size=projection_size) 147 | x = torch.rand((50, 4, 320)) 148 | hx = x.new_zeros(x.size(1), projection_size, requires_grad=False) 149 | cx = x.new_zeros(x.size(1), hidden_size, requires_grad=False) 150 | state = [hx, cx] 151 | y, h = rnn(x, state) --------------------------------------------------------------------------------