├── README.md ├── test_gradient.py ├── benchmark.py ├── .gitignore ├── ref.py ├── p_scan.py ├── selective_scan_interface.py ├── triton_sequential_scan.py └── triton_parallel_scan.py /README.md: -------------------------------------------------------------------------------- 1 | - Strange bug with ```triton_parallel_scan.py```. Different behaviors for differnet input length. Be really careful with ```tl.cumsum```, lots of bugs. See issue [#1](https://github.com/openai/triton/issues/3017). 2 | 3 | - ```triton_sequential_scan.py``` only uses for loop, but it is faster than ```triton_parallel_scan.py```. Only around 2.5 times slower than Tri's cuda implementation. Also support initial state. -------------------------------------------------------------------------------- /test_gradient.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from ref import ref_selective_scan 4 | from triton_parallel_scan import triton_selective_scan 5 | from triton_sequential_scan import triton_selective_scan_sequential 6 | 7 | if __name__ == '__main__': 8 | B = 2 9 | T = 16 10 | D = 512 11 | K = 16 12 | dtype = torch.float32 13 | A = (-(torch.rand(D, K, dtype=dtype)).exp().cuda()).requires_grad_(True) 14 | x = torch.randn(B, T, D, dtype=dtype).cuda().requires_grad_(True) 15 | delta = torch.randn(B, T, D, dtype=dtype).sigmoid().cuda().requires_grad_(True) 16 | B2 = torch.randn(B, T, K, dtype=dtype).cuda().requires_grad_(True) 17 | C = torch.randn(B, T, K, dtype=dtype).cuda().requires_grad_(True) 18 | D2 = torch.randn(D, dtype=dtype).cuda().requires_grad_(True) 19 | 20 | initial_state = torch.randn(B, D, K, dtype=dtype).cuda().requires_grad_(False) 21 | 22 | tri, tri_final = triton_selective_scan_sequential(x, delta, A, B2, C, D2, initial_state) 23 | do = torch.randn_like(tri) 24 | tri.backward(do) 25 | 26 | tri_dc, C.grad = C.grad.clone(), None 27 | tri_dx, x.grad = x.grad.clone(), None 28 | tri_db, B2.grad = B2.grad.clone(), None 29 | tri_delta, delta.grad = delta.grad.clone(), None 30 | tri_A, A.grad = A.grad.clone(), None 31 | 32 | ref, ref_final = ref_selective_scan(x, delta, A, B2, C, D2, initial_state) 33 | 34 | print((tri-ref).abs().max()) 35 | print((tri_final-ref_final).abs().max()) 36 | 37 | ref.backward(do) 38 | ref_dc, C.grad = C.grad.clone(), None 39 | ref_dx, x.grad = x.grad.clone(), None 40 | ref_db, B2.grad = B2.grad.clone(), None 41 | ref_delta, delta.grad = delta.grad.clone(), None 42 | ref_A, A.grad = A.grad.clone(), None 43 | 44 | print((tri_dc-ref_dc).abs().max()) 45 | print((tri_dx-ref_dx).abs().max()) 46 | print((tri_db-ref_db).abs().max()) 47 | print((tri_delta-ref_delta).abs().max()) 48 | print((tri_A-ref_A).abs().max()) 49 | breakpoint() 50 | 51 | 52 | -------------------------------------------------------------------------------- /benchmark.py: -------------------------------------------------------------------------------- 1 | 2 | import time 3 | 4 | import torch 5 | from p_scan import pscan_selective_scan 6 | from ref import ref_selective_scan 7 | from selective_scan_interface import selective_scan_fn, selective_scan_ref 8 | from triton_parallel_scan import triton_selective_scan 9 | from triton_sequential_scan import triton_selective_scan_sequential 10 | 11 | if __name__ == '__main__': 12 | B = 32 13 | T = 2048 14 | D = 1024 15 | K = 16 16 | dtype = torch.bfloat16 17 | A = (-(torch.rand(D, K, dtype=torch.float32)).exp().cuda()).requires_grad_(True) 18 | x = torch.randn(B, T, D, dtype=dtype).cuda().requires_grad_(True) 19 | delta = torch.randn(B, T, D, dtype=dtype).sigmoid().cuda().requires_grad_(True) 20 | B2 = torch.randn(B, T, K, dtype=dtype).cuda().requires_grad_(True) 21 | C = torch.randn(B, T, K, dtype=dtype).cuda().requires_grad_(True) 22 | D2 = torch.randn(D, dtype=torch.float32).cuda().requires_grad_(True) 23 | 24 | 25 | 26 | do = torch.randn_like(x) 27 | print("Warmup start") 28 | for _ in range(50): 29 | o, final_state = triton_selective_scan_sequential(x, delta, A, B2, C, D2) 30 | o.backward(do, retain_graph=True) 31 | o = selective_scan_fn(x, delta, A, B2, C, D2) 32 | o.backward(do, retain_graph=True) 33 | 34 | print("Warmup done") 35 | start = time.time() 36 | torch.cuda.synchronize() 37 | # with torch.no_grad(): 38 | for _ in range(100): 39 | o, final_state = triton_selective_scan_sequential(x, delta, A, B2, C, D2) 40 | o.backward(do, retain_graph=True) 41 | torch.cuda.synchronize() 42 | end = time.time() 43 | print('triton', end - start) 44 | 45 | start = time.time() 46 | torch.cuda.synchronize() 47 | # with torch.no_grad(): 48 | for _ in range(100): 49 | o = selective_scan_fn(x, delta, A, B2, C, D2) 50 | o.backward(do, retain_graph=True) 51 | torch.cuda.synchronize() 52 | end = time.time() 53 | print('cuda', end - start) 54 | 55 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | data.tar.gz 2 | *.tsf 3 | *.ckpt 4 | .ipynb_checkpoints 5 | */.ipynb_checkpoints/* 6 | *.lprof 7 | 8 | .DS_Store 9 | .idea/ 10 | outputs/ 11 | 12 | data 13 | 14 | # Created by https://www.gitignore.io/api/python 15 | # Edit at https://www.gitignore.io/?templates=python 16 | 17 | ### Python ### 18 | # Byte-compiled / optimized / DLL files 19 | __pycache__/ 20 | *.py[cod] 21 | *$py.class 22 | 23 | # C extensions 24 | *.so 25 | 26 | # Distribution / packaging 27 | .Python 28 | build/ 29 | develop-eggs/ 30 | dist/ 31 | downloads/ 32 | eggs/ 33 | .eggs/ 34 | lib/ 35 | lib64/ 36 | parts/ 37 | sdist/ 38 | var/ 39 | wheels/ 40 | pip-wheel-metadata/ 41 | share/python-wheels/ 42 | *.egg-info/ 43 | .installed.cfg 44 | *.egg 45 | MANIFEST 46 | 47 | # PyInstaller 48 | # Usually these files are written by a python script from a template 49 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 50 | *.manifest 51 | *.spec 52 | 53 | # Installer logs 54 | pip-log.txt 55 | pip-delete-this-directory.txt 56 | 57 | # Unit test / coverage reports 58 | htmlcov/ 59 | .tox/ 60 | .nox/ 61 | .coverage 62 | .coverage.* 63 | .cache 64 | nosetests.xml 65 | coverage.xml 66 | *.cover 67 | .hypothesis/ 68 | .pytest_cache/ 69 | 70 | # Translations 71 | *.mo 72 | *.pot 73 | 74 | # Scrapy stuff: 75 | .scrapy 76 | 77 | # Sphinx documentation 78 | docs/_build/ 79 | 80 | # PyBuilder 81 | target/ 82 | 83 | # pyenv 84 | .python-version 85 | 86 | # pipenv 87 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 88 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 89 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 90 | # install all needed dependencies. 91 | #Pipfile.lock 92 | 93 | # celery beat schedule file 94 | celerybeat-schedule 95 | 96 | # SageMath parsed files 97 | *.sage.py 98 | 99 | # Spyder project settings 100 | .spyderproject 101 | .spyproject 102 | 103 | # Rope project settings 104 | .ropeproject 105 | 106 | # Mr Developer 107 | .mr.developer.cfg 108 | .project 109 | .pydevproject 110 | 111 | # mkdocs documentation 112 | /site 113 | 114 | # mypy 115 | .mypy_cache/ 116 | .dmypy.json 117 | dmypy.json 118 | 119 | # Pyre type checker 120 | .pyre/ 121 | 122 | # End of https://www.gitignore.io/api/python 123 | -------------------------------------------------------------------------------- /ref.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import einsum, rearrange, repeat 3 | 4 | 5 | # credit: https://github.com/johnma2006/mamba-minimal/blob/master/model.py#L275 6 | def ref_selective_scan(u, delta, A, B, C, D, initial_state): 7 | """Does selective scan algorithm. See: 8 | - Section 2 State Space Models in the Mamba paper [1] 9 | - Algorithm 2 in Section 3.2 in the Mamba paper [1] 10 | - run_SSM(A, B, C, u) in The Annotated S4 [2] 11 | 12 | This is the classic discrete state space formula: 13 | x(t + 1) = Ax(t) + Bu(t) 14 | y(t) = Cx(t) + Du(t) 15 | except B and C (and the step size delta, which is used for discretization) are dependent on the input x(t). 16 | 17 | Args: 18 | u: shape (b, l, d_in) (See Glossary at top for definitions of b, l, d_in, n...) 19 | delta: shape (b, l, d_in) 20 | A: shape (d_in, n) 21 | B: shape (b, l, n) 22 | C: shape (b, l, n) 23 | D: shape (d_in,) 24 | 25 | Returns: 26 | output: shape (b, l, d_in) 27 | 28 | Official Implementation: 29 | selective_scan_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L86 30 | Note: I refactored some parts out of `selective_scan_ref` out, so the functionality doesn't match exactly. 31 | 32 | """ 33 | original_dtype = u.dtype 34 | u, delta, A, B, C, D = map(lambda x: x.float(), (u, delta, A, B, C, D)) 35 | (b, l, d_in) = u.shape 36 | n = A.shape[1] 37 | 38 | # Discretize continuous parameters (A, B) 39 | # - A is discretized using zero-order hold (ZOH) discretization (see Section 2 Equation 4 in the Mamba paper [1]) 40 | # - B is discretized using a simplified Euler discretization instead of ZOH. From a discussion with authors: 41 | # "A is the more important term and the performance doesn't change much with the simplification on B" 42 | deltaA = torch.exp(einsum(delta, A, 'b l d_in, d_in n -> b l d_in n')) 43 | deltaB_u = einsum(delta, B, u, 'b l d_in, b l n, b l d_in -> b l d_in n') 44 | 45 | # Perform selective scan (see scan_SSM() in The Annotated S4 [2]) 46 | # Note that the below is sequential, while the official implementation does a much faster parallel scan that 47 | # is additionally hardware-aware (like FlashAttention). 48 | x = torch.zeros((b, d_in, n), device=deltaA.device) 49 | x += initial_state 50 | ys = [] 51 | for i in range(l): 52 | x = deltaA[:, i] * x + deltaB_u[:, i] 53 | y = einsum(x, C[:, i, :], 'b d_in n, b n -> b d_in') 54 | ys.append(y) 55 | y = torch.stack(ys, dim=1) # shape (b, l, d_in) 56 | 57 | y = y + u * D[None, None, :] 58 | 59 | return y.to(original_dtype), x -------------------------------------------------------------------------------- /p_scan.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | # credit: https://github.com/alxndrTL/mamba.py/blob/main/pscan.py 5 | from einops import einsum, rearrange, repeat 6 | 7 | """ 8 | 9 | An implementation of the parallel scan operation in PyTorch (Blelloch version). 10 | This code follows the skeleton proposed by Francois Fleuret in his pscan. However, the keys differences are : 11 | -it has been written in an iterative way (rather than recursive) 12 | -the backward pass has been rewritten 13 | 14 | Please see docs/pscan.ipynb for a detailed explanation of what happens here. 15 | 16 | """ 17 | 18 | # TODO eviter les .flip() en codant un pscan reverse (avec flag) 19 | 20 | class PScan(torch.autograd.Function): 21 | @staticmethod 22 | def pscan(A, X): 23 | # A : (B, D, L, N) 24 | # X : (B, D, L, N) 25 | 26 | # modifies X in place by doing a parallel scan. 27 | # more formally, X will be populated by these values : 28 | # H[t] = A[t] * H[t-1] + X[t] with H[0] = 0 29 | # which are computed in parallel (2*log2(T) sequential steps (ideally), instead of T sequential steps) 30 | 31 | B, D, L, _ = A.size() 32 | num_steps = int(math.log2(L)) 33 | 34 | # up sweep or reduction step 35 | Aa = A 36 | Xa = X 37 | for k in range(num_steps): 38 | T = 2 * (Xa.size(2) // 2) 39 | 40 | Aa = Aa[:, :, :T].view(B, D, T//2, 2, -1) 41 | Xa = Xa[:, :, :T].view(B, D, T//2, 2, -1) 42 | 43 | Xa[:, :, :, 1].add_(Aa[:, :, :, 1].mul(Xa[:, :, :, 0])) 44 | Aa[:, :, :, 1].mul_(Aa[:, :, :, 0]) 45 | 46 | Aa = Aa[:, :, :, 1] 47 | Xa = Xa[:, :, :, 1] 48 | 49 | # down sweep 50 | for k in range(num_steps-1, -1, -1): 51 | Aa = A[:, :, 2**k-1:L:2**k] 52 | Xa = X[:, :, 2**k-1:L:2**k] 53 | 54 | T = 2 * (Xa.size(2) // 2) 55 | 56 | if T < Xa.size(2): 57 | Xa[:, :, -1].add_(Aa[:, :, -1].mul(Xa[:, :, -2])) 58 | Aa[:, :, -1].mul_(Aa[:, :, -2]) 59 | 60 | Aa = Aa[:, :, :T].view(B, D, T//2, 2, -1) 61 | Xa = Xa[:, :, :T].view(B, D, T//2, 2, -1) 62 | 63 | Xa[:, :, 1:, 0].add_(Aa[:, :, 1:, 0].mul(Xa[:, :, :-1, 1])) 64 | Aa[:, :, 1:, 0].mul_(Aa[:, :, :-1, 1]) 65 | 66 | @staticmethod 67 | def forward(ctx, A_in, X_in): 68 | """ 69 | Applies the parallel scan operation, as defined above. Returns a new tensor. 70 | 71 | Args: 72 | A_in : (B, L, D, N) 73 | X_in : (B, L, D, N) 74 | 75 | Returns: 76 | H : (B, L, D, N) 77 | """ 78 | 79 | # clone tensor (in-place ops) 80 | A = A_in.clone() # (B, L, D, N) 81 | X = X_in.clone() # (B, L, D, N) 82 | 83 | # prepare tensors 84 | A = A.transpose(2, 1) # (B, D, L, N) 85 | X = X.transpose(2, 1) # (B, D, L, N) 86 | 87 | # parallel scan 88 | PScan.pscan(A, X) 89 | 90 | ctx.save_for_backward(A_in, X) 91 | 92 | return X.transpose(2, 1) 93 | 94 | @staticmethod 95 | def backward(ctx, grad_output_in): 96 | """ 97 | Flows the gradient from the output to the input. Returns two new tensors. 98 | 99 | Args: 100 | ctx : A_in : (B, L, D, N), X : (B, D, L, N) 101 | grad_output_in : (B, L, D, N) 102 | 103 | Returns: 104 | gradA : (B, L, D, N), gradX : (B, L, D, N) 105 | """ 106 | 107 | A_in, X = ctx.saved_tensors 108 | 109 | # clone tensors 110 | A = A_in.clone() 111 | # grad_output_in will be cloned with flip() 112 | 113 | # prepare tensors 114 | A = A.transpose(2, 1) # (B, D, L, N) 115 | A = torch.cat((A[:, :, :1], A[:, :, 1:].flip(2)), dim=2) 116 | grad_output_b = grad_output_in.transpose(2, 1) 117 | 118 | # reverse parallel scan 119 | grad_output_b = grad_output_b.flip(2) 120 | PScan.pscan(A, grad_output_b) 121 | grad_output_b = grad_output_b.flip(2) 122 | 123 | Q = torch.zeros_like(X) 124 | Q[:, :, 1:].add_(X[:, :, :-1] * grad_output_b[:, :, 1:]) 125 | 126 | return Q.transpose(2, 1), grad_output_b.transpose(2, 1) 127 | 128 | pscan = PScan.apply 129 | 130 | 131 | def pscan_selective_scan(u, delta, A, B, C, D): 132 | """Does selective scan algorithm. See: 133 | - Section 2 State Space Models in the Mamba paper [1] 134 | - Algorithm 2 in Section 3.2 in the Mamba paper [1] 135 | - run_SSM(A, B, C, u) in The Annotated S4 [2] 136 | 137 | This is the classic discrete state space formula: 138 | x(t + 1) = Ax(t) + Bu(t) 139 | y(t) = Cx(t) + Du(t) 140 | except B and C (and the step size delta, which is used for discretization) are dependent on the input x(t). 141 | 142 | Args: 143 | u: shape (b, l, d_in) (See Glossary at top for definitions of b, l, d_in, n...) 144 | delta: shape (b, l, d_in) 145 | A: shape (d_in, n) 146 | B: shape (b, l, n) 147 | C: shape (b, l, n) 148 | D: shape (d_in,) 149 | 150 | Returns: 151 | output: shape (b, l, d_in) 152 | 153 | Official Implementation: 154 | selective_scan_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L86 155 | Note: I refactored some parts out of `selective_scan_ref` out, so the functionality doesn't match exactly. 156 | 157 | """ 158 | original_dtype = u.dtype 159 | # u, delta, A, B, C, D = map(lambda x: x.float(), (u, delta, A, B, C, D)) 160 | (b, l, d_in) = u.shape 161 | n = A.shape[1] 162 | 163 | # Discretize continuous parameters (A, B) 164 | # - A is discretized using zero-order hold (ZOH) discretization (see Section 2 Equation 4 in the Mamba paper [1]) 165 | # - B is discretized using a simplified Euler discretization instead of ZOH. From a discussion with authors: 166 | # "A is the more important term and the performance doesn't change much with the simplification on B" 167 | deltaA = torch.exp(einsum(delta, A, 'b l d_in, d_in n -> b l d_in n')) 168 | deltaB_u = einsum(delta, B, u, 'b l d_in, b l n, b l d_in -> b l d_in n') 169 | 170 | hs = pscan(deltaA, deltaB_u) 171 | y = (hs @ C.unsqueeze(-1)).squeeze(3) # (B, L, ED, N) @ (B, L, N, 1) -> (B, L, ED, 1) 172 | y = y + D * u 173 | 174 | return y.to(original_dtype) -------------------------------------------------------------------------------- /selective_scan_interface.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Tri Dao, Albert Gu. 2 | # https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py 3 | # import causal_conv1d_cuda 4 | import selective_scan_cuda 5 | import torch 6 | import torch.nn.functional as F 7 | # from causal_conv1d import causal_conv1d_fn 8 | from einops import rearrange, repeat 9 | from torch.cuda.amp import custom_bwd, custom_fwd 10 | 11 | 12 | class SelectiveScanFn(torch.autograd.Function): 13 | 14 | @staticmethod 15 | def forward(ctx, u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, 16 | return_last_state=False): 17 | if u.stride(-1) != 1: 18 | u = u.contiguous() 19 | if delta.stride(-1) != 1: 20 | delta = delta.contiguous() 21 | if D is not None: 22 | D = D.contiguous() 23 | if B.stride(-1) != 1: 24 | B = B.contiguous() 25 | if C.stride(-1) != 1: 26 | C = C.contiguous() 27 | if z is not None and z.stride(-1) != 1: 28 | z = z.contiguous() 29 | if B.dim() == 3: 30 | B = rearrange(B, "b dstate l -> b 1 dstate l") 31 | ctx.squeeze_B = True 32 | if C.dim() == 3: 33 | C = rearrange(C, "b dstate l -> b 1 dstate l") 34 | ctx.squeeze_C = True 35 | 36 | out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus) 37 | ctx.delta_softplus = delta_softplus 38 | ctx.has_z = z is not None 39 | last_state = x[:, :, -1, 1::2] # (batch, dim, dstate) 40 | if not ctx.has_z: 41 | ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x) 42 | return out if not return_last_state else (out, last_state) 43 | else: 44 | ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out) 45 | out_z = rest[0] 46 | return out_z if not return_last_state else (out_z, last_state) 47 | 48 | @staticmethod 49 | def backward(ctx, dout, *args): 50 | if not ctx.has_z: 51 | u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors 52 | z = None 53 | out = None 54 | else: 55 | u, delta, A, B, C, D, z, delta_bias, x, out = ctx.saved_tensors 56 | if dout.stride(-1) != 1: 57 | dout = dout.contiguous() 58 | # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the 59 | # backward of selective_scan_cuda with the backward of chunk). 60 | # Here we just pass in None and dz will be allocated in the C++ code. 61 | du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd( 62 | u, delta, A, B, C, D, z, delta_bias, dout, x, out, None, ctx.delta_softplus, 63 | False # option to recompute out_z, not used here 64 | ) 65 | dz = rest[0] if ctx.has_z else None 66 | dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB 67 | dC = dC.squeeze(1) if getattr(ctx, "squeeze_C", False) else dC 68 | return (du, ddelta, dA, dB, dC, 69 | dD if D is not None else None, 70 | dz, 71 | ddelta_bias if delta_bias is not None else None, 72 | None, 73 | None) 74 | 75 | 76 | def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, 77 | return_last_state=False): 78 | """if return_last_state is True, returns (out, last_state) 79 | last_state has shape (batch, dim, dstate). Note that the gradient of the last state is 80 | not considered in the backward pass. 81 | """ 82 | u = u.transpose(-1, -2).contiguous() 83 | B = B.transpose(-1, -2).contiguous() 84 | C = C.transpose(-1, -2).contiguous() 85 | delta = delta.transpose(-1, -2).contiguous() 86 | o = SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state) 87 | return o.transpose(-1, -2).contiguous() 88 | 89 | 90 | 91 | def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, 92 | return_last_state=False): 93 | """ 94 | u: r(B D L) 95 | delta: r(B D L) 96 | A: c(D N) or r(D N) 97 | B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L) 98 | C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L) 99 | D: r(D) 100 | z: r(B D L) 101 | delta_bias: r(D), fp32 102 | 103 | out: r(B D L) 104 | last_state (optional): r(B D dstate) or c(B D dstate) 105 | """ 106 | dtype_in = u.dtype 107 | u = u.float() 108 | delta = delta.float() 109 | if delta_bias is not None: 110 | delta = delta + delta_bias[..., None].float() 111 | if delta_softplus: 112 | delta = F.softplus(delta) 113 | batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1] 114 | is_variable_B = B.dim() >= 3 115 | is_variable_C = C.dim() >= 3 116 | if A.is_complex(): 117 | if is_variable_B: 118 | B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2)) 119 | if is_variable_C: 120 | C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2)) 121 | else: 122 | B = B.float() 123 | C = C.float() 124 | x = A.new_zeros((batch, dim, dstate)) 125 | ys = [] 126 | deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A)) 127 | if not is_variable_B: 128 | deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u) 129 | else: 130 | if B.dim() == 3: 131 | deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u) 132 | else: 133 | B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1]) 134 | deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u) 135 | if is_variable_C and C.dim() == 4: 136 | C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1]) 137 | last_state = None 138 | for i in range(u.shape[2]): 139 | x = deltaA[:, :, i] * x + deltaB_u[:, :, i] 140 | if not is_variable_C: 141 | y = torch.einsum('bdn,dn->bd', x, C) 142 | else: 143 | if C.dim() == 3: 144 | y = torch.einsum('bdn,bn->bd', x, C[:, :, i]) 145 | else: 146 | y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i]) 147 | if i == u.shape[2] - 1: 148 | last_state = x 149 | if y.is_complex(): 150 | y = y.real * 2 151 | ys.append(y) 152 | y = torch.stack(ys, dim=2) # (batch dim L) 153 | out = y if D is None else y + u * rearrange(D, "d -> d 1") 154 | if z is not None: 155 | out = out * F.silu(z) 156 | out = out.to(dtype=dtype_in) 157 | return out if not return_last_state else (out, last_state) 158 | 159 | -------------------------------------------------------------------------------- /triton_sequential_scan.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import triton 3 | import triton.language as tl 4 | from einops import einsum, rearrange, repeat 5 | from numpy import dtype 6 | 7 | assert triton.__version__ != '2.1.0', 'Triton 2.1.0 is missing enable_fp_fusion. Triton 2.2.0 is required for numerical stability of this implementation.' 8 | 9 | inv_ln2 = 1.44269504 10 | 11 | # credit: https://github.com/proger/accelerated-scan/blob/b9edbad65c673f9a1915efe51dc6bbf50fd7f8c4/accelerated_scan/triton.py 12 | 13 | @torch.jit.script 14 | def reduce(H, C): 15 | return (H * C.unsqueeze(-2)).sum(-1) 16 | 17 | @triton.jit 18 | def fwd_recurrence( 19 | A, 20 | B, 21 | C, 22 | Dt, 23 | X, 24 | Y, 25 | H, 26 | initial_state, 27 | T: tl.constexpr, 28 | D: tl.constexpr, 29 | K: tl.constexpr, 30 | BV: tl.constexpr, 31 | ): 32 | i_bh = tl.program_id(0) 33 | i_v = tl.program_id(1) 34 | 35 | dt_ptr = Dt + i_bh * T * D + i_v * BV + tl.arange(0, BV) 36 | u_ptr = X + i_bh * T * D + i_v * BV + tl.arange(0, BV) 37 | o_ptr = Y + i_bh * T * D + i_v * BV + tl.arange(0, BV) 38 | 39 | h = tl.zeros([BV, K], dtype=tl.float32) 40 | 41 | b_ptr = B + i_bh * T * K + tl.arange(0, K) 42 | 43 | A = A + ((i_v * BV) + tl.arange(0, BV) 44 | [:, None])*K + tl.arange(0, K)[None, :] 45 | _A = tl.load(A) 46 | 47 | H_ptr = H + i_bh * T * D * K + \ 48 | (i_v * BV + tl.arange(0, BV)[:, None]) * K + tl.arange(0, K)[None, :] 49 | 50 | h += tl.load(initial_state + i_bh * D * K + (i_v * BV + 51 | tl.arange(0, BV)[:, None]) * K + tl.arange(0, K)[None, :]) 52 | 53 | for i in range(T): 54 | b = tl.load(b_ptr).to(tl.float32) 55 | dt = tl.load(dt_ptr) 56 | u = tl.load(u_ptr) 57 | x_dt = u * dt 58 | x_dt_b = x_dt[:, None] * b[None, :] 59 | dt_a = tl.exp(dt[:, None] * _A) 60 | h = h * dt_a + x_dt_b 61 | tl.store(H_ptr, h) 62 | 63 | b_ptr += K 64 | dt_ptr += D 65 | u_ptr += D 66 | o_ptr += D 67 | H_ptr += D * K 68 | 69 | 70 | @triton.jit 71 | def bwd_recurrence( 72 | A, 73 | B, 74 | C, 75 | U, 76 | Dt, 77 | DO, 78 | H, 79 | DA, 80 | DB, 81 | DC, 82 | dDt, 83 | dU, 84 | batch, 85 | initial_state, 86 | T: tl.constexpr, 87 | D: tl.constexpr, 88 | K: tl.constexpr, 89 | BV: tl.constexpr, 90 | ): 91 | i_bh = tl.program_id(0) 92 | i_v = tl.program_id(1) 93 | NV = tl.cdiv(D, BV) 94 | 95 | dt_ptr = Dt + i_bh * T * D + i_v * BV + tl.arange(0, BV) + (T - 1) * D 96 | ddt_ptr = dDt + i_bh * T * D + i_v * BV + tl.arange(0, BV) + (T - 1) * D 97 | u_ptr = U + i_bh * T * D + i_v * BV + tl.arange(0, BV) + (T - 1) * D 98 | du_ptr = dU + i_bh * T * D + i_v * BV + tl.arange(0, BV) + (T - 1) * D 99 | do_ptr = DO + i_bh * T * D + i_v * BV + tl.arange(0, BV) + (T - 1) * D 100 | 101 | dh = tl.zeros([BV, K], dtype=tl.float32) 102 | dA = tl.zeros([BV, K], dtype=tl.float32) 103 | 104 | b_ptr = B + i_bh * T * K + tl.arange(0, K) + (T - 1) * K 105 | c_ptr = C + i_bh * T * K + tl.arange(0, K) + (T - 1) * K 106 | dc_ptr = DC + (i_bh + batch * i_v) * T * K + tl.arange(0, K) + (T - 1) * K 107 | db_ptr = DB + (i_bh + batch * i_v) * T * K + tl.arange(0, K) + (T - 1) * K 108 | 109 | A = A + ((i_v * BV) + tl.arange(0, BV) 110 | [:, None])*K + tl.arange(0, K)[None, :] 111 | _A = tl.load(A) 112 | H_ptr = H + i_bh * T * D * K + \ 113 | (i_v * BV + tl.arange(0, BV)[:, None]) * K + \ 114 | tl.arange(0, K)[None, :] + (T - 1) * D * K 115 | 116 | for i in range(T): 117 | h = tl.load(H_ptr) 118 | if i < T - 1: 119 | next_h = tl.load(H_ptr - D * K) 120 | else: 121 | next_h = tl.load(initial_state + i_bh * D * K + (i_v * BV + tl.arange(0, BV)[:, None]) * K + tl.arange(0, K)[None, :]) 122 | b = tl.load(b_ptr).to(tl.float32) 123 | c = tl.load(c_ptr).to(tl.float32) 124 | do = tl.load(do_ptr).to(tl.float32) 125 | u = tl.load(u_ptr).to(tl.float32) 126 | dt = tl.load(dt_ptr).to(tl.float32) 127 | 128 | # gradient wrt output proj 129 | dc = tl.sum(h * do[:, None], axis=0) 130 | tl.store(dc_ptr, dc) 131 | 132 | # graident wrt input 133 | dh += do[:, None] * c[None, :] 134 | dt_u = dt * u 135 | db = tl.sum(dh * dt_u[:, None], axis=0) 136 | tl.store(db_ptr, db) 137 | ddt_u = tl.sum(dh * b[None, :], axis=1) 138 | ddt = ddt_u * u 139 | du = ddt_u * dt 140 | tl.store(du_ptr, du) 141 | 142 | # gradient wrt decay 143 | dt_a = tl.exp(dt[:, None] * _A) 144 | dh *= dt_a 145 | 146 | d_decay = dh * next_h 147 | dA += d_decay * dt[:, None] 148 | ddt += tl.sum(d_decay * _A, axis=1) 149 | tl.store(ddt_ptr, ddt) 150 | 151 | 152 | # update ptr 153 | b_ptr -= K 154 | c_ptr -= K 155 | dc_ptr -= K 156 | db_ptr -= K 157 | dt_ptr -= D 158 | ddt_ptr -= D 159 | u_ptr -= D 160 | du_ptr -= D 161 | do_ptr -= D 162 | H_ptr -= D * K 163 | 164 | DA_ptr = DA + i_bh * D * K + \ 165 | (i_v * BV + tl.arange(0, BV)[:, None]) * K + tl.arange(0, K)[None, :] 166 | tl.store(DA_ptr, dA) 167 | 168 | 169 | class SelectiveScan(torch.autograd.Function): 170 | 171 | @staticmethod 172 | def forward(ctx, u, delta, A, B, C, initial_state=None): 173 | b_size, T, d = u.shape 174 | K = B.shape[-1] 175 | 176 | ctx.b_size = b_size 177 | ctx.T = T 178 | ctx.d = d 179 | ctx.K = K 180 | BV = 64 181 | num_warps = 4 182 | 183 | if b_size <= 16: 184 | BV = 32 185 | num_warps = 2 186 | 187 | NV = triton.cdiv(d, BV) 188 | 189 | o = torch.empty_like(u) 190 | H = torch.empty(b_size, T, d, K, device=u.device, dtype=torch.float32) 191 | 192 | if initial_state is None: 193 | initial_state = torch.zeros( 194 | b_size, d, K, device=u.device, dtype=torch.float32) 195 | 196 | fwd_recurrence[(b_size, NV)](A, B, C, delta, u, o, H, 197 | initial_state, T, d, K, BV, num_warps=num_warps, num_stages=1) 198 | o = reduce(H, C) 199 | ctx.save_for_backward(A, B, C, delta, H, u) 200 | ctx.initial_state = initial_state 201 | return o, H[:,-1] 202 | 203 | @staticmethod 204 | def backward(ctx, grad_output, d_final_state): 205 | do = grad_output 206 | A, B, C, delta, H, u = ctx.saved_tensors 207 | b_size = ctx.b_size 208 | T = ctx.T 209 | d = ctx.d 210 | K = ctx.K 211 | 212 | BV = 64 213 | num_warps = 4 214 | 215 | if b_size <= 16: 216 | BV = 32 217 | num_warps = 2 218 | 219 | NV = triton.cdiv(d, BV) 220 | dA = A.new_empty(b_size, d, K) 221 | du = torch.empty_like(u) 222 | d_delta = torch.empty_like(delta) 223 | db = B.new_empty(NV, b_size, T, K) 224 | dc = C.new_empty(NV, b_size, T, K) 225 | 226 | bwd_recurrence[(b_size, NV)](A, B, C, u, delta, do, H, dA, db, dc, 227 | d_delta, du, b_size, ctx.initial_state, T, d, K, BV, num_warps=num_warps) 228 | db = db.sum(0) 229 | dc = dc.sum(0) 230 | 231 | return du, d_delta, dA.sum(0), db, dc, None 232 | 233 | 234 | def triton_selective_scan_sequential(u, delta, A, B, C, D, initial_state=None): 235 | original_dtype = u.dtype 236 | D = D.float() 237 | A = A.float() 238 | if initial_state is not None: 239 | initial_state = initial_state.detach() 240 | o, final_state = SelectiveScan.apply(u, delta, A, B, C, initial_state) 241 | o = o + D * u 242 | return o.to(original_dtype), final_state 243 | 244 | -------------------------------------------------------------------------------- /triton_parallel_scan.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import triton 3 | import triton.language as tl 4 | from einops import einsum, rearrange, repeat 5 | from numpy import dtype 6 | 7 | assert triton.__version__ != '2.1.0', 'Triton 2.1.0 is missing enable_fp_fusion. Triton 2.2.0 is required for numerical stability of this implementation.' 8 | 9 | inv_ln2 = 1.44269504 10 | 11 | # credit: https://github.com/proger/accelerated-scan/blob/b9edbad65c673f9a1915efe51dc6bbf50fd7f8c4/accelerated_scan/triton.py 12 | 13 | # manual tuple packing by @jackd from https://github.com/openai/triton/issues/2359 14 | @triton.jit 15 | def unpack64(merged): 16 | tl.static_assert(merged.dtype == tl.uint64) 17 | b = (merged & 0xFFFFFFFF).to(tl.uint32).to(tl.float32, bitcast=True) 18 | a = (merged >> 32).to(tl.uint32).to(tl.float32, bitcast=True) 19 | return a, b 20 | 21 | 22 | @triton.jit 23 | def pack64(a, b): 24 | tl.static_assert(a.dtype == tl.float32) 25 | tl.static_assert(b.dtype == tl.float32) 26 | a = a.to(dtype=tl.uint32, bitcast=True).to(tl.uint64) 27 | a = a << 32 28 | b = b.to(dtype=tl.uint32, bitcast=True).to(tl.uint64) 29 | return a | b 30 | 31 | 32 | @triton.jit() 33 | def first_order_op(l, r): 34 | """ 35 | See https://www.cs.cmu.edu/~guyb/papers/Ble93.pdf Section 1.4.1 36 | """ 37 | xl, fl = unpack64(l) 38 | xr, fr = unpack64(r) 39 | x = xl * fr + xr 40 | f = fl * fr 41 | return pack64(x, f) 42 | 43 | 44 | @triton.jit 45 | def forward_scan( 46 | A, 47 | B, 48 | C, 49 | Dt, 50 | X, 51 | Y, 52 | H, 53 | T: tl.constexpr, 54 | D: tl.constexpr, 55 | K: tl.constexpr, 56 | ): 57 | i_bh = tl.program_id(0) 58 | i_v = tl.program_id(1) 59 | 60 | dt_ptr = Dt + i_bh * T * D + i_v * T + tl.arange(0, T) 61 | dt = tl.load(dt_ptr).to(tl.float32) 62 | 63 | x_ptr = X + i_bh * T * D + i_v * T + tl.arange(0, T) 64 | x = tl.load(x_ptr).to(tl.float32) 65 | 66 | x_dt = x * dt 67 | 68 | y = tl.zeros([T,], dtype=tl.float32) 69 | 70 | 71 | for i in range(K): 72 | b_ptr = B + i_bh * T * K + i * T + tl.arange(0, T) 73 | c_ptr = C + i_bh * T * K + i * T + tl.arange(0, T) 74 | H_ptr = H + i_bh * T * D * K + i_v * K * T + i * T + tl.arange(0, T) 75 | b = tl.load(b_ptr).to(tl.float32) 76 | c = tl.load(c_ptr).to(tl.float32) 77 | x_dt_b = x_dt * b 78 | a = tl.load(A + i_v * K + i).to(tl.float32) 79 | dt_a = tl.exp(dt * a) 80 | tuples = pack64(x_dt_b, dt_a) 81 | output_tuples_ = tl.associative_scan(tuples, axis=0, combine_fn=first_order_op) 82 | o, _ = unpack64(output_tuples_) 83 | tl.store(H_ptr, o) 84 | y += (o * c) 85 | 86 | y_ptr = Y + i_bh * T * D + i_v * T + tl.arange(0, T) 87 | tl.store(y_ptr, y) 88 | 89 | 90 | @triton.jit 91 | def backward_scan_du_delta_A( 92 | A, 93 | B, 94 | C, 95 | U, 96 | Dt, 97 | DO, 98 | H, 99 | DA, 100 | DB, 101 | DC, 102 | dDt, 103 | dU, 104 | T: tl.constexpr, 105 | D: tl.constexpr, 106 | K: tl.constexpr, 107 | ): 108 | i_bh = tl.program_id(0) 109 | i_v = tl.program_id(1) 110 | 111 | dt_ptr = Dt + i_bh * T * D + i_v * T + T - tl.arange(0, T) 112 | dt = tl.load(dt_ptr, mask=tl.arange(0, T) > 0, other=0) 113 | 114 | dt_ptr2 = Dt + i_bh * T * D + i_v * T + T - 1 - tl.arange(0, T) 115 | dt2 = tl.load(dt_ptr2) 116 | 117 | dO_ptr = DO + i_bh * T * D + i_v * T + T-1 - tl.arange(0, T) 118 | do = tl.load(dO_ptr) 119 | 120 | d_delta = tl.zeros([T], dtype=tl.float32) 121 | d_u = tl.zeros([T], dtype=tl.float32) 122 | 123 | u_ptr = U + i_bh * T * D + i_v * T + T-1 - tl.arange(0, T) 124 | u = tl.load(u_ptr).to(tl.float32) 125 | 126 | for i in range(K): 127 | H_ptr2 = H + i_bh * T * D * K + i_v * K * T + i * T + (T-1) - tl.arange(0, T) 128 | h2 = tl.load(H_ptr2) 129 | dc = h2 * do 130 | dc_ptr = DC + i_bh * T * K * D + i * D * T + T - 1 - tl.arange(0, T) + i_v * T 131 | tl.store(dc_ptr, dc) 132 | 133 | c_ptr = C + i_bh * T * K + i * T + T - 1 - tl.arange(0, T) 134 | b_ptr = B + i_bh * T * K + i * T + T - 1 - tl.arange(0, T) 135 | # DH_ptr = DH + i_bh * T * D * K + i_v * K * T + i * T + T - 1 - tl.arange(0, T) 136 | b = tl.load(b_ptr).to(tl.float32) 137 | c = tl.load(c_ptr).to(tl.float32) 138 | a = tl.load(A + i_v * K + i).to(tl.float32) 139 | dt_a = tl.math.exp(dt * a) 140 | dt_a = tl.where(tl.arange(0, T) > 0, dt_a, 0) 141 | do_c = c * do 142 | tuples = pack64(do_c, dt_a) 143 | output_tuples_ = tl.associative_scan(tuples, axis=0, combine_fn=first_order_op) 144 | dh, _ = unpack64(output_tuples_) 145 | 146 | # gradient wrt input u 147 | d_u += dh * dt2 * b 148 | d_delta += dh * b * u 149 | d_b = dh * u * dt2 150 | 151 | db_ptr = DB + i_bh * T * K * D + i * D * T + T - 1 - tl.arange(0, T) + i_v * T 152 | tl.store(db_ptr, d_b) 153 | 154 | 155 | # gradient wrt decay 156 | H_ptr = H + i_bh * T * D * K + i_v * K * T + i * T + (T-2) - tl.arange(0, T) 157 | h = tl.load(H_ptr, mask=tl.arange(0, T) < T-1, other=0) 158 | d_decay = h * dh * tl.exp(dt2 * a) 159 | d_delta += d_decay * a 160 | d_a = tl.sum(d_decay * dt2) 161 | tl.store(DA + i_bh * K * D + i_v * K + i, d_a) 162 | 163 | tl.store(dU + i_bh * T * D + i_v * T + T-1 - tl.arange(0, T), d_u) 164 | tl.store(dDt + i_bh * T * D + i_v * T + T-1 - tl.arange(0, T), d_delta) 165 | 166 | # @triton.jit 167 | # def grad_compute( 168 | # A, B, C, Dt, U, H, 169 | # dA, dB, dC, dDt, 170 | # dO, dH, dU, 171 | # batch, 172 | # T: tl.constexpr, 173 | # DV: tl.constexpr, 174 | # DK: tl.constexpr, 175 | # BV: tl.constexpr, 176 | # ): 177 | 178 | # i_bh = tl.program_id(0) 179 | # i_v = tl.program_id(1) 180 | 181 | # prev_h = tl.zeros([BV, DK], dtype=tl.float32) 182 | # dA_acc = tl.zeros([BV, DK], dtype=tl.float32) 183 | 184 | # # [BV, DK] 185 | # A = tl.load(A + (tl.arange(0, BV)[:, None] + i_v * BV) * DK + tl.arange(0, DK)[None, :]) 186 | 187 | # H_ptr = H + i_bh * T * DK * DV + (i_v * BV + tl.arange(0, BV)[:, None]) * DK + tl.arange(0, DK)[None, :] 188 | # dH_ptr = dH + i_bh * T * DK * DV + (i_v * BV + tl.arange(0, BV)[:, None]) * DK + tl.arange(0, DK)[None, :] 189 | 190 | # C_ptr = C + i_bh * T * DK + tl.arange(0, DK) 191 | # dC_ptr = dC + (i_bh + i_v * batch) * T * DK + tl.arange(0, DK) 192 | 193 | # B_ptr = B + i_bh * T * DK + tl.arange(0, DK) 194 | # dB_ptr = dB + (i_bh + i_v * batch) * T * DK + tl.arange(0, DK) 195 | 196 | # Dt_ptr = Dt + i_bh * T * DV + i_v * BV + tl.arange(0, BV) 197 | # dDt_ptr = dDt + i_bh * T * DV + i_v * BV + tl.arange(0, BV) 198 | 199 | # u_ptr = U + i_bh * T * DV + i_v * BV + tl.arange(0, BV) 200 | # du_ptr = dU + i_bh * T * DV + i_v * BV + tl.arange(0, BV) 201 | # do_ptr = dO + i_bh * T * DV + i_v * BV + tl.arange(0, BV) 202 | 203 | # for i in range(T): 204 | # h = tl.load(H_ptr) 205 | # dh = tl.load(dH_ptr) 206 | # b = tl.load(B_ptr) 207 | # delta = tl.load(Dt_ptr).to(tl.float32) 208 | # u = tl.load(u_ptr) 209 | # do = tl.load(do_ptr) 210 | 211 | # # gradient wrt output proj 212 | # dc = tl.sum(do[:, None] * h, axis=0) 213 | # tl.store(dC_ptr, dc) 214 | 215 | # # gradient wrt input 216 | # db = tl.sum(dh * u[:, None] * delta[:, None], axis=0) 217 | # du_delta = tl.sum(dh * b[None, :], axis=1) 218 | # d_delta = du_delta * u 219 | # du = du_delta * delta 220 | # tl.store(dB_ptr, db) 221 | # tl.store(du_ptr, du) 222 | 223 | # # gradient wrt decay 224 | # d_decay = prev_h * dh 225 | # gate = tl.exp(delta[:, None] * A) 226 | # d_decay *= gate 227 | # dA_acc += d_decay * delta[:, None] 228 | # d_delta += tl.sum(d_decay * A, axis=1) 229 | # prev_h = h 230 | 231 | # tl.store(dDt_ptr, d_delta.to(dDt.dtype.element_ty)) 232 | 233 | # # update ptrs 234 | # H_ptr += DK * DV 235 | # dH_ptr += DK * DV 236 | 237 | # B_ptr += DK 238 | # dB_ptr += DK 239 | 240 | # dDt_ptr += DV 241 | # Dt_ptr += DV 242 | 243 | # u_ptr += DV 244 | # du_ptr += DV 245 | 246 | # do_ptr += DV 247 | # dC_ptr += DK 248 | 249 | # #fp32 250 | # dA_ptr = dA + i_bh * DV * DK + (tl.arange(0, BV)[:, None] + i_v * BV) * DK + tl.arange(0, DK)[None, :] 251 | # tl.store(dA_ptr, dA_acc) 252 | 253 | class SelectiveScan(torch.autograd.Function): 254 | 255 | @staticmethod 256 | def forward(ctx, u, delta, A, B, C): 257 | b_size, T, d = u.shape 258 | K = B.shape[-1] 259 | 260 | ctx.b_size = b_size 261 | ctx.T = T 262 | ctx.d = d 263 | ctx.K = K 264 | 265 | u = u.transpose(-1, -2).contiguous() 266 | delta = delta.transpose(-1, -2).contiguous() 267 | B = B.transpose(-1, -2).contiguous() 268 | C = C.transpose(-1, -2).contiguous() 269 | o = torch.empty_like(u) 270 | H = torch.empty(b_size, d, K, T, device=u.device, dtype=torch.float32) 271 | forward_scan[(b_size, d)](A, B, C, delta, u, o, H, T, d, K) 272 | ctx.save_for_backward(A, B, C, delta, H, u) 273 | return o.transpose(-1, -2).contiguous() 274 | 275 | 276 | @staticmethod 277 | def backward(ctx, grad_output): 278 | do = grad_output.transpose(-1, -2).contiguous() 279 | A, B, C, delta, H, u = ctx.saved_tensors 280 | b_size = ctx.b_size 281 | T = ctx.T 282 | d = ctx.d 283 | K = ctx.K 284 | 285 | dA = A.new_empty(b_size, d, K) 286 | du = torch.empty_like(u) 287 | d_delta = torch.empty_like(delta) 288 | db = B.new_empty(b_size, K, d, T) 289 | dc = C.new_empty(b_size, K, d, T) 290 | 291 | backward_scan_du_delta_A[(b_size, d)](A, B, C, u, delta, do, H, dA, db, dc, d_delta, du, T, d, K) 292 | db = db.sum(-2) 293 | dc = dc.sum(-2) 294 | 295 | return du.transpose(-1, -2), d_delta.transpose(-1, -2), dA.sum(0), db.transpose(-1, -2), dc.transpose(-1, -2) 296 | 297 | 298 | def triton_selective_scan(u, delta, A, B, C, D): 299 | original_dtype = u.dtype 300 | D = D.float() 301 | A = A.float() 302 | o = SelectiveScan.apply(u, delta, A, B, C) 303 | o += D * u 304 | return o.to(original_dtype) 305 | 306 | 307 | --------------------------------------------------------------------------------