├── LICENSE.txt ├── LRU_pytorch ├── LRU.py └── __init__.py ├── README.md ├── setup.cfg └── setup.py /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Vishnu Jaddipal 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /LRU_pytorch/LRU.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class LRU(nn.Module): 8 | def __init__(self,in_features,out_features,state_features, rmin=0, rmax=1,max_phase=6.283): 9 | super().__init__() 10 | self.out_features=out_features 11 | self.D=nn.Parameter(torch.randn([out_features,in_features])/math.sqrt(in_features)) 12 | u1=torch.rand(state_features) 13 | u2=torch.rand(state_features) 14 | self.nu_log= nn.Parameter(torch.log(-0.5*torch.log(u1*(rmax+rmin)*(rmax-rmin) + rmin**2))) 15 | self.theta_log= nn.Parameter(torch.log(max_phase*u2)) 16 | Lambda_mod=torch.exp(-torch.exp(self.nu_log)) 17 | self.gamma_log=nn.Parameter(torch.log(torch.sqrt(torch.ones_like(Lambda_mod)-torch.square(Lambda_mod)))) 18 | B_re=torch.randn([state_features,in_features])/math.sqrt(2*in_features) 19 | B_im=torch.randn([state_features,in_features])/math.sqrt(2*in_features) 20 | self.B=nn.Parameter(torch.complex(B_re,B_im)) 21 | C_re=torch.randn([out_features,state_features])/math.sqrt(state_features) 22 | C_im=torch.randn([out_features,state_features])/math.sqrt(state_features) 23 | self.C=nn.Parameter(torch.complex(C_re,C_im)) 24 | self.state=torch.complex(torch.zeros(state_features),torch.zeros(state_features)) 25 | 26 | def forward(self, input,state=None): 27 | self.state=self.state.to(self.B.device) if state==None else state 28 | Lambda_mod=torch.exp(-torch.exp(self.nu_log)) 29 | Lambda_re=Lambda_mod*torch.cos(torch.exp(self.theta_log)) 30 | Lambda_im=Lambda_mod*torch.sin(torch.exp(self.theta_log)) 31 | Lambda=torch.complex(Lambda_re,Lambda_im) 32 | Lambda=Lambda.to(self.state.device) 33 | gammas=torch.exp(self.gamma_log).unsqueeze(-1).to(self.B.device) 34 | gammas=gammas.to(self.state.device) 35 | output=torch.empty([i for i in input.shape[:-1]] +[self.out_features],device=self.B.device) 36 | #Handle input of (Batches,Seq_length, Input size) 37 | if input.dim()==3: 38 | for i,batch in enumerate(input): 39 | out_seq=torch.empty(input.shape[1],self.out_features) 40 | for j,step in enumerate(batch): 41 | self.state=(Lambda*self.state + gammas* self.B@step.to(dtype= self.B.dtype)) 42 | out_step= (self.C@self.state).real + self.D@step 43 | out_seq[j]=out_step 44 | self.state=torch.complex(torch.zeros_like(self.state.real),torch.zeros_like(self.state.real)) 45 | output[i]=out_seq 46 | #Handle input of (Seq_length, Input size) 47 | if input.dim()==2: 48 | for i,step in enumerate(input): 49 | self.state=(Lambda*self.state + gammas* self.B@step.to(dtype= self.B.dtype)) 50 | out_step= (self.C@self.state).real + self.D@step 51 | output[i]=out_step 52 | self.state=torch.complex(torch.zeros_like(self.state.real),torch.zeros_like(self.state.real)) 53 | return output 54 | -------------------------------------------------------------------------------- /LRU_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from LRU_pytorch.LRU import LRU 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LRU-pytorch 2 | An implementation of Linear Recurrent Units, by Deepmind, in Pytorch. LRUs are inspired by Deep State-Space Machines, particularly S4 and S5. 3 | 4 | # Notes: 5 | + Since Pytorch does not have associative scans as of now, the Pytorch implementation will very likely be slower than a JAX implementation. 6 | + Complex tensors are still in beta in Pytorch and do not fully support .half(), so using torch.float16 is not advised. 7 | + Certain tensors are created on every forward pass. This is necessary only during training, and these tensors could be frozen to speed up inference. 8 | 9 | # Installation: 10 | ``` 11 | $ pip install LRU-pytorch 12 | ``` 13 | # Usage: 14 | ```python 15 | import torch 16 | 17 | from LRU_pytorch import LRU 18 | 19 | # Create a single Linear Recurrent Unit, that takes in inputs of size (batch_size, seq_length, 30) (or (seq_length, 30)), 20 | # with internal state-space variable of size 10, and returns outputs of (batch_size, seq_length, 50) (or (seq_length, 50)). 21 | 22 | lru= LRU( 23 | in_features=30, 24 | out_features=50, 25 | state_features=10 26 | ) 27 | 28 | preds= lru(torch.randn([2,50,30])) # Get predictions 29 | ``` 30 | # Parameters: 31 | ```in_features```: int. The size of each timestep of the input sequence. 32 | 33 | ```out_features```: int. The size of each timestep of the output sequence. 34 | 35 | ```state_features```:int. The size of the internal state variable. 36 | 37 | # Paper: 38 | Resurrecting Recurrent Neural Networks for Long Sequences 39 | 40 | 41 | 42 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | description-file=README.md 3 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | 2 | from distutils.core import setup 3 | setup( 4 | name = 'LRU-pytorch', 5 | packages = ['LRU_pytorch'], 6 | version = '0.1.2', 7 | license='MIT', 8 | description = 'Linear Recurrent Unit (LRU) - Pytorch', 9 | author = 'Vishnu Jaddipal', 10 | author_email = 'zeus.vj2003@gmail.com', 11 | url = 'https://github.com/Gothos/LRU-pytorch', 12 | download_url = 'https://github.com/Gothos/LRU-pytorch/archive/refs/tags/v0.1.1-alpha.tar.gz', 13 | keywords = ['Artificial Intelligence', 'Deep Learning', 'Recurrent Neural Networks'], 14 | install_requires=[ 15 | 'torch>=1.13' 16 | ], 17 | classifiers=[ 18 | 'Development Status :: 3 - Alpha', 19 | 'Intended Audience :: Developers', 20 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 21 | 'License :: OSI Approved :: MIT License', 22 | 'Programming Language :: Python :: 3.7', 23 | 'Programming Language :: Python :: 3.8', 24 | 'Programming Language :: Python :: 3.9', 25 | 26 | ], 27 | ) 28 | --------------------------------------------------------------------------------