├── .github
└── workflows
│ └── python-publish.yml
├── .gitignore
├── LICENSE
├── README.md
├── etsformer.png
├── etsformer_pytorch
├── __init__.py
└── etsformer_pytorch.py
└── setup.py
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 |
2 |
3 | # This workflow will upload a Python Package using Twine when a release is created
4 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
5 |
6 | # This workflow uses actions that are not certified by GitHub.
7 | # They are provided by a third-party and are governed by
8 | # separate terms of service, privacy policy, and support
9 | # documentation.
10 |
11 | name: Upload Python Package
12 |
13 | on:
14 | release:
15 | types: [published]
16 |
17 | jobs:
18 | deploy:
19 |
20 | runs-on: ubuntu-latest
21 |
22 | steps:
23 | - uses: actions/checkout@v2
24 | - name: Set up Python
25 | uses: actions/setup-python@v2
26 | with:
27 | python-version: '3.x'
28 | - name: Install dependencies
29 | run: |
30 | python -m pip install --upgrade pip
31 | pip install build
32 | - name: Build package
33 | run: python -m build
34 | - name: Publish package
35 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
36 | with:
37 | user: __token__
38 | password: ${{ secrets.PYPI_API_TOKEN }}
39 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Phil Wang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## ETSformer - Pytorch
4 |
5 | Implementation of ETSformer, state of the art time-series Transformer, in Pytorch
6 |
7 | ## Install
8 |
9 | ```bash
10 | $ pip install etsformer-pytorch
11 | ```
12 |
13 | ## Usage
14 |
15 | ```python
16 | import torch
17 | from etsformer_pytorch import ETSFormer
18 |
19 | model = ETSFormer(
20 | time_features = 4,
21 | model_dim = 512, # in paper they use 512
22 | embed_kernel_size = 3, # kernel size for 1d conv for input embedding
23 | layers = 2, # number of encoder and corresponding decoder layers
24 | heads = 8, # number of exponential smoothing attention heads
25 | K = 4, # num frequencies with highest amplitude to keep (attend to)
26 | dropout = 0.2 # dropout (in paper they did 0.2)
27 | )
28 |
29 | timeseries = torch.randn(1, 1024, 4)
30 |
31 | pred = model(timeseries, num_steps_forecast = 32) # (1, 32, 4) - (batch, num steps forecast, num time features)
32 | ```
33 |
34 | For using ETSFormer for classification, using cross attention pooling on all latents and level output
35 |
36 | ```python
37 | import torch
38 | from etsformer_pytorch import ETSFormer, ClassificationWrapper
39 |
40 | etsformer = ETSFormer(
41 | time_features = 1,
42 | model_dim = 512,
43 | embed_kernel_size = 3,
44 | layers = 2,
45 | heads = 8,
46 | K = 4,
47 | dropout = 0.2
48 | )
49 |
50 | adapter = ClassificationWrapper(
51 | etsformer = etsformer,
52 | dim_head = 32,
53 | heads = 16,
54 | dropout = 0.2,
55 | level_kernel_size = 5,
56 | num_classes = 10
57 | )
58 |
59 | timeseries = torch.randn(1, 1024)
60 |
61 | logits = adapter(timeseries) # (1, 10)
62 | ```
63 |
64 | ## Citation
65 |
66 | ```bibtex
67 | @misc{woo2022etsformer,
68 | title = {ETSformer: Exponential Smoothing Transformers for Time-series Forecasting},
69 | author = {Gerald Woo and Chenghao Liu and Doyen Sahoo and Akshat Kumar and Steven Hoi},
70 | year = {2022},
71 | eprint = {2202.01381},
72 | archivePrefix = {arXiv},
73 | primaryClass = {cs.LG}
74 | }
75 | ```
76 |
--------------------------------------------------------------------------------
/etsformer.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/ETSformer-pytorch/68f9ed9fa361e4a9966fb192275b686eacf00745/etsformer.png
--------------------------------------------------------------------------------
/etsformer_pytorch/__init__.py:
--------------------------------------------------------------------------------
1 | from etsformer_pytorch.etsformer_pytorch import (
2 | ETSFormer,
3 | ClassificationWrapper,
4 | MHESA
5 | )
6 |
--------------------------------------------------------------------------------
/etsformer_pytorch/etsformer_pytorch.py:
--------------------------------------------------------------------------------
1 | from math import pi
2 | from collections import namedtuple
3 |
4 | import torch
5 | import torch.nn.functional as F
6 | from torch import nn, einsum
7 |
8 | from scipy.fftpack import next_fast_len
9 | from einops import rearrange, repeat
10 | from einops.layers.torch import Rearrange
11 |
12 | # constants
13 |
14 | Intermediates = namedtuple('Intermediates', ['growth_latents', 'seasonal_latents', 'level_output'])
15 |
16 | # helper functions
17 |
18 | def exists(val):
19 | return val is not None
20 |
21 | # fourier helpers
22 |
23 | def fourier_extrapolate(signal, start, end):
24 | device = signal.device
25 | fhat = torch.fft.fft(signal)
26 | fhat_len = fhat.shape[-1]
27 | time = torch.linspace(start, end - 1, end - start, device = device, dtype = torch.complex64)
28 | freqs = torch.linspace(0, fhat_len - 1, fhat_len, device = device, dtype = torch.complex64)
29 | res = fhat[..., None, :] * (1.j * 2 * pi * freqs[..., None, :] * time[..., :, None] / fhat_len).exp() / fhat_len
30 | return res.sum(dim = -1).real
31 |
32 | # classes
33 |
34 | def InputEmbedding(time_features, model_dim, kernel_size = 3, dropout = 0.):
35 | return nn.Sequential(
36 | Rearrange('b n d -> b d n'),
37 | nn.Conv1d(time_features, model_dim, kernel_size = kernel_size, padding = kernel_size // 2),
38 | nn.Dropout(dropout),
39 | Rearrange('b d n -> b n d'),
40 | )
41 |
42 | def FeedForward(dim, mult = 4, dropout = 0.):
43 | return nn.Sequential(
44 | nn.Linear(dim, dim * mult),
45 | nn.Sigmoid(),
46 | nn.Dropout(dropout),
47 | nn.Linear(dim * mult, dim),
48 | nn.Dropout(dropout)
49 | )
50 |
51 | class FeedForwardBlock(nn.Module):
52 | def __init__(
53 | self,
54 | *,
55 | dim,
56 | **kwargs
57 | ):
58 | super().__init__()
59 | self.norm = nn.LayerNorm(dim)
60 | self.ff = FeedForward(dim, **kwargs)
61 | self.post_norm = nn.LayerNorm(dim)
62 |
63 | def forward(self, x):
64 | x = self.norm(x)
65 | return self.post_norm(x + self.ff(x))
66 |
67 | # encoder related classes
68 |
69 | ## multi-head exponential smoothing attention
70 |
71 | def conv1d_fft(x, weights, dim = -2, weight_dim = -1):
72 | # Algorithm 3 in paper
73 |
74 | N = x.shape[dim]
75 | M = weights.shape[weight_dim]
76 |
77 | fast_len = next_fast_len(N + M - 1)
78 |
79 | f_x = torch.fft.rfft(x, n = fast_len, dim = dim)
80 | f_weight = torch.fft.rfft(weights, n = fast_len, dim = weight_dim)
81 |
82 | f_v_weight = f_x * rearrange(f_weight.conj(), '... -> ... 1')
83 | out = torch.fft.irfft(f_v_weight, fast_len, dim = dim)
84 | out = out.roll(-1, dims = (dim,))
85 |
86 | indices = torch.arange(start = fast_len - N, end = fast_len, dtype = torch.long, device = x.device)
87 | out = out.index_select(dim, indices)
88 | return out
89 |
90 | class MHESA(nn.Module):
91 | def __init__(
92 | self,
93 | *,
94 | dim,
95 | heads = 8,
96 | dropout = 0.,
97 | norm_heads = False
98 | ):
99 | super().__init__()
100 | self.heads = heads
101 | self.initial_state = nn.Parameter(torch.randn(heads, dim // heads))
102 |
103 | self.dropout = nn.Dropout(dropout)
104 | self.alpha = nn.Parameter(torch.randn(heads))
105 |
106 | self.norm_heads = nn.Sequential(
107 | Rearrange('b n (h d) -> b (h d) n', h = heads),
108 | nn.GroupNorm(heads, dim),
109 | Rearrange('b (h d) n -> b n (h d)', h = heads)
110 | ) if norm_heads else nn.Identity()
111 |
112 | self.project_in = nn.Linear(dim, dim)
113 | self.project_out = nn.Linear(dim, dim)
114 |
115 | def naive_Aes(self, x, weights):
116 | n, h = x.shape[-2], self.heads
117 |
118 | # in appendix A.1 - Algorithm 2
119 |
120 | arange = torch.arange(n, device = x.device)
121 |
122 | weights = repeat(weights, '... l -> ... t l', t = n)
123 | indices = repeat(arange, 'l -> h t l', h = h, t = n)
124 |
125 | indices = (indices - rearrange(arange + 1, 't -> 1 t 1')) % n
126 |
127 | weights = weights.gather(-1, indices)
128 | weights = self.dropout(weights)
129 |
130 | # causal
131 |
132 | weights = weights.tril()
133 |
134 | # multiply
135 |
136 | output = einsum('b h n d, h m n -> b h m d', x, weights)
137 | return output
138 |
139 | def forward(self, x, naive = False):
140 | b, n, d, h, device = *x.shape, self.heads, x.device
141 |
142 | # linear project in
143 |
144 | x = self.project_in(x)
145 |
146 | # split out heads
147 |
148 | x = rearrange(x, 'b n (h d) -> b h n d', h = h)
149 |
150 | # temporal difference
151 |
152 | x = torch.cat((
153 | repeat(self.initial_state, 'h d -> b h 1 d', b = b),
154 | x
155 | ), dim = -2)
156 |
157 | x = x[:, :, 1:] - x[:, :, :-1]
158 |
159 | # prepare exponential alpha
160 |
161 | alpha = self.alpha.sigmoid()
162 | alpha = rearrange(alpha, 'h -> h 1')
163 |
164 | # arange == powers
165 |
166 | arange = torch.arange(n, device = device)
167 | weights = alpha * (1 - alpha) ** torch.flip(arange, dims = (0,))
168 |
169 | if naive:
170 | output = self.naive_Aes(x, weights)
171 | else:
172 | output = conv1d_fft(x, weights)
173 |
174 | # get initial state contribution
175 |
176 | init_weight = (1 - alpha) ** (arange + 1)
177 | init_output = rearrange(init_weight, 'h n -> h n 1') * rearrange(self.initial_state, 'h d -> h 1 d')
178 |
179 | output = output + init_output
180 |
181 | # merge heads
182 |
183 | output = rearrange(output, 'b h n d -> b n (h d)')
184 |
185 | # maybe sub-ln from https://arxiv.org/abs/2210.06423 - retnet used groupnorm
186 |
187 | output = self.norm_heads(output)
188 |
189 | return self.project_out(output)
190 |
191 | ## frequency attention
192 |
193 | class FrequencyAttention(nn.Module):
194 | def __init__(
195 | self,
196 | *,
197 | K = 4,
198 | dropout = 0.
199 | ):
200 | super().__init__()
201 | self.K = K
202 | self.dropout = nn.Dropout(dropout)
203 |
204 | def forward(self, x):
205 | freqs = torch.fft.rfft(x, dim = 1)
206 |
207 | # get amplitudes
208 |
209 | amp = freqs.abs()
210 | amp = self.dropout(amp)
211 |
212 | # topk amplitudes - for seasonality, branded as attention
213 |
214 | topk_amp, _ = amp.topk(k = self.K, dim = 1, sorted = True)
215 |
216 | # mask out all freqs with lower amplitudes than the lowest value of the topk above
217 |
218 | topk_freqs = freqs.masked_fill(amp < topk_amp[:, -1:], 0.+0.j)
219 |
220 | # inverse fft
221 |
222 | return torch.fft.irfft(topk_freqs, dim = 1)
223 |
224 | ## level module
225 |
226 | class Level(nn.Module):
227 | def __init__(self, time_features, model_dim):
228 | super().__init__()
229 | self.alpha = nn.Parameter(torch.Tensor([0.]))
230 | self.to_growth = nn.Linear(model_dim, time_features)
231 | self.to_seasonal = nn.Linear(model_dim, time_features)
232 |
233 | def forward(self, x, latent_growth, latent_seasonal):
234 | # following equation in appendix A.2
235 |
236 | n, device = x.shape[1], x.device
237 |
238 | alpha = self.alpha.sigmoid()
239 |
240 | arange = torch.arange(n, device = device)
241 | powers = torch.flip(arange, dims = (0,))
242 |
243 | # Aes for raw time series signal with seasonal terms (from frequency attention) subtracted out
244 |
245 | seasonal =self.to_seasonal(latent_seasonal)
246 | Aes_weights = alpha * (1 - alpha) ** powers
247 | seasonal_normalized_term = conv1d_fft(x - seasonal, Aes_weights)
248 |
249 | # auxiliary term
250 |
251 | growth = self.to_growth(latent_growth)
252 | growth_smoothing_weights = (1 - alpha) ** powers
253 | growth_term = conv1d_fft(growth, growth_smoothing_weights)
254 |
255 | return seasonal_normalized_term + growth_term
256 |
257 | # decoder classes
258 |
259 | class LevelStack(nn.Module):
260 | def forward(self, x, num_steps_forecast):
261 | return repeat(x[:, -1], 'b d -> b n d', n = num_steps_forecast)
262 |
263 | class GrowthDampening(nn.Module):
264 | def __init__(
265 | self,
266 | dim,
267 | heads = 8
268 | ):
269 | super().__init__()
270 | self.heads = heads
271 | self.dampen_factor = nn.Parameter(torch.randn(heads))
272 |
273 | def forward(self, growth, *, num_steps_forecast):
274 | device, h = growth.device, self.heads
275 |
276 | dampen_factor = self.dampen_factor.sigmoid()
277 |
278 | # like level stack, it takes the last growth for forecasting
279 |
280 | last_growth = growth[:, -1]
281 | last_growth = rearrange(last_growth, 'b l (h d) -> b l 1 h d', h = h)
282 |
283 | # prepare dampening factors per head and the powers
284 |
285 | dampen_factor = rearrange(dampen_factor, 'h -> 1 1 1 h 1')
286 | powers = (torch.arange(num_steps_forecast, device = device) + 1)
287 | powers = rearrange(powers, 'n -> 1 1 n 1 1')
288 |
289 | # following Eq(2) in the paper
290 |
291 | dampened_growth = last_growth * (dampen_factor ** powers).cumsum(dim = 2)
292 | return rearrange(dampened_growth, 'b l n h d -> b l n (h d)')
293 |
294 | # main class
295 |
296 | class ETSFormer(nn.Module):
297 | def __init__(
298 | self,
299 | *,
300 | model_dim,
301 | time_features = 1,
302 | embed_kernel_size = 3,
303 | layers = 2,
304 | heads = 8,
305 | K = 4,
306 | dropout = 0.
307 | ):
308 | super().__init__()
309 | assert (model_dim % heads) == 0, 'model dimension must be divisible by number of heads'
310 | self.model_dim = model_dim
311 | self.time_features = time_features
312 |
313 | self.embed = InputEmbedding(time_features, model_dim, kernel_size = embed_kernel_size, dropout = dropout)
314 |
315 | self.encoder_layers = nn.ModuleList([])
316 |
317 | for ind in range(layers):
318 | is_last_layer = ind == (layers - 1)
319 |
320 | self.encoder_layers.append(nn.ModuleList([
321 | FrequencyAttention(K = K, dropout = dropout),
322 | MHESA(dim = model_dim, heads = heads, dropout = dropout),
323 | FeedForwardBlock(dim = model_dim) if not is_last_layer else None,
324 | Level(time_features = time_features, model_dim = model_dim)
325 | ]))
326 |
327 | self.growth_dampening_module = GrowthDampening(dim = model_dim, heads = heads)
328 |
329 | self.latents_to_time_features = nn.Linear(model_dim, time_features)
330 | self.level_stack = LevelStack()
331 |
332 | def forward(
333 | self,
334 | x,
335 | *,
336 | num_steps_forecast = 0,
337 | return_latents = False
338 | ):
339 | one_time_feature = x.ndim == 2
340 |
341 | if one_time_feature:
342 | x = rearrange(x, 'b n -> b n 1')
343 |
344 | z = self.embed(x)
345 |
346 | latent_growths = []
347 | latent_seasonals = []
348 |
349 | for freq_attn, mhes_attn, ff_block, level in self.encoder_layers:
350 | latent_seasonal = freq_attn(z)
351 | z = z - latent_seasonal
352 |
353 | latent_growth = mhes_attn(z)
354 | z = z - latent_growth
355 |
356 | if exists(ff_block):
357 | z = ff_block(z)
358 |
359 | x = level(x, latent_growth, latent_seasonal)
360 |
361 | latent_growths.append(latent_growth)
362 | latent_seasonals.append(latent_seasonal)
363 |
364 | latent_growths = torch.stack(latent_growths, dim = -2)
365 | latent_seasonals = torch.stack(latent_seasonals, dim = -2)
366 |
367 | latents = Intermediates(latent_growths, latent_seasonals, x)
368 |
369 | if num_steps_forecast == 0:
370 | return latents
371 |
372 | latent_seasonals = rearrange(latent_seasonals, 'b n l d -> b l d n')
373 | extrapolated_seasonals = fourier_extrapolate(latent_seasonals, x.shape[1], x.shape[1] + num_steps_forecast)
374 | extrapolated_seasonals = rearrange(extrapolated_seasonals, 'b l d n -> b l n d')
375 |
376 | dampened_growths = self.growth_dampening_module(latent_growths, num_steps_forecast = num_steps_forecast)
377 | level = self.level_stack(x, num_steps_forecast = num_steps_forecast)
378 |
379 | summed_latents = dampened_growths.sum(dim = 1) + extrapolated_seasonals.sum(dim = 1)
380 | forecasted = level + self.latents_to_time_features(summed_latents)
381 |
382 | if one_time_feature:
383 | forecasted = rearrange(forecasted, 'b n 1 -> b n')
384 |
385 | if return_latents:
386 | return forecasted, latents
387 |
388 | return forecasted
389 |
390 | # classification wrapper
391 |
392 | class MultiheadLayerNorm(nn.Module):
393 | def __init__(self, dim, heads = 1, eps = 1e-5):
394 | super().__init__()
395 | self.eps = eps
396 | self.g = nn.Parameter(torch.ones(heads, 1, dim))
397 | self.b = nn.Parameter(torch.zeros(heads, 1, dim))
398 |
399 | def forward(self, x):
400 | std = torch.var(x, dim = -1, unbiased = False, keepdim = True).sqrt()
401 | mean = torch.mean(x, dim = -1, keepdim = True)
402 | return (x - mean) / (std + self.eps) * self.g + self.b
403 |
404 | class ClassificationWrapper(nn.Module):
405 | def __init__(
406 | self,
407 | *,
408 | etsformer,
409 | num_classes = 10,
410 | heads = 16,
411 | dim_head = 32,
412 | level_kernel_size = 3,
413 | growth_kernel_size = 3,
414 | seasonal_kernel_size = 3,
415 | dropout = 0.
416 | ):
417 | super().__init__()
418 | assert isinstance(etsformer, ETSFormer)
419 | self.etsformer = etsformer
420 | model_dim = etsformer.model_dim
421 | time_features = etsformer.time_features
422 |
423 | inner_dim = dim_head * heads
424 | self.scale = dim_head ** -0.5
425 | self.dropout = nn.Dropout(dropout)
426 |
427 | self.queries = nn.Parameter(torch.randn(heads, dim_head))
428 |
429 | self.growth_to_kv = nn.Sequential(
430 | Rearrange('b n d -> b d n'),
431 | nn.Conv1d(model_dim, inner_dim * 2, growth_kernel_size, bias = False, padding = growth_kernel_size // 2),
432 | Rearrange('... (kv h d) n -> ... (kv h) n d', kv = 2, h = heads),
433 | MultiheadLayerNorm(dim_head, heads = 2 * heads),
434 | )
435 |
436 | self.seasonal_to_kv = nn.Sequential(
437 | Rearrange('b n d -> b d n'),
438 | nn.Conv1d(model_dim, inner_dim * 2, seasonal_kernel_size, bias = False, padding = seasonal_kernel_size // 2),
439 | Rearrange('... (kv h d) n -> ... (kv h) n d', kv = 2, h = heads),
440 | MultiheadLayerNorm(dim_head, heads = 2 * heads),
441 | )
442 |
443 | self.level_to_kv = nn.Sequential(
444 | Rearrange('b n t -> b t n'),
445 | nn.Conv1d(time_features, inner_dim * 2, level_kernel_size, bias = False, padding = level_kernel_size // 2),
446 | Rearrange('b (kv h d) n -> b (kv h) n d', kv = 2, h = heads),
447 | MultiheadLayerNorm(dim_head, heads = 2 * heads),
448 | )
449 |
450 | self.to_out = nn.Linear(inner_dim, model_dim)
451 |
452 | self.to_logits = nn.Sequential(
453 | nn.LayerNorm(model_dim),
454 | nn.Linear(model_dim, num_classes)
455 | )
456 |
457 | def forward(self, timeseries):
458 | latent_growths, latent_seasonals, level_output = self.etsformer(timeseries)
459 |
460 | latent_growths = latent_growths.mean(dim = -2)
461 | latent_seasonals = latent_seasonals.mean(dim = -2)
462 |
463 | # queries, key, values
464 |
465 | q = self.queries * self.scale
466 |
467 | kvs = torch.cat((
468 | self.growth_to_kv(latent_growths),
469 | self.seasonal_to_kv(latent_seasonals),
470 | self.level_to_kv(level_output)
471 | ), dim = -2)
472 |
473 | k, v = kvs.chunk(2, dim = 1)
474 |
475 | # cross attention pooling
476 |
477 | sim = einsum('h d, b h j d -> b h j', q, k)
478 | sim = sim - sim.amax(dim = -1, keepdim = True).detach()
479 |
480 | attn = sim.softmax(dim = -1)
481 | attn = self.dropout(attn)
482 |
483 | out = einsum('b h j, b h j d -> b h d', attn, v)
484 | out = rearrange(out, 'b ... -> b (...)')
485 |
486 | out = self.to_out(out)
487 |
488 | # project to logits
489 |
490 | return self.to_logits(out)
491 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name = 'ETSformer-pytorch',
5 | packages = find_packages(exclude=[]),
6 | version = '0.1.1',
7 | license='MIT',
8 | description = 'ETSTransformer - Exponential Smoothing Transformer for Time-Series Forecasting - Pytorch',
9 | long_description_content_type = 'text/markdown',
10 | author = 'Phil Wang',
11 | author_email = 'lucidrains@gmail.com',
12 | url = 'https://github.com/lucidrains/ETSformer-pytorch',
13 | keywords = [
14 | 'artificial intelligence',
15 | 'deep learning',
16 | 'transformers',
17 | 'time-series',
18 | 'forecasting'
19 | ],
20 | install_requires=[
21 | 'einops>=0.4',
22 | 'scipy',
23 | 'torch>=1.6',
24 | ],
25 | classifiers=[
26 | 'Development Status :: 4 - Beta',
27 | 'Intended Audience :: Developers',
28 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
29 | 'License :: OSI Approved :: MIT License',
30 | 'Programming Language :: Python :: 3.6',
31 | ],
32 | )
33 |
--------------------------------------------------------------------------------