├── .github └── workflows │ └── python-publish.yml ├── .gitignore ├── LICENSE ├── README.md ├── etsformer.png ├── etsformer_pytorch ├── __init__.py └── etsformer_pytorch.py └── setup.py /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | 2 | 3 | # This workflow will upload a Python Package using Twine when a release is created 4 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 5 | 6 | # This workflow uses actions that are not certified by GitHub. 7 | # They are provided by a third-party and are governed by 8 | # separate terms of service, privacy policy, and support 9 | # documentation. 10 | 11 | name: Upload Python Package 12 | 13 | on: 14 | release: 15 | types: [published] 16 | 17 | jobs: 18 | deploy: 19 | 20 | runs-on: ubuntu-latest 21 | 22 | steps: 23 | - uses: actions/checkout@v2 24 | - name: Set up Python 25 | uses: actions/setup-python@v2 26 | with: 27 | python-version: '3.x' 28 | - name: Install dependencies 29 | run: | 30 | python -m pip install --upgrade pip 31 | pip install build 32 | - name: Build package 33 | run: python -m build 34 | - name: Publish package 35 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 36 | with: 37 | user: __token__ 38 | password: ${{ secrets.PYPI_API_TOKEN }} 39 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Phil Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## ETSformer - Pytorch 4 | 5 | Implementation of ETSformer, state of the art time-series Transformer, in Pytorch 6 | 7 | ## Install 8 | 9 | ```bash 10 | $ pip install etsformer-pytorch 11 | ``` 12 | 13 | ## Usage 14 | 15 | ```python 16 | import torch 17 | from etsformer_pytorch import ETSFormer 18 | 19 | model = ETSFormer( 20 | time_features = 4, 21 | model_dim = 512, # in paper they use 512 22 | embed_kernel_size = 3, # kernel size for 1d conv for input embedding 23 | layers = 2, # number of encoder and corresponding decoder layers 24 | heads = 8, # number of exponential smoothing attention heads 25 | K = 4, # num frequencies with highest amplitude to keep (attend to) 26 | dropout = 0.2 # dropout (in paper they did 0.2) 27 | ) 28 | 29 | timeseries = torch.randn(1, 1024, 4) 30 | 31 | pred = model(timeseries, num_steps_forecast = 32) # (1, 32, 4) - (batch, num steps forecast, num time features) 32 | ``` 33 | 34 | For using ETSFormer for classification, using cross attention pooling on all latents and level output 35 | 36 | ```python 37 | import torch 38 | from etsformer_pytorch import ETSFormer, ClassificationWrapper 39 | 40 | etsformer = ETSFormer( 41 | time_features = 1, 42 | model_dim = 512, 43 | embed_kernel_size = 3, 44 | layers = 2, 45 | heads = 8, 46 | K = 4, 47 | dropout = 0.2 48 | ) 49 | 50 | adapter = ClassificationWrapper( 51 | etsformer = etsformer, 52 | dim_head = 32, 53 | heads = 16, 54 | dropout = 0.2, 55 | level_kernel_size = 5, 56 | num_classes = 10 57 | ) 58 | 59 | timeseries = torch.randn(1, 1024) 60 | 61 | logits = adapter(timeseries) # (1, 10) 62 | ``` 63 | 64 | ## Citation 65 | 66 | ```bibtex 67 | @misc{woo2022etsformer, 68 | title = {ETSformer: Exponential Smoothing Transformers for Time-series Forecasting}, 69 | author = {Gerald Woo and Chenghao Liu and Doyen Sahoo and Akshat Kumar and Steven Hoi}, 70 | year = {2022}, 71 | eprint = {2202.01381}, 72 | archivePrefix = {arXiv}, 73 | primaryClass = {cs.LG} 74 | } 75 | ``` 76 | -------------------------------------------------------------------------------- /etsformer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/ETSformer-pytorch/68f9ed9fa361e4a9966fb192275b686eacf00745/etsformer.png -------------------------------------------------------------------------------- /etsformer_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from etsformer_pytorch.etsformer_pytorch import ( 2 | ETSFormer, 3 | ClassificationWrapper, 4 | MHESA 5 | ) 6 | -------------------------------------------------------------------------------- /etsformer_pytorch/etsformer_pytorch.py: -------------------------------------------------------------------------------- 1 | from math import pi 2 | from collections import namedtuple 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from torch import nn, einsum 7 | 8 | from scipy.fftpack import next_fast_len 9 | from einops import rearrange, repeat 10 | from einops.layers.torch import Rearrange 11 | 12 | # constants 13 | 14 | Intermediates = namedtuple('Intermediates', ['growth_latents', 'seasonal_latents', 'level_output']) 15 | 16 | # helper functions 17 | 18 | def exists(val): 19 | return val is not None 20 | 21 | # fourier helpers 22 | 23 | def fourier_extrapolate(signal, start, end): 24 | device = signal.device 25 | fhat = torch.fft.fft(signal) 26 | fhat_len = fhat.shape[-1] 27 | time = torch.linspace(start, end - 1, end - start, device = device, dtype = torch.complex64) 28 | freqs = torch.linspace(0, fhat_len - 1, fhat_len, device = device, dtype = torch.complex64) 29 | res = fhat[..., None, :] * (1.j * 2 * pi * freqs[..., None, :] * time[..., :, None] / fhat_len).exp() / fhat_len 30 | return res.sum(dim = -1).real 31 | 32 | # classes 33 | 34 | def InputEmbedding(time_features, model_dim, kernel_size = 3, dropout = 0.): 35 | return nn.Sequential( 36 | Rearrange('b n d -> b d n'), 37 | nn.Conv1d(time_features, model_dim, kernel_size = kernel_size, padding = kernel_size // 2), 38 | nn.Dropout(dropout), 39 | Rearrange('b d n -> b n d'), 40 | ) 41 | 42 | def FeedForward(dim, mult = 4, dropout = 0.): 43 | return nn.Sequential( 44 | nn.Linear(dim, dim * mult), 45 | nn.Sigmoid(), 46 | nn.Dropout(dropout), 47 | nn.Linear(dim * mult, dim), 48 | nn.Dropout(dropout) 49 | ) 50 | 51 | class FeedForwardBlock(nn.Module): 52 | def __init__( 53 | self, 54 | *, 55 | dim, 56 | **kwargs 57 | ): 58 | super().__init__() 59 | self.norm = nn.LayerNorm(dim) 60 | self.ff = FeedForward(dim, **kwargs) 61 | self.post_norm = nn.LayerNorm(dim) 62 | 63 | def forward(self, x): 64 | x = self.norm(x) 65 | return self.post_norm(x + self.ff(x)) 66 | 67 | # encoder related classes 68 | 69 | ## multi-head exponential smoothing attention 70 | 71 | def conv1d_fft(x, weights, dim = -2, weight_dim = -1): 72 | # Algorithm 3 in paper 73 | 74 | N = x.shape[dim] 75 | M = weights.shape[weight_dim] 76 | 77 | fast_len = next_fast_len(N + M - 1) 78 | 79 | f_x = torch.fft.rfft(x, n = fast_len, dim = dim) 80 | f_weight = torch.fft.rfft(weights, n = fast_len, dim = weight_dim) 81 | 82 | f_v_weight = f_x * rearrange(f_weight.conj(), '... -> ... 1') 83 | out = torch.fft.irfft(f_v_weight, fast_len, dim = dim) 84 | out = out.roll(-1, dims = (dim,)) 85 | 86 | indices = torch.arange(start = fast_len - N, end = fast_len, dtype = torch.long, device = x.device) 87 | out = out.index_select(dim, indices) 88 | return out 89 | 90 | class MHESA(nn.Module): 91 | def __init__( 92 | self, 93 | *, 94 | dim, 95 | heads = 8, 96 | dropout = 0., 97 | norm_heads = False 98 | ): 99 | super().__init__() 100 | self.heads = heads 101 | self.initial_state = nn.Parameter(torch.randn(heads, dim // heads)) 102 | 103 | self.dropout = nn.Dropout(dropout) 104 | self.alpha = nn.Parameter(torch.randn(heads)) 105 | 106 | self.norm_heads = nn.Sequential( 107 | Rearrange('b n (h d) -> b (h d) n', h = heads), 108 | nn.GroupNorm(heads, dim), 109 | Rearrange('b (h d) n -> b n (h d)', h = heads) 110 | ) if norm_heads else nn.Identity() 111 | 112 | self.project_in = nn.Linear(dim, dim) 113 | self.project_out = nn.Linear(dim, dim) 114 | 115 | def naive_Aes(self, x, weights): 116 | n, h = x.shape[-2], self.heads 117 | 118 | # in appendix A.1 - Algorithm 2 119 | 120 | arange = torch.arange(n, device = x.device) 121 | 122 | weights = repeat(weights, '... l -> ... t l', t = n) 123 | indices = repeat(arange, 'l -> h t l', h = h, t = n) 124 | 125 | indices = (indices - rearrange(arange + 1, 't -> 1 t 1')) % n 126 | 127 | weights = weights.gather(-1, indices) 128 | weights = self.dropout(weights) 129 | 130 | # causal 131 | 132 | weights = weights.tril() 133 | 134 | # multiply 135 | 136 | output = einsum('b h n d, h m n -> b h m d', x, weights) 137 | return output 138 | 139 | def forward(self, x, naive = False): 140 | b, n, d, h, device = *x.shape, self.heads, x.device 141 | 142 | # linear project in 143 | 144 | x = self.project_in(x) 145 | 146 | # split out heads 147 | 148 | x = rearrange(x, 'b n (h d) -> b h n d', h = h) 149 | 150 | # temporal difference 151 | 152 | x = torch.cat(( 153 | repeat(self.initial_state, 'h d -> b h 1 d', b = b), 154 | x 155 | ), dim = -2) 156 | 157 | x = x[:, :, 1:] - x[:, :, :-1] 158 | 159 | # prepare exponential alpha 160 | 161 | alpha = self.alpha.sigmoid() 162 | alpha = rearrange(alpha, 'h -> h 1') 163 | 164 | # arange == powers 165 | 166 | arange = torch.arange(n, device = device) 167 | weights = alpha * (1 - alpha) ** torch.flip(arange, dims = (0,)) 168 | 169 | if naive: 170 | output = self.naive_Aes(x, weights) 171 | else: 172 | output = conv1d_fft(x, weights) 173 | 174 | # get initial state contribution 175 | 176 | init_weight = (1 - alpha) ** (arange + 1) 177 | init_output = rearrange(init_weight, 'h n -> h n 1') * rearrange(self.initial_state, 'h d -> h 1 d') 178 | 179 | output = output + init_output 180 | 181 | # merge heads 182 | 183 | output = rearrange(output, 'b h n d -> b n (h d)') 184 | 185 | # maybe sub-ln from https://arxiv.org/abs/2210.06423 - retnet used groupnorm 186 | 187 | output = self.norm_heads(output) 188 | 189 | return self.project_out(output) 190 | 191 | ## frequency attention 192 | 193 | class FrequencyAttention(nn.Module): 194 | def __init__( 195 | self, 196 | *, 197 | K = 4, 198 | dropout = 0. 199 | ): 200 | super().__init__() 201 | self.K = K 202 | self.dropout = nn.Dropout(dropout) 203 | 204 | def forward(self, x): 205 | freqs = torch.fft.rfft(x, dim = 1) 206 | 207 | # get amplitudes 208 | 209 | amp = freqs.abs() 210 | amp = self.dropout(amp) 211 | 212 | # topk amplitudes - for seasonality, branded as attention 213 | 214 | topk_amp, _ = amp.topk(k = self.K, dim = 1, sorted = True) 215 | 216 | # mask out all freqs with lower amplitudes than the lowest value of the topk above 217 | 218 | topk_freqs = freqs.masked_fill(amp < topk_amp[:, -1:], 0.+0.j) 219 | 220 | # inverse fft 221 | 222 | return torch.fft.irfft(topk_freqs, dim = 1) 223 | 224 | ## level module 225 | 226 | class Level(nn.Module): 227 | def __init__(self, time_features, model_dim): 228 | super().__init__() 229 | self.alpha = nn.Parameter(torch.Tensor([0.])) 230 | self.to_growth = nn.Linear(model_dim, time_features) 231 | self.to_seasonal = nn.Linear(model_dim, time_features) 232 | 233 | def forward(self, x, latent_growth, latent_seasonal): 234 | # following equation in appendix A.2 235 | 236 | n, device = x.shape[1], x.device 237 | 238 | alpha = self.alpha.sigmoid() 239 | 240 | arange = torch.arange(n, device = device) 241 | powers = torch.flip(arange, dims = (0,)) 242 | 243 | # Aes for raw time series signal with seasonal terms (from frequency attention) subtracted out 244 | 245 | seasonal =self.to_seasonal(latent_seasonal) 246 | Aes_weights = alpha * (1 - alpha) ** powers 247 | seasonal_normalized_term = conv1d_fft(x - seasonal, Aes_weights) 248 | 249 | # auxiliary term 250 | 251 | growth = self.to_growth(latent_growth) 252 | growth_smoothing_weights = (1 - alpha) ** powers 253 | growth_term = conv1d_fft(growth, growth_smoothing_weights) 254 | 255 | return seasonal_normalized_term + growth_term 256 | 257 | # decoder classes 258 | 259 | class LevelStack(nn.Module): 260 | def forward(self, x, num_steps_forecast): 261 | return repeat(x[:, -1], 'b d -> b n d', n = num_steps_forecast) 262 | 263 | class GrowthDampening(nn.Module): 264 | def __init__( 265 | self, 266 | dim, 267 | heads = 8 268 | ): 269 | super().__init__() 270 | self.heads = heads 271 | self.dampen_factor = nn.Parameter(torch.randn(heads)) 272 | 273 | def forward(self, growth, *, num_steps_forecast): 274 | device, h = growth.device, self.heads 275 | 276 | dampen_factor = self.dampen_factor.sigmoid() 277 | 278 | # like level stack, it takes the last growth for forecasting 279 | 280 | last_growth = growth[:, -1] 281 | last_growth = rearrange(last_growth, 'b l (h d) -> b l 1 h d', h = h) 282 | 283 | # prepare dampening factors per head and the powers 284 | 285 | dampen_factor = rearrange(dampen_factor, 'h -> 1 1 1 h 1') 286 | powers = (torch.arange(num_steps_forecast, device = device) + 1) 287 | powers = rearrange(powers, 'n -> 1 1 n 1 1') 288 | 289 | # following Eq(2) in the paper 290 | 291 | dampened_growth = last_growth * (dampen_factor ** powers).cumsum(dim = 2) 292 | return rearrange(dampened_growth, 'b l n h d -> b l n (h d)') 293 | 294 | # main class 295 | 296 | class ETSFormer(nn.Module): 297 | def __init__( 298 | self, 299 | *, 300 | model_dim, 301 | time_features = 1, 302 | embed_kernel_size = 3, 303 | layers = 2, 304 | heads = 8, 305 | K = 4, 306 | dropout = 0. 307 | ): 308 | super().__init__() 309 | assert (model_dim % heads) == 0, 'model dimension must be divisible by number of heads' 310 | self.model_dim = model_dim 311 | self.time_features = time_features 312 | 313 | self.embed = InputEmbedding(time_features, model_dim, kernel_size = embed_kernel_size, dropout = dropout) 314 | 315 | self.encoder_layers = nn.ModuleList([]) 316 | 317 | for ind in range(layers): 318 | is_last_layer = ind == (layers - 1) 319 | 320 | self.encoder_layers.append(nn.ModuleList([ 321 | FrequencyAttention(K = K, dropout = dropout), 322 | MHESA(dim = model_dim, heads = heads, dropout = dropout), 323 | FeedForwardBlock(dim = model_dim) if not is_last_layer else None, 324 | Level(time_features = time_features, model_dim = model_dim) 325 | ])) 326 | 327 | self.growth_dampening_module = GrowthDampening(dim = model_dim, heads = heads) 328 | 329 | self.latents_to_time_features = nn.Linear(model_dim, time_features) 330 | self.level_stack = LevelStack() 331 | 332 | def forward( 333 | self, 334 | x, 335 | *, 336 | num_steps_forecast = 0, 337 | return_latents = False 338 | ): 339 | one_time_feature = x.ndim == 2 340 | 341 | if one_time_feature: 342 | x = rearrange(x, 'b n -> b n 1') 343 | 344 | z = self.embed(x) 345 | 346 | latent_growths = [] 347 | latent_seasonals = [] 348 | 349 | for freq_attn, mhes_attn, ff_block, level in self.encoder_layers: 350 | latent_seasonal = freq_attn(z) 351 | z = z - latent_seasonal 352 | 353 | latent_growth = mhes_attn(z) 354 | z = z - latent_growth 355 | 356 | if exists(ff_block): 357 | z = ff_block(z) 358 | 359 | x = level(x, latent_growth, latent_seasonal) 360 | 361 | latent_growths.append(latent_growth) 362 | latent_seasonals.append(latent_seasonal) 363 | 364 | latent_growths = torch.stack(latent_growths, dim = -2) 365 | latent_seasonals = torch.stack(latent_seasonals, dim = -2) 366 | 367 | latents = Intermediates(latent_growths, latent_seasonals, x) 368 | 369 | if num_steps_forecast == 0: 370 | return latents 371 | 372 | latent_seasonals = rearrange(latent_seasonals, 'b n l d -> b l d n') 373 | extrapolated_seasonals = fourier_extrapolate(latent_seasonals, x.shape[1], x.shape[1] + num_steps_forecast) 374 | extrapolated_seasonals = rearrange(extrapolated_seasonals, 'b l d n -> b l n d') 375 | 376 | dampened_growths = self.growth_dampening_module(latent_growths, num_steps_forecast = num_steps_forecast) 377 | level = self.level_stack(x, num_steps_forecast = num_steps_forecast) 378 | 379 | summed_latents = dampened_growths.sum(dim = 1) + extrapolated_seasonals.sum(dim = 1) 380 | forecasted = level + self.latents_to_time_features(summed_latents) 381 | 382 | if one_time_feature: 383 | forecasted = rearrange(forecasted, 'b n 1 -> b n') 384 | 385 | if return_latents: 386 | return forecasted, latents 387 | 388 | return forecasted 389 | 390 | # classification wrapper 391 | 392 | class MultiheadLayerNorm(nn.Module): 393 | def __init__(self, dim, heads = 1, eps = 1e-5): 394 | super().__init__() 395 | self.eps = eps 396 | self.g = nn.Parameter(torch.ones(heads, 1, dim)) 397 | self.b = nn.Parameter(torch.zeros(heads, 1, dim)) 398 | 399 | def forward(self, x): 400 | std = torch.var(x, dim = -1, unbiased = False, keepdim = True).sqrt() 401 | mean = torch.mean(x, dim = -1, keepdim = True) 402 | return (x - mean) / (std + self.eps) * self.g + self.b 403 | 404 | class ClassificationWrapper(nn.Module): 405 | def __init__( 406 | self, 407 | *, 408 | etsformer, 409 | num_classes = 10, 410 | heads = 16, 411 | dim_head = 32, 412 | level_kernel_size = 3, 413 | growth_kernel_size = 3, 414 | seasonal_kernel_size = 3, 415 | dropout = 0. 416 | ): 417 | super().__init__() 418 | assert isinstance(etsformer, ETSFormer) 419 | self.etsformer = etsformer 420 | model_dim = etsformer.model_dim 421 | time_features = etsformer.time_features 422 | 423 | inner_dim = dim_head * heads 424 | self.scale = dim_head ** -0.5 425 | self.dropout = nn.Dropout(dropout) 426 | 427 | self.queries = nn.Parameter(torch.randn(heads, dim_head)) 428 | 429 | self.growth_to_kv = nn.Sequential( 430 | Rearrange('b n d -> b d n'), 431 | nn.Conv1d(model_dim, inner_dim * 2, growth_kernel_size, bias = False, padding = growth_kernel_size // 2), 432 | Rearrange('... (kv h d) n -> ... (kv h) n d', kv = 2, h = heads), 433 | MultiheadLayerNorm(dim_head, heads = 2 * heads), 434 | ) 435 | 436 | self.seasonal_to_kv = nn.Sequential( 437 | Rearrange('b n d -> b d n'), 438 | nn.Conv1d(model_dim, inner_dim * 2, seasonal_kernel_size, bias = False, padding = seasonal_kernel_size // 2), 439 | Rearrange('... (kv h d) n -> ... (kv h) n d', kv = 2, h = heads), 440 | MultiheadLayerNorm(dim_head, heads = 2 * heads), 441 | ) 442 | 443 | self.level_to_kv = nn.Sequential( 444 | Rearrange('b n t -> b t n'), 445 | nn.Conv1d(time_features, inner_dim * 2, level_kernel_size, bias = False, padding = level_kernel_size // 2), 446 | Rearrange('b (kv h d) n -> b (kv h) n d', kv = 2, h = heads), 447 | MultiheadLayerNorm(dim_head, heads = 2 * heads), 448 | ) 449 | 450 | self.to_out = nn.Linear(inner_dim, model_dim) 451 | 452 | self.to_logits = nn.Sequential( 453 | nn.LayerNorm(model_dim), 454 | nn.Linear(model_dim, num_classes) 455 | ) 456 | 457 | def forward(self, timeseries): 458 | latent_growths, latent_seasonals, level_output = self.etsformer(timeseries) 459 | 460 | latent_growths = latent_growths.mean(dim = -2) 461 | latent_seasonals = latent_seasonals.mean(dim = -2) 462 | 463 | # queries, key, values 464 | 465 | q = self.queries * self.scale 466 | 467 | kvs = torch.cat(( 468 | self.growth_to_kv(latent_growths), 469 | self.seasonal_to_kv(latent_seasonals), 470 | self.level_to_kv(level_output) 471 | ), dim = -2) 472 | 473 | k, v = kvs.chunk(2, dim = 1) 474 | 475 | # cross attention pooling 476 | 477 | sim = einsum('h d, b h j d -> b h j', q, k) 478 | sim = sim - sim.amax(dim = -1, keepdim = True).detach() 479 | 480 | attn = sim.softmax(dim = -1) 481 | attn = self.dropout(attn) 482 | 483 | out = einsum('b h j, b h j d -> b h d', attn, v) 484 | out = rearrange(out, 'b ... -> b (...)') 485 | 486 | out = self.to_out(out) 487 | 488 | # project to logits 489 | 490 | return self.to_logits(out) 491 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name = 'ETSformer-pytorch', 5 | packages = find_packages(exclude=[]), 6 | version = '0.1.1', 7 | license='MIT', 8 | description = 'ETSTransformer - Exponential Smoothing Transformer for Time-Series Forecasting - Pytorch', 9 | long_description_content_type = 'text/markdown', 10 | author = 'Phil Wang', 11 | author_email = 'lucidrains@gmail.com', 12 | url = 'https://github.com/lucidrains/ETSformer-pytorch', 13 | keywords = [ 14 | 'artificial intelligence', 15 | 'deep learning', 16 | 'transformers', 17 | 'time-series', 18 | 'forecasting' 19 | ], 20 | install_requires=[ 21 | 'einops>=0.4', 22 | 'scipy', 23 | 'torch>=1.6', 24 | ], 25 | classifiers=[ 26 | 'Development Status :: 4 - Beta', 27 | 'Intended Audience :: Developers', 28 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 29 | 'License :: OSI Approved :: MIT License', 30 | 'Programming Language :: Python :: 3.6', 31 | ], 32 | ) 33 | --------------------------------------------------------------------------------