├── linear_rnn ├── layers │ ├── __init__.py │ ├── gilr.py │ └── lru.py └── scan_triton │ ├── __init__.py │ ├── complex_rnn.py │ └── real_rnn_tie_input_gate.py ├── README.md ├── .gitignore └── tests ├── test_real_rnn.py └── test_complex_rnn.py /linear_rnn/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .lru import * 2 | from .gilr import * 3 | 4 | 5 | -------------------------------------------------------------------------------- /linear_rnn/scan_triton/__init__.py: -------------------------------------------------------------------------------- 1 | from .complex_rnn import * 2 | from .real_rnn_tie_input_gate import * 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Linear RNN Layer Zoo in Pytorch 2 | Support: 3 | - [Linear Recurrent Unit (LRU)](https://arxiv.org/pdf/2303.06349.pdf) 4 | - [Gated Impulse Linear Recurrent (GILR) layer](https://openreview.net/pdf?id=HyUNwulC-) 5 | 6 | TODO: 7 | - S5 8 | - Simple Recurrent Unit (SRU) 9 | 10 | # Requirement 11 | Install triton https://github.com/openai/triton 12 | 13 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | 5 | # C extensions 6 | *.so 7 | 8 | # Distribution / packaging 9 | bin/ 10 | build/ 11 | develop-eggs/ 12 | dist/ 13 | eggs/ 14 | lib/ 15 | lib64/ 16 | 17 | parts/ 18 | sdist/ 19 | var/ 20 | *.egg-info/ 21 | .installed.cfg 22 | *.egg 23 | 24 | # Installer logs 25 | pip-log.txt 26 | pip-delete-this-directory.txt 27 | 28 | # Unit test / coverage reports 29 | .tox/ 30 | .coverage 31 | .cache 32 | nosetests.xml 33 | coverage.xml 34 | 35 | # Translations 36 | *.mo 37 | 38 | # Mr Developer 39 | .mr.developer.cfg 40 | .project 41 | .pydevproject 42 | 43 | # Rope 44 | .ropeproject 45 | 46 | # Django stuff: 47 | *.log 48 | *.pot 49 | 50 | # Sphinx documentation 51 | docs/_build/ 52 | 53 | 54 | test.py -------------------------------------------------------------------------------- /linear_rnn/layers/gilr.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from linear_rnn.scan_triton import real_scan_tie_input_gate, real_scan_tie_input_gate_fused 5 | 6 | class GILRLayer(nn.Module): 7 | 8 | def __init__( 9 | self, 10 | d_model, 11 | factor=1, 12 | dropout=0.2, 13 | fuse_forget_gate=True 14 | ): 15 | super().__init__() 16 | self.d_model = d_model 17 | self.fuse_forget_gate = fuse_forget_gate 18 | 19 | 20 | self.in_proj = nn.Linear(self.d_model, self.d_model*3*factor) 21 | self.dropout = nn.Dropout(dropout) 22 | self.layer_norm = nn.LayerNorm(factor * self.d_model) 23 | self.out_proj = nn.Linear(self.d_model * factor, self.d_model) 24 | self.swish = nn.SiLU() 25 | 26 | def forward(self, x): 27 | u = self.in_proj(x) 28 | v, o, f = u.chunk(3,dim=-1) 29 | 30 | if not self.fuse_forget_gate: 31 | f = f.sigmoid() 32 | v = real_scan_tie_input_gate(v.contiguous(), f.contiguous()) 33 | else: 34 | v = real_scan_tie_input_gate_fused(v.contiguous(), f.contiguous()) 35 | 36 | return self.out_proj( 37 | self.layer_norm( 38 | self.dropout(v * self.swish(o)) 39 | ) 40 | ) 41 | -------------------------------------------------------------------------------- /tests/test_real_rnn.py: -------------------------------------------------------------------------------- 1 | from linear_rnn.scan_triton import real_scan_tie_input_gate_fused 2 | import torch 3 | 4 | def naive_forward_fused( v, 5 | f 6 | ): 7 | B, L, C = v.shape 8 | h = v.new_zeros(B, C) 9 | 10 | output = v.new_zeros(B, L, C) 11 | 12 | for i in range(L): 13 | input = v[:, i, :] 14 | decay = f[:, i, :].sigmoid() 15 | h = (h - input) * decay + input 16 | output[:, i, :] = h 17 | 18 | return output 19 | 20 | 21 | def check_gradient(): 22 | B = 4 23 | L = 1024 24 | C = 512 25 | v,f = torch.rand(B, L, C * 2).chunk(2, dim=-1) 26 | 27 | v = v.cuda().requires_grad_(True).contiguous() 28 | f = f.cuda().requires_grad_(True).contiguous() 29 | 30 | grad_output = torch.rand(B, L, C).cuda() 31 | 32 | output = naive_forward_fused(v, f) 33 | 34 | output.backward(grad_output) 35 | 36 | v_grad_clone = v.grad.clone() 37 | f_grad_clone = f.grad.clone() 38 | 39 | v.grad.zero_() 40 | f.grad.zero_() 41 | 42 | output2 = real_scan_tie_input_gate_fused(v, f) 43 | output2.backward(grad_output) 44 | 45 | diff0 = (output2 - output).abs().max() 46 | 47 | diff1 = (v.grad - v_grad_clone).abs().max() 48 | diff2 = (f.grad - f_grad_clone).abs().max() 49 | print(diff0, diff1, diff2) 50 | breakpoint() 51 | 52 | 53 | 54 | -------------------------------------------------------------------------------- /linear_rnn/layers/lru.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from linear_rnn.scan_triton import complex_scan 5 | 6 | 7 | 8 | class LRULayer(nn.Module): 9 | 10 | def __init__( 11 | self, 12 | d_model, 13 | dropout=0.2 14 | ): 15 | super().__init__() 16 | self.d_model = d_model 17 | self.in_proj = nn.Linear(self.d_model, self.d_model*4) 18 | self.dropout = nn.Dropout(dropout) 19 | self.layer_norm = nn.LayerNorm(2*self.d_model) 20 | self.out_proj = nn.Linear(2*self.d_model, self.d_model) 21 | 22 | nu_log, theta_log, gamma_log = self.initializer() 23 | self.nu_log = nn.Parameter(nu_log, requires_grad=True) 24 | self.theta_log = nn.Parameter(theta_log, requires_grad=True) 25 | self.gamma_log = nn.Parameter(gamma_log, requires_grad=True) 26 | 27 | self.swish = nn.SiLU() 28 | 29 | def initializer(self): 30 | #https://arxiv.org/pdf/2303.06349.pdf Sect.3.2.2 31 | r_min, r_max = 0.9, 0.999 32 | u1 = np.random.random(self.d_model) 33 | u2 = np.random.random(self.d_model) 34 | nu_log = np.log( 35 | -0.5 * np.log(u1 * (r_max**2 - r_min**2) + r_min**2) 36 | ) 37 | theta_log = np.log(u2 * np.pi * 2) 38 | gamma_log = np.log(np.sqrt(1 - np.exp(-np.exp(nu_log))**2)) 39 | 40 | return torch.Tensor(nu_log), torch.Tensor(theta_log), torch.Tensor(gamma_log) 41 | 42 | def forward(self, x): 43 | u = self.in_proj(x) 44 | v, o = u.chunk(2,dim=-1) 45 | 46 | nu = torch.exp(-torch.exp(self.nu_log)) 47 | theta = torch.exp(self.theta_log) 48 | gamma = torch.exp(self.gamma_log) 49 | 50 | f_real = nu * torch.cos(theta) 51 | f_imag = nu * torch.sin(theta) 52 | 53 | input_real, input_imag = v.chunk(2, dim=-1) 54 | input_real = gamma[None, None, :] * input_real 55 | input_imag = gamma[None, None, :] * input_imag 56 | 57 | f_real = f_real[None, None, :].expand_as(input_real) 58 | f_imag = f_imag[None, None, :].expand_as(input_real) 59 | 60 | output_real, output_imag = complex_scan( 61 | input_real.contiguous(), input_imag.contiguous(), 62 | f_real.contiguous(), f_imag.contiguous() 63 | ) 64 | 65 | return self.out_proj( 66 | self.layer_norm( 67 | self.dropout( 68 | torch.cat([output_real, output_imag], dim=-1) * self.swish(o) 69 | ) 70 | ) 71 | ) 72 | 73 | 74 | -------------------------------------------------------------------------------- /tests/test_complex_rnn.py: -------------------------------------------------------------------------------- 1 | from linear_rnn.triton_scan import complex_scan 2 | import torch 3 | 4 | def naive_forward( v_real, 5 | v_imag, 6 | f_real, 7 | f_imag 8 | ): 9 | B, L, C = v_real.shape 10 | h_real = v_real.new_zeros(B, C) 11 | h_imag = v_real.new_zeros(B, C) 12 | 13 | output_real = v_real.new_zeros(B, L, C) 14 | output_imag = v_real.new_zeros(B, L, C) 15 | 16 | for i in range(L): 17 | input_real = v_real[:, i, :] 18 | input_imag = v_imag[:, i, :] 19 | decay_real = f_real[:, i, :] 20 | decay_imag = f_imag[:, i, :] 21 | 22 | h_real_new = (h_real * decay_real - h_imag * decay_imag) + input_real 23 | 24 | h_imag_new = (h_real * decay_imag + h_imag * decay_real) + input_imag 25 | 26 | output_real[:, i, :] = h_real_new 27 | output_imag[:, i, :] = h_imag_new 28 | 29 | h_real = h_real_new 30 | h_imag = h_imag_new 31 | 32 | return output_real, output_imag 33 | 34 | 35 | def check_gradient(): 36 | B = 4 37 | L = 1024 38 | C = 512 39 | v_real, v_imag, f_real, f_imag = torch.rand(B, L, C * 4).chunk(4, dim=-1) 40 | 41 | v_real = v_real.cuda().requires_grad_(True).contiguous() 42 | v_imag = v_imag.cuda().requires_grad_(True).contiguous() 43 | f_real = f_real.cuda().requires_grad_(True).contiguous() 44 | f_imag = f_imag.cuda().requires_grad_(True).contiguous() 45 | 46 | grad_output_real = torch.rand(B, L, C).cuda() 47 | grad_output_image = torch.rand(B, L, C).cuda() 48 | 49 | output1, output2 = naive_forward(v_real, v_imag, f_real, f_imag) 50 | 51 | (output1 * grad_output_real + output2 * grad_output_image).sum().backward() 52 | 53 | v_real_grad_clone = v_real.grad.clone() 54 | v_imag_grad_clone = v_imag.grad.clone() 55 | f_real_grad_clone = f_real.grad.clone() 56 | f_imag_grad_clone = f_imag.grad.clone() 57 | 58 | v_real.grad.zero_() 59 | v_imag.grad.zero_() 60 | f_real.grad.zero_() 61 | f_imag.grad.zero_() 62 | 63 | output3, output4 = complex_scan(v_real, v_imag, f_real, f_imag) 64 | 65 | (output3 * grad_output_real + output4 * grad_output_image).sum().backward() 66 | 67 | diff0 = (output3 - output1).abs().max() 68 | diff00 = (output4 - output2).abs().max() 69 | 70 | diff1 = (v_real.grad - v_real_grad_clone).abs().max() 71 | diff2 = (v_imag.grad - v_imag_grad_clone).abs().max() 72 | diff3 = (f_real.grad - f_real_grad_clone).abs().max() 73 | diff4 = (f_imag.grad - f_imag_grad_clone).abs().max() 74 | print(diff0, diff00, diff1, diff2, diff3, diff4) 75 | breakpoint() 76 | 77 | 78 | -------------------------------------------------------------------------------- /linear_rnn/scan_triton/complex_rnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import triton 4 | import triton.language as tl 5 | from torch.autograd import Function 6 | 7 | 8 | @triton.jit 9 | def fwd_sequential_scan_complex( 10 | v_real, 11 | v_imag, 12 | decay_real, 13 | decay_imag, 14 | hidden_real, 15 | hidden_imag, 16 | B, 17 | L, 18 | C, 19 | BLOCK_M: tl.constexpr, 20 | ): 21 | 22 | offset_b = tl.program_id(0) 23 | 24 | if offset_b >= B: 25 | return 26 | 27 | offset_n = tl.program_id(1) 28 | ptr = tl.arange(0, BLOCK_M) + offset_b * L * C + offset_n * BLOCK_M 29 | h_real = tl.zeros([BLOCK_M,], dtype=tl.float32) 30 | h_imag = tl.zeros([BLOCK_M,], dtype=tl.float32) 31 | 32 | for _ in range(L): 33 | x_real = tl.load(v_real + ptr).to(tl.float32) 34 | x_imag = tl.load(v_imag + ptr).to(tl.float32) 35 | 36 | f_real = tl.load(decay_real + ptr).to(tl.float32) 37 | f_imag = tl.load(decay_imag + ptr).to(tl.float32) 38 | 39 | h_real_new = h_real * f_real - h_imag * f_imag + x_real 40 | h_imag_new = h_real * f_imag + h_imag * f_real + x_imag 41 | 42 | tl.store(hidden_real + ptr, h_real_new.to(hidden_real.dtype.element_ty)) 43 | tl.store(hidden_imag + ptr, h_imag_new.to(hidden_imag.dtype.element_ty)) 44 | h_real = h_real_new 45 | h_imag = h_imag_new 46 | ptr += C 47 | 48 | 49 | @triton.jit 50 | def bwd_sequential_scan_complex( 51 | 52 | grad_output_real, 53 | grad_output_imag, 54 | 55 | v_real, 56 | v_imag, 57 | 58 | f_real, 59 | f_imag, 60 | 61 | hidden_real, 62 | hidden_imag, 63 | 64 | B, 65 | L, 66 | C, 67 | BLOCK_M: tl.constexpr, 68 | ): 69 | 70 | offset_b = tl.program_id(0) 71 | 72 | if offset_b >= B: 73 | return 74 | 75 | offset_n = tl.program_id(1) 76 | 77 | ptr = tl.arange(0, BLOCK_M) + offset_b * L * C + (L-1) * C + offset_n * BLOCK_M 78 | 79 | grad_h_real = tl.zeros([BLOCK_M,], dtype=tl.float32) 80 | grad_h_imag = tl.zeros([BLOCK_M,], dtype=tl.float32) 81 | 82 | for time_step in range(L-1, -1, -1): 83 | grad_real = tl.load(grad_output_real + ptr).to(tl.float32) 84 | grad_imag = tl.load(grad_output_imag + ptr).to(tl.float32) 85 | 86 | grad_h_real += grad_real 87 | grad_h_imag += grad_imag 88 | 89 | decay_real = tl.load(f_real + ptr).to(tl.float32) 90 | decay_imag = tl.load(f_imag + ptr).to(tl.float32) 91 | 92 | h_real = tl.load(hidden_real + ptr - C, mask= ptr >= (offset_b * L * C + C), other=0.0).to(tl.float32) 93 | h_imag = tl.load(hidden_imag + ptr - C, mask= ptr >= (offset_b * L * C + C), other=0.0).to(tl.float32) 94 | 95 | grad_f_real = (grad_h_real * h_real + grad_h_imag * h_imag) 96 | grad_f_imag = (grad_h_imag * h_real - grad_h_real * h_imag) 97 | 98 | tl.store(f_real + ptr, grad_f_real.to(f_real.dtype.element_ty)) 99 | tl.store(f_imag + ptr, grad_f_imag.to(f_real.dtype.element_ty)) 100 | 101 | tl.store(v_real + ptr, grad_h_real.to(v_real.dtype.element_ty)) 102 | tl.store(v_imag + ptr, grad_h_imag.to(v_real.dtype.element_ty)) 103 | 104 | grad_h_real_new = grad_h_real * decay_real + grad_h_imag * decay_imag 105 | grad_h_imag_new = grad_h_imag * decay_real - grad_h_real * decay_imag 106 | 107 | grad_h_real = grad_h_real_new 108 | grad_h_imag = grad_h_imag_new 109 | 110 | ptr -= C 111 | 112 | 113 | 114 | class TritonSequentialScan_Complex(Function): 115 | @staticmethod 116 | @torch.cuda.amp.custom_fwd 117 | def forward(ctx, v_real, v_imag, f_real, f_imag): 118 | B,L,C = v_real.shape 119 | num_warps = 8 120 | assert C % 256 == 0, 'Hidden dimension must be multiple of 256' 121 | v_real = v_real.contiguous() 122 | v_imag = v_imag.contiguous() 123 | f_real = f_real.contiguous() 124 | f_imag = f_imag.contiguous() 125 | 126 | hidden_real = torch.zeros_like(v_real).contiguous() 127 | hidden_imag = torch.zeros_like(v_imag).contiguous() 128 | 129 | fwd_sequential_scan_complex[(B, int(C/256))]( 130 | v_real, 131 | v_imag, 132 | f_real, 133 | f_imag, 134 | hidden_real, 135 | hidden_imag, 136 | B, 137 | L, 138 | C, 139 | BLOCK_M=256, 140 | num_warps=num_warps 141 | ) 142 | 143 | ctx.save_for_backward(v_real, v_imag, f_real, f_imag, hidden_real, hidden_imag) 144 | return hidden_real, hidden_imag 145 | 146 | @staticmethod 147 | @torch.cuda.amp.custom_bwd 148 | def backward(ctx, grad_output_real, grad_output_imag): 149 | 150 | v_real, v_imag, f_real, f_imag, hidden_real, hidden_imag = ctx.saved_tensors 151 | B, L, C = v_real.shape 152 | 153 | num_warps = 8 154 | 155 | 156 | bwd_sequential_scan_complex[(B, int(C/256))]( 157 | grad_output_real, 158 | grad_output_imag, 159 | 160 | v_real, 161 | v_imag, 162 | f_real, 163 | f_imag, 164 | hidden_real, 165 | hidden_imag, 166 | 167 | B, 168 | L, 169 | C, 170 | BLOCK_M=256, 171 | num_warps=num_warps 172 | ) 173 | return v_real, v_imag, f_real, f_imag 174 | 175 | 176 | complex_scan = TritonSequentialScan_Complex.apply 177 | 178 | 179 | -------------------------------------------------------------------------------- /linear_rnn/scan_triton/real_rnn_tie_input_gate.py: -------------------------------------------------------------------------------- 1 | import triton 2 | import triton.language as tl 3 | import torch 4 | import torch.nn.functional as F 5 | from torch.autograd import Function 6 | 7 | # input: (B, L, D) 8 | @triton.jit 9 | def fwd_sequential_scan( 10 | v, 11 | f1, 12 | hidden, 13 | B, 14 | L, 15 | C, 16 | BLOCK_M: tl.constexpr, 17 | ): 18 | 19 | offset_b = tl.program_id(0) 20 | 21 | if offset_b >= B: 22 | return 23 | 24 | offset_n = tl.program_id(1) 25 | ptr = tl.arange(0, BLOCK_M) + offset_b * L * C + offset_n * BLOCK_M 26 | h1 = tl.zeros([BLOCK_M,], dtype=tl.float32) 27 | 28 | for _ in range(L): 29 | x0 = tl.load(v + ptr).to(tl.float32) 30 | decay1 = tl.load(f1 + ptr).to(tl.float32) 31 | h1 = (h1 - x0) * decay1 + x0 32 | tl.store(hidden + ptr, h1.to(hidden.dtype.element_ty) ) 33 | ptr += C 34 | 35 | 36 | @triton.jit 37 | def fwd_sequential_scan_fused( 38 | v, 39 | f1, 40 | hidden, 41 | B, 42 | L, 43 | C, 44 | BLOCK_M: tl.constexpr, 45 | ): 46 | 47 | offset_b = tl.program_id(0) 48 | 49 | if offset_b >= B: 50 | return 51 | 52 | offset_n = tl.program_id(1) 53 | ptr = tl.arange(0, BLOCK_M) + offset_b * L * C + offset_n * BLOCK_M 54 | h1 = tl.zeros([BLOCK_M,], dtype=tl.float32) 55 | 56 | for _ in range(L): 57 | x0 = tl.load(v + ptr).to(tl.float32) 58 | decay1 = tl.load(f1 + ptr).to(tl.float32) 59 | decay1 = tl.sigmoid(decay1) 60 | h1 = (h1 - x0) * decay1 + x0 61 | tl.store(hidden + ptr, h1.to(hidden.dtype.element_ty) ) 62 | ptr += C 63 | 64 | 65 | # input: (B, L, D) 66 | @triton.jit 67 | def bwd_sequential_scan( 68 | grad_output, 69 | 70 | v, 71 | f, 72 | 73 | h, 74 | 75 | B, 76 | L, 77 | C, 78 | BLOCK_M: tl.constexpr, 79 | ): 80 | 81 | 82 | offset_b = tl.program_id(0) 83 | 84 | if offset_b >= B: 85 | return 86 | 87 | offset_n = tl.program_id(1) 88 | 89 | ptr = tl.arange(0, BLOCK_M) + offset_b * L * C + (L-1) * C + offset_n * BLOCK_M 90 | 91 | grad_h = tl.zeros([BLOCK_M,], dtype=tl.float32) 92 | 93 | for time_step in range(L-1, -1, -1): 94 | 95 | grad = tl.load(grad_output + ptr).to(tl.float32) 96 | 97 | grad_h += grad 98 | 99 | decay = tl.load(f + ptr).to(tl.float32) 100 | input = tl.load(v + ptr).to(tl.float32) 101 | 102 | grad_v = (1 - decay) * grad_h 103 | tl.store(v + ptr, grad_v.to(v.dtype.element_ty)) 104 | 105 | hidden_state = tl.load(h + ptr - C, mask= ptr >= (offset_b * L * C + C), other=0.0).to(tl.float32) 106 | 107 | grad_f = grad_h * (hidden_state - input) 108 | 109 | tl.store(f + ptr, grad_f.to(f.dtype.element_ty)) 110 | 111 | grad_h *= decay 112 | 113 | 114 | 115 | ptr -= C 116 | 117 | # input: (B, L, D) 118 | @triton.jit 119 | def bwd_sequential_scan_fused( 120 | grad_output, 121 | 122 | v, 123 | f, 124 | 125 | h, 126 | 127 | B, 128 | L, 129 | C, 130 | BLOCK_M: tl.constexpr, 131 | ): 132 | 133 | offset_b = tl.program_id(0) 134 | 135 | if offset_b >= B: 136 | return 137 | 138 | offset_n = tl.program_id(1) 139 | 140 | ptr = tl.arange(0, BLOCK_M) + offset_b * L * C + (L-1) * C + offset_n * BLOCK_M 141 | 142 | grad_h = tl.zeros([BLOCK_M,], dtype=tl.float32) 143 | 144 | for time_step in range(L-1, -1, -1): 145 | 146 | grad = tl.load(grad_output + ptr).to(tl.float32) 147 | 148 | grad_h += grad 149 | 150 | decay = tl.load(f + ptr).to(tl.float32) 151 | decay = tl.sigmoid(decay) 152 | input = tl.load(v + ptr).to(tl.float32) 153 | 154 | grad_v = (1 - decay) * grad_h 155 | tl.store(v + ptr, grad_v.to(v.dtype.element_ty)) 156 | 157 | hidden_state = tl.load(h + ptr - C, mask= ptr >= (offset_b * L * C + C), other=0.0).to(tl.float32) 158 | 159 | grad_f = grad_h * (hidden_state - input) * decay * (1 - decay) 160 | 161 | tl.store(f + ptr, grad_f.to(f.dtype.element_ty)) 162 | 163 | grad_h *= decay 164 | 165 | 166 | ptr -= C 167 | 168 | 169 | class TritonSequentialScan(Function): 170 | @staticmethod 171 | @torch.cuda.amp.custom_fwd 172 | def forward(ctx, v, f1): 173 | B,L,C = v.shape 174 | num_warps = 8 175 | assert C % 256 == 0 176 | v = v.contiguous() 177 | f1 = f1.contiguous() 178 | hidden = torch.zeros_like(v).contiguous() 179 | 180 | fwd_sequential_scan[(B, int(C/256) )]( 181 | v, 182 | f1, 183 | hidden, 184 | B, 185 | L, 186 | C, 187 | BLOCK_M=256, 188 | num_warps=num_warps 189 | ) 190 | 191 | ctx.save_for_backward(v, f1, hidden) 192 | return hidden 193 | 194 | @staticmethod 195 | @torch.cuda.amp.custom_bwd 196 | def backward(ctx, grad_output): 197 | v, f1, hidden = ctx.saved_tensors 198 | B, L, C = v.shape 199 | 200 | num_warps = 8 201 | 202 | bwd_sequential_scan[(B, int(C/256))]( 203 | grad_output, 204 | v, 205 | f1, 206 | hidden, 207 | B, 208 | L, 209 | C, 210 | BLOCK_M=256, 211 | num_warps=num_warps 212 | ) 213 | return v, f1 214 | 215 | 216 | class TritonSequentialScanFused(Function): 217 | @staticmethod 218 | @torch.cuda.amp.custom_fwd 219 | def forward(ctx, v, f1): 220 | B,L,C = v.shape 221 | num_warps = 8 222 | assert C % 256 == 0 223 | v = v.contiguous() 224 | f1 = f1.contiguous() 225 | hidden = torch.zeros_like(v).contiguous() 226 | 227 | fwd_sequential_scan_fused[(B, int(C/256) )]( 228 | v, 229 | f1, 230 | hidden, 231 | B, 232 | L, 233 | C, 234 | BLOCK_M=256, 235 | num_warps=num_warps 236 | ) 237 | 238 | ctx.save_for_backward(v, f1, hidden) 239 | return hidden 240 | 241 | @staticmethod 242 | @torch.cuda.amp.custom_bwd 243 | def backward(ctx, grad_output): 244 | v, f1, hidden = ctx.saved_tensors 245 | B, L, C = v.shape 246 | 247 | num_warps = 8 248 | 249 | bwd_sequential_scan_fused[(B, int(C/256))]( 250 | grad_output, 251 | v, 252 | f1, 253 | hidden, 254 | B, 255 | L, 256 | C, 257 | BLOCK_M=256, 258 | num_warps=num_warps 259 | ) 260 | return v, f1 261 | 262 | 263 | real_scan_tie_input_gate = TritonSequentialScan.apply 264 | 265 | 266 | real_scan_tie_input_gate_fused = TritonSequentialScanFused.apply 267 | 268 | 269 | 270 | 271 | --------------------------------------------------------------------------------