├── .gitignore ├── LICENSE ├── README.md ├── example.py ├── network.png ├── setup.py └── tnn_pytorch ├── __init__.py ├── glu.py ├── gtu.py ├── helpers ├── __init__.py ├── act_layer.py └── helpers.py ├── norm ├── __init__.py ├── offset_scale.py ├── rms_norm.py └── scale_norm.py ├── rpe.py ├── tnn_layer.py └── tno.py /.gitignore: -------------------------------------------------------------------------------- 1 | **/__pycache__/** 2 | *.egg-info/ 3 | build/ 4 | dist/ 5 | test.log -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Doraemonzzz 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Tnn-Pytorch 2 | This repository contains the Rpe, Tno, Gtu and Tnn layers mentioned in [Toeplitz Neural Network for Sequence Modeling](https://openreview.net/forum?id=IxmWsm4xrua). The overall network structure is as follows: 3 | 4 | ![](./network.png) 5 | 6 | 7 | 8 | ## Install 9 | 10 | ``` 11 | $ pip install tnn-pytorch 12 | ``` 13 | 14 | ## Introduction 15 | wip. 16 | 17 | ## Usage 18 | 19 | Pelease refer to `example.py` to understand the use of modules. 20 | 21 | ### Tno 22 | 23 | 24 | 25 | ## Recommended configuration 26 | 27 | ### TnnLayer 28 | 29 | We recommend initializing `TnnLayer`with the following configuration, the number of model parameters is approximately equivalent to a standard transformer layer: 30 | 31 | ``` 32 | dim=your_embedding_dim, 33 | num_heads=1, 34 | rpe_embedding=max(your_embedding_dim // 8, 32), 35 | glu_dim=your_embedding_dim, 36 | # model params 37 | prenorm=True, 38 | norm_type="simplermsnorm", 39 | # gtu params 40 | causal=False, # True for language model 41 | gtu_act="silu", 42 | expand_ratio=3, 43 | use_decay=True, 44 | gamma=0.9 for seqlen < 100, else 0.99 45 | # rpe params 46 | rpe_act="relu", 47 | rpe_layers=1, 48 | # glu params 49 | glu_act="silu", 50 | ``` 51 | 52 | 53 | 54 | ### Gtu 55 | 56 | If you only want to use `Gtu`, we recommend the following configuration: 57 | 58 | ``` 59 | embed_dim=your_embedding_dim, 60 | num_heads=1, 61 | act_fun="silu", 62 | norm_type="simplermsnorm", 63 | causal=False, # True for language model 64 | expand_ratio=3, 65 | use_decay=True, 66 | gamma=0.9 for seqlen < 100, else 0.99 67 | rpe_embedding=max(your_embedding_dim // 8, 32), 68 | rpe_act="relu", 69 | rpe_layers=1, 70 | ``` 71 | 72 | 73 | 74 | ## Todo 75 | 76 | - [x] Tnn layer 77 | - [ ] Tnn model 78 | - [x] Recommended configuration 79 | - [ ] Introduction 80 | 81 | -------------------------------------------------------------------------------- /example.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from tnn_pytorch import Gtu, TnnLayer, Tno 5 | 6 | # batch size 7 | b = 2 8 | # number of head 9 | h = 1 10 | # sequce length 11 | n = 10 12 | # embedding size 13 | e = 4 14 | # rpe embedding size 15 | d = 16 16 | 17 | print("======Start Test Tno=====") 18 | x = torch.rand(b, h, n, e) 19 | models = [ 20 | Tno(h, e, d, use_decay=True), 21 | Tno(h, e, d, use_multi_decay=True), 22 | Tno(h, e, d, causal=True), 23 | ] 24 | 25 | for dim in [-2]: 26 | for model in models: 27 | y1 = model.forward(x, dim=dim) 28 | y2 = model.toeplizt_matrix(x, dim=dim) 29 | print(torch.norm(y1 - y2)) 30 | print("======End Test Tno=====") 31 | 32 | print("======Start Test Gtu=====") 33 | x = torch.rand(b, n, e) 34 | models = [ 35 | Gtu( 36 | embed_dim=e, 37 | num_heads=1, 38 | ) 39 | ] 40 | 41 | for dim in [-2]: 42 | for model in models: 43 | y = model(x) 44 | print(f"input size is {x.shape}") 45 | print(f"output size is {y.shape}") 46 | print("======End Test Gtu=====") 47 | 48 | print("======Start Test Tnn Layer=====") 49 | x = torch.rand(b, n, e) 50 | models = [ 51 | TnnLayer( 52 | dim=e, 53 | num_heads=1, 54 | rpe_embedding=d, 55 | glu_dim=e, 56 | ) 57 | ] 58 | 59 | for dim in [-2]: 60 | for model in models: 61 | y = model(x) 62 | print(f"input size is {x.shape}") 63 | print(f"output size is {y.shape}") 64 | print("======End Test Tnn Layer=====") 65 | -------------------------------------------------------------------------------- /network.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Doraemonzzz/tnn-pytorch/64b4c354d415ea78ef5b631252de77ceae1a4cc7/network.png -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | setup( 4 | name='tnn-pytorch', 5 | version='0.0.5', 6 | packages=find_packages(), 7 | description='Toeplitz Neural Network for Sequence Modeling', 8 | author='Doraemonzzz', 9 | author_email='doraemon_zzz@163.com', 10 | url='https://github.com/Doraemonzzz/tnn-pytorch', 11 | install_requires=[ 12 | 'torch', 13 | 'einops', 14 | ], 15 | keywords = [ 16 | 'artificial intelligence', 17 | 'sequential model', 18 | ], 19 | classifiers=[ 20 | 'Development Status :: 3 - Alpha', 21 | 'Intended Audience :: Developers', 22 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 23 | 'License :: OSI Approved :: MIT License', 24 | 'Programming Language :: Python :: 3.6', 25 | ], 26 | 27 | ) 28 | -------------------------------------------------------------------------------- /tnn_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from .glu import GLU 2 | from .gtu import Gtu 3 | from .rpe import Rpe 4 | from .tnn_layer import TnnLayer 5 | from .tno import Tno 6 | -------------------------------------------------------------------------------- /tnn_pytorch/glu.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .helpers import get_activation_fn, print_params 6 | 7 | 8 | class GLU(nn.Module): 9 | def __init__(self, d1, d2, act_fun, fina_act="None", dropout=0.0, bias=True): 10 | super().__init__() 11 | # get local varables 12 | params = locals() 13 | # print params 14 | print_params(**params) 15 | 16 | self.l1 = nn.Linear(d1, d2, bias=bias) 17 | self.l2 = nn.Linear(d1, d2, bias=bias) 18 | self.l3 = nn.Linear(d2, d1, bias=bias) 19 | self.act_fun = get_activation_fn(act_fun) 20 | self.p = dropout 21 | if self.p > 0.0: 22 | self.dropout = nn.Dropout(p=dropout) 23 | self.fina_act = get_activation_fn(fina_act) 24 | 25 | def forward(self, x): 26 | o1 = self.l1(x) 27 | weight = self.act_fun(o1) 28 | if self.p > 0.0: 29 | weight = self.dropout(weight) 30 | o2 = self.l2(x) 31 | output = weight * o2 32 | output = self.l3(output) 33 | output = self.fina_act(output) 34 | 35 | return output 36 | -------------------------------------------------------------------------------- /tnn_pytorch/gtu.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from einops import rearrange 3 | 4 | from .helpers import get_activation_fn, get_norm_fn, print_params 5 | from .tno import Tno 6 | 7 | 8 | class Gtu(nn.Module): 9 | def __init__( 10 | self, 11 | embed_dim, 12 | num_heads, 13 | bias=True, 14 | act_fun="silu", 15 | causal=False, 16 | expand_ratio=3, 17 | resi_param=False, 18 | use_norm=False, 19 | norm_type="simplermsnorm", 20 | use_decay=False, 21 | use_multi_decay=False, 22 | rpe_layers=3, 23 | rpe_embedding=512, 24 | rpe_act="relu", 25 | normalize=False, 26 | par_type=1, 27 | residual=False, 28 | gamma=0.99, 29 | act_type="none", 30 | ): 31 | super().__init__() 32 | self.embed_dim = embed_dim 33 | self.expand_ratio = expand_ratio 34 | self.resi_param = resi_param 35 | self.num_heads = num_heads 36 | self.normalize = normalize 37 | 38 | if self.resi_param: 39 | self.d = nn.Parameter(torch.randn(embed_dim)) 40 | 41 | d1 = int(self.expand_ratio * embed_dim) 42 | d1 = (d1 // self.num_heads) * self.num_heads 43 | self.head_dim = d1 // num_heads 44 | # linear projection 45 | self.v_proj = nn.Linear(embed_dim, d1, bias=bias) 46 | self.u_proj = nn.Linear(embed_dim, d1, bias=bias) 47 | self.o = nn.Linear(d1, embed_dim, bias=bias) 48 | self.act = get_activation_fn(act_fun) 49 | # tno 50 | self.toep = Tno( 51 | h=num_heads, 52 | dim=self.head_dim, 53 | rpe_dim=rpe_embedding, 54 | causal=causal, 55 | use_decay=use_decay, 56 | use_multi_decay=use_multi_decay, 57 | residual=residual, 58 | act=rpe_act, 59 | par_type=par_type, 60 | gamma=gamma, 61 | bias=bias, 62 | act_type=act_type, 63 | layers=rpe_layers, 64 | norm_type=norm_type, 65 | ) 66 | # norm 67 | self.norm_type = norm_type 68 | self.use_norm = use_norm 69 | if self.use_norm: 70 | self.norm = get_norm_fn(self.norm_type)(d1) 71 | 72 | def forward(self, x): 73 | # x: b, h, w, d 74 | num_heads = self.num_heads 75 | 76 | if self.resi_param: 77 | shortcut = shortcut * self.d 78 | u = self.act(self.u_proj(x)) 79 | v = self.act(self.v_proj(x)) 80 | # reshape 81 | v = rearrange(v, 'b n (h d) -> b h n d', h=num_heads) 82 | output = self.toep(v, dim=-2, normalize=self.normalize) 83 | output = rearrange(output, 'b h n d -> b n (h d)') 84 | output = u * output 85 | if self.use_norm: 86 | output = self.norm(output) 87 | 88 | output = self.o(output) 89 | 90 | return output 91 | 92 | -------------------------------------------------------------------------------- /tnn_pytorch/helpers/__init__.py: -------------------------------------------------------------------------------- 1 | from .act_layer import ActLayer 2 | from .helpers import (get_activation_fn, get_norm_fn, logging_info, 3 | print_config, print_params) 4 | -------------------------------------------------------------------------------- /tnn_pytorch/helpers/act_layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .helpers import get_activation_fn, print_params 6 | 7 | 8 | class ActLayer(nn.Module): 9 | def __init__(self, act_fun): 10 | super().__init__() 11 | # get local varables 12 | params = locals() 13 | # print params 14 | print_params(**params) 15 | 16 | self.act = get_activation_fn(act_fun) 17 | 18 | def forward(self, x): 19 | return self.act(x) 20 | -------------------------------------------------------------------------------- /tnn_pytorch/helpers/helpers.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | 5 | import torch 6 | import torch.distributed as dist 7 | import torch.nn.functional as F 8 | from torch import nn 9 | 10 | from ..norm import GatedRMSNorm, RMSNorm, ScaleNorm, SimpleRMSNorm 11 | 12 | logging.basicConfig( 13 | format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", 14 | datefmt="%Y-%m-%d %H:%M:%S", 15 | level=os.environ.get("LOGLEVEL", "INFO").upper(), 16 | stream=sys.stdout, 17 | ) 18 | logger = logging.getLogger("print_config") 19 | 20 | def is_dist_avail_and_initialized(): 21 | if not dist.is_available(): 22 | return False 23 | if not dist.is_initialized(): 24 | return False 25 | return True 26 | 27 | def get_world_size(): 28 | if not is_dist_avail_and_initialized(): 29 | return 1 30 | return dist.get_world_size() 31 | 32 | def get_rank(): 33 | if not is_dist_avail_and_initialized(): 34 | return 0 35 | return dist.get_rank() 36 | 37 | def is_main_process(): 38 | return get_rank() == 0 39 | 40 | def logging_info(string): 41 | if is_main_process(): 42 | logger.info(string) 43 | 44 | def print_params(**kwargs): 45 | if is_main_process(): 46 | logger.info(f"start print config of {kwargs['__class__']}") 47 | for key in kwargs: 48 | if key in ["__class__", "self"]: 49 | continue 50 | logger.info(f"{key}: {kwargs[key]}") 51 | logger.info(f"end print config of {kwargs['__class__']}") 52 | 53 | def print_config(config): 54 | if is_main_process(): 55 | logger.info(f"start print config of {config['__class__']}") 56 | for key in config: 57 | if key in ["__class__", "self"]: 58 | continue 59 | logger.info(f"{key}: {config[key]}") 60 | logger.info(f"end print config of {config['__class__']}") 61 | 62 | def get_activation_fn(activation): 63 | logger.info(f"activation: {activation}") 64 | if activation == "gelu": 65 | return F.gelu 66 | elif activation == "relu": 67 | return F.relu 68 | elif activation == "elu": 69 | return F.elu 70 | elif activation == "sigmoid": 71 | return F.sigmoid 72 | elif activation == "exp": 73 | return torch.exp 74 | elif activation == "leak": 75 | return F.leaky_relu 76 | elif activation == "1+elu": 77 | def f(x): 78 | return 1 + F.elu(x) 79 | return f 80 | elif activation == "2+elu": 81 | def f(x): 82 | return 2 + F.elu(x) 83 | return f 84 | elif activation == "silu": 85 | return F.silu 86 | else: 87 | return lambda x: x 88 | 89 | def get_norm_fn(norm_type): 90 | if norm_type == "rmsnorm": 91 | return RMSNorm 92 | elif norm_type == "gatedrmsnorm": 93 | return GatedRMSNorm 94 | elif norm_type == "simplermsnorm": 95 | return SimpleRMSNorm 96 | elif norm_type == "scalenorm": 97 | return ScaleNorm 98 | else: 99 | return nn.LayerNorm 100 | -------------------------------------------------------------------------------- /tnn_pytorch/norm/__init__.py: -------------------------------------------------------------------------------- 1 | from .rms_norm import SimpleRMSNorm, RMSNorm, GatedRMSNorm 2 | from .scale_norm import ScaleNorm 3 | from .offset_scale import OffsetScale -------------------------------------------------------------------------------- /tnn_pytorch/norm/offset_scale.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | class OffsetScale(nn.Module): 5 | def __init__(self, dim): 6 | super().__init__() 7 | self.gamma = nn.Parameter(torch.ones(dim)) 8 | self.beta = nn.Parameter(torch.zeros(dim)) 9 | nn.init.normal_(self.gamma, std = 0.02) 10 | 11 | def forward(self, x): 12 | out = torch.einsum('... d, d -> ... d', x, self.gamma) + self.beta 13 | return out -------------------------------------------------------------------------------- /tnn_pytorch/norm/rms_norm.py: -------------------------------------------------------------------------------- 1 | # https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py 2 | # https://github.com/bzhangGo/zero/blob/master/modules/rela.py 3 | import torch 4 | import torch.nn as nn 5 | import torch.functional as F 6 | 7 | class SimpleRMSNorm(nn.Module): 8 | def __init__(self, d, p=-1., eps=1e-8, bias=False): 9 | """ 10 | Root Mean Square Layer Normalization 11 | :param d: model size 12 | :param p: partial RMSNorm, valid value [0, 1], default -1.0 (disabled) 13 | :param eps: epsilon value, default 1e-8 14 | :param bias: whether use bias term for RMSNorm, disabled by 15 | default because RMSNorm doesn't enforce re-centering invariance. 16 | """ 17 | super(SimpleRMSNorm, self).__init__() 18 | self.eps = eps 19 | self.d = d 20 | 21 | def forward(self, x): 22 | norm_x = x.norm(2, dim=-1, keepdim=True) 23 | d_x = self.d 24 | 25 | rms_x = norm_x * d_x ** (-1. / 2) 26 | x_normed = x / (rms_x + self.eps) 27 | 28 | return x_normed 29 | 30 | class RMSNorm(nn.Module): 31 | def __init__(self, d, p=-1., eps=1e-8, bias=False): 32 | """ 33 | Root Mean Square Layer Normalization 34 | :param d: model size 35 | :param p: partial RMSNorm, valid value [0, 1], default -1.0 (disabled) 36 | :param eps: epsilon value, default 1e-8 37 | :param bias: whether use bias term for RMSNorm, disabled by 38 | default because RMSNorm doesn't enforce re-centering invariance. 39 | """ 40 | super(RMSNorm, self).__init__() 41 | 42 | self.eps = eps 43 | self.d = d 44 | self.p = p 45 | self.bias = bias 46 | 47 | self.scale = nn.Parameter(torch.ones(d)) 48 | self.register_parameter("scale", self.scale) 49 | 50 | if self.bias: 51 | self.offset = nn.Parameter(torch.zeros(d)) 52 | self.register_parameter("offset", self.offset) 53 | 54 | def forward(self, x): 55 | if self.p < 0. or self.p > 1.: 56 | norm_x = x.norm(2, dim=-1, keepdim=True) 57 | d_x = self.d 58 | else: 59 | partial_size = int(self.d * self.p) 60 | partial_x, _ = torch.split(x, [partial_size, self.d - partial_size], dim=-1) 61 | 62 | norm_x = partial_x.norm(2, dim=-1, keepdim=True) 63 | d_x = partial_size 64 | 65 | rms_x = norm_x * d_x ** (-1. / 2) 66 | x_normed = x / (rms_x + self.eps) 67 | 68 | if self.bias: 69 | return self.scale * x_normed + self.offset 70 | 71 | return self.scale * x_normed 72 | 73 | class GatedRMSNorm(nn.Module): 74 | def __init__(self, d, eps=1e-8, bias=False): 75 | """ 76 | Root Mean Square Layer Normalization 77 | :param d: model size 78 | :param eps: epsilon value, default 1e-8 79 | :param bias: whether use bias term for RMSNorm, disabled by 80 | default because RMSNorm doesn't enforce re-centering invariance. 81 | """ 82 | super(GatedRMSNorm, self).__init__() 83 | 84 | self.eps = eps 85 | self.d = d 86 | self.bias = bias 87 | 88 | self.scale = nn.Parameter(torch.ones(d)) 89 | self.register_parameter("scale", self.scale) 90 | self.gate = nn.Parameter(torch.ones(d)) 91 | self.register_parameter("scale", self.scale) 92 | 93 | 94 | def forward(self, x): 95 | norm_x = x.norm(2, dim=-1, keepdim=True) 96 | d_x = self.d 97 | 98 | rms_x = norm_x * d_x ** (-1. / 2) 99 | x_normed = x / (rms_x + self.eps) 100 | 101 | return self.scale * x_normed * torch.sigmoid(self.gate * x) -------------------------------------------------------------------------------- /tnn_pytorch/norm/scale_norm.py: -------------------------------------------------------------------------------- 1 | # for flash 2 | import torch 3 | from torch import nn 4 | 5 | class ScaleNorm(nn.Module): 6 | def __init__(self, d, eps=1e-5): 7 | super().__init__() 8 | self.d = d 9 | self.eps = eps 10 | self.scala = nn.Parameter(torch.ones(1)) 11 | 12 | def forward(self, x): 13 | mean_square = (x ** 2).mean(dim=-1, keepdim=True) 14 | x = x * torch.rsqrt(mean_square + self.eps) * self.scala 15 | return x -------------------------------------------------------------------------------- /tnn_pytorch/rpe.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from .helpers import ActLayer, get_activation_fn, get_norm_fn, print_params 4 | 5 | 6 | class Rpe(nn.Module): 7 | def __init__( 8 | self, 9 | dim, 10 | outdim, 11 | residual, 12 | act="relu", 13 | bias=True, 14 | layers=3, 15 | norm_type="simplermsnorm", 16 | ): 17 | super().__init__() 18 | # get local varables 19 | params = locals() 20 | # print params 21 | print_params(**params) 22 | 23 | self.residual = residual 24 | self.outdim = outdim 25 | self.pos_dim = dim 26 | self.act = act 27 | self.pos_proj = nn.Linear(1, self.pos_dim, bias=bias) 28 | self.layers = nn.ModuleList([]) 29 | for i in range(layers): 30 | self.layers.append( 31 | nn.Sequential( 32 | get_norm_fn(norm_type)(self.pos_dim), 33 | self.get_act(), 34 | nn.Linear(self.pos_dim, self.pos_dim, bias=bias), 35 | ) 36 | ) 37 | self.out = nn.Sequential( 38 | get_norm_fn(norm_type)(self.pos_dim), 39 | self.get_act(), 40 | nn.Linear(self.pos_dim, self.outdim, bias=bias), 41 | ) 42 | 43 | def get_act(self): 44 | if self.act == "silu": 45 | return nn.SiLU(inplace=True) 46 | elif self.act == "relu": 47 | return nn.ReLU(inplace=True) 48 | else: 49 | return ActLayer(self.act) 50 | 51 | def forward(self, biases): 52 | x = self.pos_proj(biases) 53 | if self.residual: 54 | for m in self.layers: 55 | x = m(x) + x 56 | else: 57 | for m in self.layers: 58 | x = m(x) 59 | x = self.out(x) 60 | 61 | return x 62 | -------------------------------------------------------------------------------- /tnn_pytorch/tnn_layer.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from .glu import GLU 4 | from .gtu import Gtu 5 | from .helpers import get_norm_fn 6 | 7 | 8 | class TnnLayer(nn.Module): 9 | def __init__( 10 | self, 11 | dim, 12 | num_heads, 13 | rpe_embedding, 14 | glu_dim, 15 | # model params 16 | prenorm=True, 17 | norm_type="simplermsnorm", 18 | # gtu params 19 | causal=False, 20 | gtu_act="silu", 21 | expand_ratio=3, 22 | use_decay=False, 23 | gamma=0.999, 24 | # rpe params 25 | rpe_act="relu", 26 | rpe_layers=3, 27 | # glu params 28 | glu_act="silu", 29 | ): 30 | super().__init__() 31 | self.token_mixer = Gtu( 32 | # gtu params 33 | embed_dim=dim, 34 | num_heads=num_heads, 35 | act_fun=gtu_act, 36 | norm_type=norm_type, 37 | causal=causal, 38 | expand_ratio=expand_ratio, 39 | use_decay=use_decay, 40 | gamma=gamma, 41 | # rpe params 42 | rpe_embedding=rpe_embedding, 43 | rpe_act=rpe_act, 44 | rpe_layers=rpe_layers, 45 | ) 46 | 47 | self.token_norm = get_norm_fn(norm_type)(dim) 48 | self.feature_norm = get_norm_fn(norm_type)(dim) 49 | 50 | self.feature_mixer = GLU( 51 | d1=dim, 52 | d2=glu_dim, 53 | act_fun=glu_act, 54 | ) 55 | 56 | if prenorm: 57 | self.forward = self.forward_prenorm 58 | else: 59 | self.forward = self.forward_postnorm 60 | 61 | def forward_postnorm(self, x): 62 | x = x + self.token_norm(self.token_mixer(x)) 63 | x = x + self.feature_norm(self.feature_mixer(x)) 64 | 65 | return x 66 | 67 | def forward_prenorm(self, x): 68 | x = x + self.token_mixer(self.token_norm(x)) 69 | x = x + self.feature_mixer(self.feature_norm(x)) 70 | 71 | return x 72 | -------------------------------------------------------------------------------- /tnn_pytorch/tno.py: -------------------------------------------------------------------------------- 1 | # https://alinush.github.io/2020/03/19/multiplying-a-vector-by-a-toeplitz-matrix.html 2 | # https://stackoverflow.com/questions/69809789/is-there-any-way-to-create-a-tensor-with-a-specific-pattern-in-pytorch 3 | # https://github.com/cheerss/CrossFormer/blob/main/models/crossformer.py 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from einops import rearrange, repeat 9 | 10 | from .helpers import ActLayer, get_activation_fn, print_params 11 | from .rpe import Rpe 12 | 13 | 14 | class Tno(nn.Module): 15 | def __init__( 16 | self, 17 | h, 18 | dim, 19 | rpe_dim, 20 | causal=False, 21 | use_decay=False, 22 | use_multi_decay=False, 23 | residual=False, 24 | act="relu", 25 | par_type=1, 26 | gamma=0.999, 27 | bias=True, 28 | act_type="none", 29 | layers=3, 30 | norm_type="simplermsnorm", 31 | ): 32 | super().__init__() 33 | # get local varables 34 | params = locals() 35 | # print params 36 | print_params(**params) 37 | 38 | self.h = h 39 | self.dim = dim 40 | self.causal = causal 41 | self.par_type = par_type 42 | self.zero_value = 0 43 | self.use_decay = use_decay 44 | if self.use_decay: 45 | self.gamma = nn.Parameter(torch.ones(h, 1, dim) * gamma, requires_grad=False) 46 | self.use_multi_decay = use_multi_decay 47 | if self.use_multi_decay: 48 | self.lambda_ = gamma 49 | self.gamma = nn.Parameter(torch.randn(h, 1, dim)) 50 | 51 | self.rpe = Rpe( 52 | dim=rpe_dim, 53 | outdim=h * dim, 54 | residual=residual, 55 | act=act, 56 | bias=bias, 57 | layers=layers, 58 | norm_type=norm_type, 59 | ) 60 | 61 | if self.causal: 62 | self.forward = self.forward_causal 63 | else: 64 | self.forward = self.forward_non_causal 65 | 66 | self.act_fun = get_activation_fn(act_type) 67 | 68 | def get_pos(self, n): 69 | if self.par_type == 1: 70 | index = torch.arange(1, 1 + n).reshape(n, -1) * 1.0 71 | elif self.par_type == 2: 72 | index = torch.arange(1, 1 + n).reshape(n, -1) * 1.0 / n 73 | elif self.par_type == 3: 74 | index = torch.exp(torch.arange(1, 1 + n).reshape(n, -1) * 1.0 / n) 75 | 76 | return index 77 | 78 | def get_zero(self): 79 | index = torch.zeros(1).reshape(1, -1) * 1.0 80 | if self.par_type == 3: 81 | index = torch.exp(index) 82 | 83 | return index 84 | 85 | def get_neg(self, n): 86 | if self.causal: 87 | index = torch.ones(self.h * n * self.dim).reshape(self.h, n, self.dim) * self.zero_value 88 | else: 89 | if self.par_type == 1: 90 | index = -torch.arange(1, 1 + n).flip(0).reshape(n, -1) * 1.0 91 | elif self.par_type == 2: 92 | index = -torch.arange(1, 1 + n).flip(0).reshape(n, -1) * 1.0 / n 93 | 94 | return index 95 | 96 | def rpe_transform(self, x): 97 | # n, 1 -> n, (d * h) 98 | res = self.rpe(x) 99 | # n, (d * h) -> h, n, d 100 | res = rearrange(res, 'n (h d) -> h n d', h=self.h) 101 | 102 | return res 103 | 104 | def forward_causal(self, x, dim=-2, normalize=False): 105 | # x: b, h, n, d 106 | n = x.shape[dim] 107 | # a0, a1, ... , a(n-1), a0, a(-(n-1)), ... , a(-1) 108 | ##### coef 109 | # 1, d, 1 -> h, 1, d 110 | zero = self.rpe_transform(self.get_zero().to(x)) 111 | pos = self.rpe_transform(self.get_pos(n - 1).to(x)) 112 | 113 | if self.use_decay or self.use_multi_decay: 114 | coef = torch.arange(1, n).reshape(1, -1, 1).to(x) 115 | if self.use_decay: 116 | gamma = self.gamma 117 | else: 118 | gamma = torch.sigmoid(self.gamma) 119 | gamma = self.lambda_ + (1 - self.lambda_) * gamma 120 | gamma = gamma ** coef 121 | pos = gamma * pos 122 | a = torch.cat([zero, pos, zero], dim=1) 123 | a = self.act_fun(a) 124 | 125 | # x: b, h, n, d 126 | # a: h, l, d 127 | output = self.compute(x, a, dim, n) 128 | 129 | if normalize: 130 | size = list(x.shape[:-1]) + [1] 131 | ones = torch.ones(size).to(x) 132 | denorm = self.compute(ones, a, dim, n) 133 | output = output / denorm 134 | 135 | return output 136 | 137 | def forward_non_causal(self, x, dim=-2, normalize=False): 138 | # x: b, h, n, d 139 | n = x.shape[dim] 140 | # a0, a1, ... , a(n-1), a0, a(-(n-1)), ... , a(-1) 141 | ##### coef 142 | # 1, d, 1 -> h, 1, d 143 | zero = self.rpe_transform(self.get_zero().to(x)) 144 | pos = self.rpe_transform(self.get_pos(n - 1).to(x)) 145 | neg_index = self.get_neg(n - 1).to(x) 146 | if self.causal: 147 | neg = neg_index 148 | else: 149 | neg = self.rpe_transform(neg_index) 150 | 151 | if self.use_decay or self.use_multi_decay: 152 | coef = torch.arange(1, n).reshape(1, -1, 1).to(x) 153 | if self.use_decay: 154 | gamma = self.gamma 155 | else: 156 | gamma = torch.sigmoid(self.gamma) 157 | gamma = self.lambda_ + (1 - self.lambda_) * gamma 158 | gamma = gamma ** coef 159 | pos = gamma * pos 160 | neg = torch.flip(gamma, dims=[1]) * neg 161 | a = torch.cat([zero, pos, zero, neg], dim=1) 162 | a = self.act_fun(a) 163 | # x: b, h, n, d 164 | # a: h, l, d 165 | output = self.compute(x, a, dim, n) 166 | 167 | if normalize: 168 | size = list(x.shape[:-1]) + [1] 169 | ones = torch.ones(size).to(x) 170 | denorm = self.compute(ones, a, dim, n) 171 | output = output / denorm 172 | 173 | return output 174 | 175 | def compute(self, x, a, dim, n): 176 | # x: b, h, n, d 177 | # a: h, n, d 178 | y = torch.fft.rfft(x, 2 * n, dim=dim) 179 | v = torch.fft.rfft(a, 2 * n, dim=dim).unsqueeze(0) 180 | u = v * y 181 | output = torch.fft.irfft(u, 2 * n, dim=dim)[:, :, :n, :] 182 | 183 | return output 184 | 185 | def toeplizt_matrix(self, x, dim): 186 | assert dim == -2 187 | # shape of x: b, h, n, d 188 | n = x.shape[dim] 189 | # c: first col, r: first row 190 | # 1, d, 1 -> h, 1, d 191 | zero = self.rpe_transform(self.get_zero().to(x)) 192 | pos = self.rpe_transform(self.get_pos(n - 1).to(x)) 193 | neg_index = self.get_neg(n - 1).to(x) 194 | if self.causal: 195 | neg = neg_index 196 | else: 197 | neg = self.rpe_transform(neg_index) 198 | 199 | if self.use_decay or self.use_multi_decay: 200 | coef = torch.arange(1, n).reshape(1, -1, 1) 201 | if self.use_decay: 202 | gamma = self.gamma 203 | else: 204 | gamma = torch.sigmoid(self.gamma) 205 | gamma = self.lambda_ + (1 - self.lambda_) * gamma 206 | gamma = gamma ** coef 207 | pos = gamma * pos 208 | neg = torch.flip(gamma, dims=[1]) * neg 209 | zero = self.act_fun(zero) 210 | pos = self.act_fun(pos) 211 | if not self.causal: 212 | neg = self.act_fun(neg) 213 | c = torch.cat([zero, pos], dim=-2) 214 | r = torch.cat([zero, neg.flip(1)], dim=-2) 215 | vals = torch.cat([r, c[:, 1:].flip(1)], dim=-2) 216 | n = c.shape[-2] 217 | shape = self.h, n, n 218 | i, j = torch.ones(n, n).nonzero().T 219 | T = vals[:, j - i].reshape(self.h, n, n, -1) 220 | 221 | res = torch.einsum('h n m d, b h m d -> b h n d', T, x) 222 | return res 223 | --------------------------------------------------------------------------------