├── CODEOWNERS
├── CODE_OF_CONDUCT.md
├── LICENSE.txt
├── README.md
├── SECURITY.md
├── cost.py
├── datasets
├── PLACE_DATASETS_HERE
├── electricity.py
└── m5.py
├── datautils.py
├── models
├── __init__.py
├── dilated_conv.py
└── encoder.py
├── pics
├── CoST.png
└── results.png
├── requirements.txt
├── scripts
├── ETT_CoST.sh
├── Electricity_CoST.sh
├── M5_CoST.sh
└── Weather_CoST.sh
├── tasks
├── __init__.py
├── _eval_protocols.py
└── forecasting.py
├── train.py
└── utils.py
/CODEOWNERS:
--------------------------------------------------------------------------------
1 | # Comment line immediately above ownership line is reserved for related other information. Please be careful while editing.
2 | #ECCN:Open Source
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Salesforce Open Source Community Code of Conduct
2 |
3 | ## About the Code of Conduct
4 |
5 | Equality is a core value at Salesforce. We believe a diverse and inclusive
6 | community fosters innovation and creativity, and are committed to building a
7 | culture where everyone feels included.
8 |
9 | Salesforce open-source projects are committed to providing a friendly, safe, and
10 | welcoming environment for all, regardless of gender identity and expression,
11 | sexual orientation, disability, physical appearance, body size, ethnicity, nationality,
12 | race, age, religion, level of experience, education, socioeconomic status, or
13 | other similar personal characteristics.
14 |
15 | The goal of this code of conduct is to specify a baseline standard of behavior so
16 | that people with different social values and communication styles can work
17 | together effectively, productively, and respectfully in our open source community.
18 | It also establishes a mechanism for reporting issues and resolving conflicts.
19 |
20 | All questions and reports of abusive, harassing, or otherwise unacceptable behavior
21 | in a Salesforce open-source project may be reported by contacting the Salesforce
22 | Open Source Conduct Committee at ossconduct@salesforce.com.
23 |
24 | ## Our Pledge
25 |
26 | In the interest of fostering an open and welcoming environment, we as
27 | contributors and maintainers pledge to making participation in our project and
28 | our community a harassment-free experience for everyone, regardless of gender
29 | identity and expression, sexual orientation, disability, physical appearance,
30 | body size, ethnicity, nationality, race, age, religion, level of experience, education,
31 | socioeconomic status, or other similar personal characteristics.
32 |
33 | ## Our Standards
34 |
35 | Examples of behavior that contributes to creating a positive environment
36 | include:
37 |
38 | * Using welcoming and inclusive language
39 | * Being respectful of differing viewpoints and experiences
40 | * Gracefully accepting constructive criticism
41 | * Focusing on what is best for the community
42 | * Showing empathy toward other community members
43 |
44 | Examples of unacceptable behavior by participants include:
45 |
46 | * The use of sexualized language or imagery and unwelcome sexual attention or
47 | advances
48 | * Personal attacks, insulting/derogatory comments, or trolling
49 | * Public or private harassment
50 | * Publishing, or threatening to publish, others' private information—such as
51 | a physical or electronic address—without explicit permission
52 | * Other conduct which could reasonably be considered inappropriate in a
53 | professional setting
54 | * Advocating for or encouraging any of the above behaviors
55 |
56 | ## Our Responsibilities
57 |
58 | Project maintainers are responsible for clarifying the standards of acceptable
59 | behavior and are expected to take appropriate and fair corrective action in
60 | response to any instances of unacceptable behavior.
61 |
62 | Project maintainers have the right and responsibility to remove, edit, or
63 | reject comments, commits, code, wiki edits, issues, and other contributions
64 | that are not aligned with this Code of Conduct, or to ban temporarily or
65 | permanently any contributor for other behaviors that they deem inappropriate,
66 | threatening, offensive, or harmful.
67 |
68 | ## Scope
69 |
70 | This Code of Conduct applies both within project spaces and in public spaces
71 | when an individual is representing the project or its community. Examples of
72 | representing a project or community include using an official project email
73 | address, posting via an official social media account, or acting as an appointed
74 | representative at an online or offline event. Representation of a project may be
75 | further defined and clarified by project maintainers.
76 |
77 | ## Enforcement
78 |
79 | Instances of abusive, harassing, or otherwise unacceptable behavior may be
80 | reported by contacting the Salesforce Open Source Conduct Committee
81 | at ossconduct@salesforce.com. All complaints will be reviewed and investigated
82 | and will result in a response that is deemed necessary and appropriate to the
83 | circumstances. The committee is obligated to maintain confidentiality with
84 | regard to the reporter of an incident. Further details of specific enforcement
85 | policies may be posted separately.
86 |
87 | Project maintainers who do not follow or enforce the Code of Conduct in good
88 | faith may face temporary or permanent repercussions as determined by other
89 | members of the project's leadership and the Salesforce Open Source Conduct
90 | Committee.
91 |
92 | ## Attribution
93 |
94 | This Code of Conduct is adapted from the [Contributor Covenant][contributor-covenant-home],
95 | version 1.4, available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html.
96 | It includes adaptions and additions from [Go Community Code of Conduct][golang-coc],
97 | [CNCF Code of Conduct][cncf-coc], and [Microsoft Open Source Code of Conduct][microsoft-coc].
98 |
99 | This Code of Conduct is licensed under the [Creative Commons Attribution 3.0 License][cc-by-3-us].
100 |
101 | [contributor-covenant-home]: https://www.contributor-covenant.org (https://www.contributor-covenant.org/)
102 | [golang-coc]: https://golang.org/conduct
103 | [cncf-coc]: https://github.com/cncf/foundation/blob/master/code-of-conduct.md
104 | [microsoft-coc]: https://opensource.microsoft.com/codeofconduct/
105 | [cc-by-3-us]: https://creativecommons.org/licenses/by/3.0/us/
--------------------------------------------------------------------------------
/LICENSE.txt:
--------------------------------------------------------------------------------
1 | Copyright (c) 2022, Salesforce.com, Inc.
2 | All rights reserved.
3 |
4 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
5 |
6 | * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
7 |
8 | * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
9 |
10 | * Neither the name of Salesforce.com nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
11 |
12 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # CoST: Contrastive Learning of Disentangled Seasonal-Trend Representations for Time Series Forecasting (ICLR 2022)
2 |
3 |
4 |
5 |
6 | Figure 1. Overall CoST Architecture.
7 |
8 |
9 | Official PyTorch code repository for the [CoST paper](https://openreview.net/forum?id=PilZY3omXV2).
10 |
11 | * CoST is a contrastive learning method for learning disentangled seasonal-trend representations for time series forecasting.
12 | * CoST consistently outperforms state-of-the-art methods by a considerable margin, achieveing a 21.3% improvement in MSE on multivariate benchmarks.
13 |
14 | ## Requirements
15 | 1. Install Python 3.8, and the required dependencies.
16 | 2. Required dependencies can be installed by: ```pip install -r requirements.txt```
17 |
18 | ## Data
19 |
20 | The datasets can be obtained and put into `datasets/` folder in the following way:
21 |
22 | * [3 ETT datasets](https://github.com/zhouhaoyi/ETDataset) should be placed at `datasets/ETTh1.csv`, `datasets/ETTh2.csv` and `datasets/ETTm1.csv`.
23 | * [Electricity dataset](https://archive.ics.uci.edu/ml/datasets/ElectricityLoadDiagrams20112014) placed at `datasets/LD2011_2014.txt` and run `electricity.py`.
24 | * [Weather dataset](https://drive.google.com/drive/folders/1ohGYWWohJlOlb2gsGTeEq3Wii2egnEPR) (link from [Informer repository](https://github.com/zhouhaoyi/Informer2020)) placed at `datasets/WTH.csv`
25 | * [M5 dataset](https://drive.google.com/drive/folders/1D6EWdVSaOtrP1LEFh1REjI3vej6iUS_4) place `calendar.csv`, `sales_train_validation.csv`, `sales_train_evaluation.csv`, `sales_test_validation.csv` and `sales_test_evaluation.csv` at `datasets/` and run m5.py.
26 |
27 | ## Usage
28 | To train and evaluate CoST on a dataset, run the script from the scripts folder: ```./scripts/ETT_CoST.sh``` (edit file permissions via ```chmod u+x scripts/*```).
29 |
30 | After training and evaluation, the trained encoder, output and evaluation metrics can be found in `training//__/`.
31 |
32 | Alternatively, you can directly run the python scripts:
33 | ```train & evaluate
34 | python train.py --archive --batch-size --repr-dims --gpu --eval
35 | ```
36 | The detailed descriptions about the arguments are as following:
37 | | Parameter name | Description of parameter |
38 | | --- | --- |
39 | | dataset_name | The dataset name |
40 | | run_name | The folder name used to save model, output and evaluation metrics. This can be set to any word |
41 | | archive | The archive name that the dataset belongs to. This can be set to `forecast_csv` or `forecast_csv_univar` |
42 | | batch_size | The batch size (defaults to 8) |
43 | | repr_dims | The representation dimensions (defaults to 320) |
44 | | gpu | The gpu no. used for training and inference (defaults to 0) |
45 | | eval | Whether to perform evaluation after training |
46 | | kernels | Kernel sizes for mixture of AR experts module |
47 | | alpha | Weight for loss function |
48 |
49 | (For descriptions of more arguments, run `python train.py -h`.)
50 |
51 | ## Main Results
52 | We perform experiments on five real-world public benchmark datasets, comparing against both state-of-the-art representation learning and end-to-end forecasting approaches.
53 | CoST achieves state-of-the-art performance, beating the best performing end-to-end forecasting approach by 39.3% and 18.22% (MSE) in the multivariate and univariate settings
54 | respectively. CoST also beats next best performing feature-based approach by 21.3% and 4.71% (MSE) in the multivariate and univariate settings respectively (refer to main paper for full results).
55 |
56 |
57 |
58 |
59 |
60 | ## FAQs
61 | **Q**: ValueError: Found array with dim 4. StandardScaler expected <= 2.
62 |
63 | **A**: Please install the appropriate package requirements as found in ```requirements.txt```, in particular, ```scikit_learn==0.24.1```.
64 |
65 | **Q**: How to set the ``--kernels`` parameter?
66 |
67 | **A**: It should be list of space separated integers, e.g. ```--kernels 1 2 4```. See the `scripts` folder for further examples.
68 |
69 | ## Acknowledgements
70 | The implementation of CoST relies on resources from the following codebases and repositories, we thank the original authors for open-sourcing their work.
71 | * https://github.com/yuezhihan/ts2vec
72 | * https://github.com/zhouhaoyi/Informer2020
73 |
74 | ## Citation
75 | Please consider citing if you find this code useful to your research.
76 | @inproceedings{
77 | woo2022cost,
78 | title={Co{ST}: Contrastive Learning of Disentangled Seasonal-Trend Representations for Time Series Forecasting},
79 | author={Gerald Woo and Chenghao Liu and Doyen Sahoo and Akshat Kumar and Steven Hoi},
80 | booktitle={International Conference on Learning Representations},
81 | year={2022},
82 | url={https://openreview.net/forum?id=PilZY3omXV2}
83 | }
84 |
--------------------------------------------------------------------------------
/SECURITY.md:
--------------------------------------------------------------------------------
1 | ## Security
2 |
3 | Please report any security issue to [security@salesforce.com](mailto:security@salesforce.com)
4 | as soon as it is discovered. This library limits its runtime dependencies in
5 | order to reduce the total cost of ownership as much as can be, but all consumers
6 | should remain vigilant and have their security stakeholders review all third-party
7 | products (3PP) like this one and their dependencies.
--------------------------------------------------------------------------------
/cost.py:
--------------------------------------------------------------------------------
1 | import sys, math, random, copy
2 | from typing import Union, Callable, Optional, List
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | import torch.fft as fft
8 | from torch.utils.data import TensorDataset, DataLoader, Dataset
9 |
10 | import numpy as np
11 | from einops import rearrange, repeat, reduce
12 |
13 | from models.encoder import CoSTEncoder
14 | from utils import take_per_row, split_with_nan, centerize_vary_length_series, torch_pad_nan
15 |
16 |
17 | class PretrainDataset(Dataset):
18 |
19 | def __init__(self,
20 | data,
21 | sigma,
22 | p=0.5,
23 | multiplier=10):
24 | super().__init__()
25 | self.data = data
26 | self.p = p
27 | self.sigma = sigma
28 | self.multiplier = multiplier
29 | self.N, self.T, self.D = data.shape # num_ts, time, dim
30 |
31 | def __getitem__(self, item):
32 | ts = self.data[item % self.N]
33 | return self.transform(ts), self.transform(ts)
34 |
35 | def __len__(self):
36 | return self.data.size(0) * self.multiplier
37 |
38 | def transform(self, x):
39 | return self.jitter(self.shift(self.scale(x)))
40 |
41 | def jitter(self, x):
42 | if random.random() > self.p:
43 | return x
44 | return x + (torch.randn(x.shape) * self.sigma)
45 |
46 | def scale(self, x):
47 | if random.random() > self.p:
48 | return x
49 | return x * (torch.randn(x.size(-1)) * self.sigma + 1)
50 |
51 | def shift(self, x):
52 | if random.random() > self.p:
53 | return x
54 | return x + (torch.randn(x.size(-1)) * self.sigma)
55 |
56 |
57 | class CoSTModel(nn.Module):
58 | def __init__(self,
59 | encoder_q: nn.Module, encoder_k: nn.Module,
60 | kernels: List[int],
61 | device: Optional[str] = 'cuda',
62 | dim: Optional[int] = 128,
63 | alpha: Optional[float] = 0.05,
64 | K: Optional[int] = 65536,
65 | m: Optional[float] = 0.999,
66 | T: Optional[float] = 0.07):
67 | super().__init__()
68 |
69 | self.K = K
70 | self.m = m
71 | self.T = T
72 | self.device = device
73 |
74 | self.kernels = kernels
75 |
76 | self.alpha = alpha
77 |
78 | self.encoder_q = encoder_q
79 | self.encoder_k = encoder_k
80 |
81 | # create the encoders
82 | self.head_q = nn.Sequential(
83 | nn.Linear(dim, dim),
84 | nn.ReLU(),
85 | nn.Linear(dim, dim)
86 | )
87 | self.head_k = nn.Sequential(
88 | nn.Linear(dim, dim),
89 | nn.ReLU(),
90 | nn.Linear(dim, dim)
91 | )
92 |
93 | for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
94 | param_k.data.copy_(param_q.data) # initialize
95 | param_k.requires_grad = False # not update by gradient
96 | for param_q, param_k in zip(self.head_q.parameters(), self.head_k.parameters()):
97 | param_k.data.copy_(param_q.data) # initialize
98 | param_k.requires_grad = False # not update by gradient
99 |
100 | self.register_buffer('queue', F.normalize(torch.randn(dim, K), dim=0))
101 | self.register_buffer('queue_ptr', torch.zeros(1, dtype=torch.long))
102 |
103 |
104 | def compute_loss(self, q, k, k_negs):
105 | # compute logits
106 | # positive logits: Nx1
107 | l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)
108 | # negative logits: NxK
109 | l_neg = torch.einsum('nc,ck->nk', [q, k_negs])
110 |
111 | # logits: Nx(1+K)
112 | logits = torch.cat([l_pos, l_neg], dim=1)
113 |
114 | # apply temperature
115 | logits /= self.T
116 |
117 | # labels: positive key indicators - first dim of each batch
118 | labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda()
119 | loss = F.cross_entropy(logits, labels)
120 |
121 | return loss
122 |
123 | def convert_coeff(self, x, eps=1e-6):
124 | amp = torch.sqrt((x.real + eps).pow(2) + (x.imag + eps).pow(2))
125 | phase = torch.atan2(x.imag, x.real + eps)
126 | return amp, phase
127 |
128 | def instance_contrastive_loss(self, z1, z2):
129 | B, T = z1.size(0), z1.size(1)
130 | z = torch.cat([z1, z2], dim=0) # 2B x T x C
131 | z = z.transpose(0, 1) # T x 2B x C
132 | sim = torch.matmul(z, z.transpose(1, 2)) # T x 2B x 2B
133 | logits = torch.tril(sim, diagonal=-1)[:, :, :-1] # T x 2B x (2B-1)
134 | logits += torch.triu(sim, diagonal=1)[:, :, 1:]
135 | logits = -F.log_softmax(logits, dim=-1)
136 |
137 | i = torch.arange(B, device=z1.device)
138 | loss = (logits[:, i, B + i - 1].mean() + logits[:, B + i, i].mean()) / 2
139 | return loss
140 |
141 | def forward(self, x_q, x_k):
142 | # compute query features
143 | rand_idx = np.random.randint(0, x_q.shape[1])
144 |
145 | q_t, q_s = self.encoder_q(x_q)
146 | if q_t is not None:
147 | q_t = F.normalize(self.head_q(q_t[:, rand_idx]), dim=-1)
148 |
149 | # compute key features
150 | with torch.no_grad(): # no gradient for keys
151 | self._momentum_update_key_encoder() # update key encoder
152 | k_t, k_s = self.encoder_k(x_k)
153 | if k_t is not None:
154 | k_t = F.normalize(self.head_k(k_t[:, rand_idx]), dim=-1)
155 |
156 | loss = 0
157 |
158 | loss += self.compute_loss(q_t, k_t, self.queue.clone().detach())
159 | self._dequeue_and_enqueue(k_t)
160 |
161 | q_s = F.normalize(q_s, dim=-1)
162 | _, k_s = self.encoder_q(x_k)
163 | k_s = F.normalize(k_s, dim=-1)
164 |
165 | q_s_freq = fft.rfft(q_s, dim=1)
166 | k_s_freq = fft.rfft(k_s, dim=1)
167 | q_s_amp, q_s_phase = self.convert_coeff(q_s_freq)
168 | k_s_amp, k_s_phase = self.convert_coeff(k_s_freq)
169 |
170 | seasonal_loss = self.instance_contrastive_loss(q_s_amp, k_s_amp) + \
171 | self.instance_contrastive_loss(q_s_phase,k_s_phase)
172 | loss += (self.alpha * (seasonal_loss/2))
173 |
174 | return loss
175 |
176 | @torch.no_grad()
177 | def _momentum_update_key_encoder(self):
178 | """
179 | Momentum update for key encoder
180 | """
181 | for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
182 | param_k.data = param_k.data * self.m + param_q.data * (1 - self.m)
183 | for param_q, param_k in zip(self.head_q.parameters(), self.head_k.parameters()):
184 | param_k.data = param_k.data * self.m + param_q.data * (1 - self.m)
185 |
186 | @torch.no_grad()
187 | def _dequeue_and_enqueue(self, keys):
188 | batch_size = keys.shape[0]
189 |
190 | ptr = int(self.queue_ptr)
191 | assert self.K % batch_size == 0
192 |
193 | # replace keys at ptr (dequeue and enqueue)
194 | self.queue[:, ptr:ptr + batch_size] = keys.T
195 |
196 | ptr = (ptr + batch_size) % self.K
197 | self.queue_ptr[0] = ptr
198 |
199 |
200 | class CoST:
201 | def __init__(self,
202 | input_dims: int,
203 | kernels: List[int],
204 | alpha: bool,
205 | max_train_length: int,
206 | output_dims: int = 320,
207 | hidden_dims: int = 64,
208 | depth: int = 10,
209 | device: 'str' ='cuda',
210 | lr: float = 0.001,
211 | batch_size: int = 16,
212 | after_iter_callback: Union[Callable, None] = None,
213 | after_epoch_callback: Union[Callable, None] = None):
214 |
215 | super().__init__()
216 | self.input_dims = input_dims
217 | self.output_dims = output_dims
218 | self.hidden_dims = hidden_dims
219 | self.device = device
220 | self.lr = lr
221 | self.batch_size = batch_size
222 | self.max_train_length = max_train_length
223 |
224 | if kernels is None:
225 | kernels = []
226 |
227 | self.net = CoSTEncoder(
228 | input_dims=input_dims, output_dims=output_dims,
229 | kernels=kernels,
230 | length=max_train_length,
231 | hidden_dims=hidden_dims, depth=depth,
232 | ).to(self.device)
233 |
234 | self.cost = CoSTModel(
235 | self.net,
236 | copy.deepcopy(self.net),
237 | kernels=kernels,
238 | dim=self.net.component_dims,
239 | alpha=alpha,
240 | K=256,
241 | device=self.device,
242 | ).to(self.device)
243 |
244 | self.after_iter_callback = after_iter_callback
245 | self.after_epoch_callback = after_epoch_callback
246 |
247 | self.n_epochs = 0
248 | self.n_iters = 0
249 |
250 | def fit(self, train_data, n_epochs=None, n_iters=None, verbose=False):
251 | assert train_data.ndim == 3
252 |
253 | if n_iters is None and n_epochs is None:
254 | n_iters = 200 if train_data.size <= 100000 else 600
255 |
256 | if self.max_train_length is not None:
257 | sections = train_data.shape[1] // self.max_train_length
258 | if sections >= 2:
259 | train_data = np.concatenate(split_with_nan(train_data, sections, axis=1), axis=0)
260 |
261 | temporal_missing = np.isnan(train_data).all(axis=-1).any(axis=0)
262 | if temporal_missing[0] or temporal_missing[-1]:
263 | train_data = centerize_vary_length_series(train_data)
264 |
265 | train_data = train_data[~np.isnan(train_data).all(axis=2).all(axis=1)]
266 |
267 | multiplier = 1 if train_data.shape[0] >= self.batch_size else math.ceil(self.batch_size / train_data.shape[0])
268 | train_dataset = PretrainDataset(torch.from_numpy(train_data).to(torch.float), sigma=0.5, multiplier=multiplier)
269 | train_loader = DataLoader(train_dataset, batch_size=min(self.batch_size, len(train_dataset)), shuffle=True, drop_last=True)
270 |
271 | optimizer = torch.optim.SGD([p for p in self.cost.parameters() if p.requires_grad],
272 | lr=self.lr,
273 | momentum=0.9,
274 | weight_decay=1e-4)
275 |
276 | loss_log = []
277 |
278 | while True:
279 | if n_epochs is not None and self.n_epochs >= n_epochs:
280 | break
281 |
282 | cum_loss = 0
283 | n_epoch_iters = 0
284 |
285 | interrupted = False
286 | for batch in train_loader:
287 | if n_iters is not None and self.n_iters >= n_iters:
288 | interrupted = True
289 | break
290 |
291 | x_q, x_k = map(lambda x: x.to(self.device), batch)
292 | if self.max_train_length is not None and x_q.size(1) > self.max_train_length:
293 | window_offset = np.random.randint(x_q.size(1) - self.max_train_length + 1)
294 | x_q = x_q[:, window_offset : window_offset + self.max_train_length]
295 | x_k = x_k[:, window_offset : window_offset + self.max_train_length]
296 |
297 | optimizer.zero_grad()
298 |
299 | loss = self.cost(x_q, x_k)
300 |
301 | loss.backward()
302 | optimizer.step()
303 |
304 | cum_loss += loss.item()
305 | n_epoch_iters += 1
306 |
307 | self.n_iters += 1
308 |
309 | if self.after_iter_callback is not None:
310 | self.after_iter_callback(self, loss.item())
311 |
312 | if n_iters is not None:
313 | adjust_learning_rate(optimizer, self.lr, self.n_iters, n_iters)
314 |
315 | if interrupted:
316 | break
317 |
318 | cum_loss /= n_epoch_iters
319 | loss_log.append(cum_loss)
320 | if verbose:
321 | print(f"Epoch #{self.n_epochs}: loss={cum_loss}")
322 | self.n_epochs += 1
323 |
324 | if self.after_epoch_callback is not None:
325 | self.after_epoch_callback(self, cum_loss)
326 |
327 | if n_epochs is not None:
328 | adjust_learning_rate(optimizer, self.lr, self.n_epochs, n_epochs)
329 |
330 | return loss_log
331 |
332 | def _eval_with_pooling(self, x, mask=None, slicing=None, encoding_window=None):
333 | out_t, out_s = self.net(x.to(self.device, non_blocking=True)) # l b t d
334 | out = torch.cat([out_t[:, -1], out_s[:, -1]], dim=-1)
335 | return rearrange(out.cpu(), 'b d -> b () d')
336 |
337 | def encode(self, data, mode, mask=None, encoding_window=None, casual=False, sliding_length=None, sliding_padding=0, batch_size=None):
338 | if mode == 'forecasting':
339 | encoding_window = None
340 | slicing = None
341 | else:
342 | raise NotImplementedError(f"mode {mode} has not been implemented")
343 |
344 | assert data.ndim == 3
345 | if batch_size is None:
346 | batch_size = self.batch_size
347 | n_samples, ts_l, _ = data.shape
348 |
349 | org_training = self.net.training
350 | self.net.eval()
351 |
352 | dataset = TensorDataset(torch.from_numpy(data).to(torch.float))
353 | loader = DataLoader(dataset, batch_size=batch_size)
354 |
355 | with torch.no_grad():
356 | output = []
357 | for batch in loader:
358 | x = batch[0]
359 | if sliding_length is not None:
360 | reprs = []
361 | if n_samples < batch_size:
362 | calc_buffer = []
363 | calc_buffer_l = 0
364 | for i in range(0, ts_l, sliding_length):
365 | l = i - sliding_padding
366 | r = i + sliding_length + (sliding_padding if not casual else 0)
367 | x_sliding = torch_pad_nan(
368 | x[:, max(l, 0) : min(r, ts_l)],
369 | left=-l if l<0 else 0,
370 | right=r-ts_l if r>ts_l else 0,
371 | dim=1
372 | )
373 | if n_samples < batch_size:
374 | if calc_buffer_l + n_samples > batch_size:
375 | out = self._eval_with_pooling(
376 | torch.cat(calc_buffer, dim=0),
377 | mask,
378 | slicing=slicing,
379 | encoding_window=encoding_window
380 | )
381 | reprs += torch.split(out, n_samples)
382 | calc_buffer = []
383 | calc_buffer_l = 0
384 | calc_buffer.append(x_sliding)
385 | calc_buffer_l += n_samples
386 | else:
387 | out = self._eval_with_pooling(
388 | x_sliding,
389 | mask,
390 | slicing=slicing,
391 | encoding_window=encoding_window
392 | )
393 | reprs.append(out)
394 |
395 | if n_samples < batch_size:
396 | if calc_buffer_l > 0:
397 | out = self._eval_with_pooling(
398 | torch.cat(calc_buffer, dim=0),
399 | mask,
400 | slicing=slicing,
401 | encoding_window=encoding_window
402 | )
403 | reprs += torch.split(out, n_samples)
404 | calc_buffer = []
405 | calc_buffer_l = 0
406 |
407 | out = torch.cat(reprs, dim=1)
408 | if encoding_window == 'full_series':
409 | out = F.max_pool1d(
410 | out.transpose(1, 2).contiguous(),
411 | kernel_size = out.size(1),
412 | ).squeeze(1)
413 | else:
414 | out = self._eval_with_pooling(x, mask, encoding_window=encoding_window)
415 | if encoding_window == 'full_series':
416 | out = out.squeeze(1)
417 |
418 | output.append(out)
419 |
420 | output = torch.cat(output, dim=0)
421 |
422 | self.net.train(org_training)
423 | return output.numpy()
424 |
425 | def save(self, fn):
426 | ''' Save the model to a file.
427 |
428 | Args:
429 | fn (str): filename.
430 | '''
431 | torch.save(self.net.state_dict(), fn)
432 |
433 | def load(self, fn):
434 | ''' Load the model from a file.
435 |
436 | Args:
437 | fn (str): filename.
438 | '''
439 | state_dict = torch.load(fn, map_location=self.device)
440 | self.net.load_state_dict(state_dict)
441 |
442 |
443 | def adjust_learning_rate(optimizer, lr, epoch, epochs):
444 | """Decay the learning rate based on schedule"""
445 | lr *= 0.5 * (1. + math.cos(math.pi * epoch / epochs))
446 | for param_group in optimizer.param_groups:
447 | param_group['lr'] = lr
448 |
--------------------------------------------------------------------------------
/datasets/PLACE_DATASETS_HERE:
--------------------------------------------------------------------------------
1 | Please follow the instructions in README.md to place the datasets into this folder.
--------------------------------------------------------------------------------
/datasets/electricity.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | data_ecl = pd.read_csv('LD2011_2014.txt', parse_dates=True, sep=';', decimal=',', index_col=0)
3 | data_ecl = data_ecl.resample('1h', closed='right').sum()
4 | data_ecl = data_ecl.loc[:, data_ecl.cumsum(axis=0).iloc[8920] != 0] # filter out instances with missing values
5 | data_ecl.index = data_ecl.index.rename('date')
6 | data_ecl = data_ecl['2012':]
7 | data_ecl.to_csv('electricity.csv')
--------------------------------------------------------------------------------
/datasets/m5.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import numpy as np
3 |
4 | calendar = pd.read_csv('calendar.csv', index_col='date', parse_dates=True)
5 | train_validation = pd.read_csv('sales_train_validation.csv')
6 | train_evaluation = pd.read_csv('sales_train_evaluation.csv')
7 | test_validation = pd.read_csv('sales_test_validation.csv')
8 | test_evaluation = pd.read_csv('sales_test_evaluation.csv')
9 |
10 | all_data = pd.merge(
11 | train_evaluation,
12 | test_evaluation,
13 | how="inner",
14 | on=None,
15 | left_on=['item_id', 'dept_id', 'cat_id', 'store_id', 'state_id'],
16 | right_on=['item_id', 'dept_id', 'cat_id', 'store_id', 'state_id'],
17 | sort=False,
18 | suffixes=("_x", "_y"),
19 | copy=True,
20 | indicator=False,
21 | validate=None,
22 | )
23 |
24 | groups = {
25 | 'l1': ['item_id', 'dept_id', 'cat_id', 'store_id', 'state_id'],
26 | 'l2': ['state_id'],
27 | 'l3': ['store_id'],
28 | 'l4': ['cat_id'],
29 | 'l5': ['dept_id'],
30 | 'l6': ['state_id', 'cat_id'],
31 | 'l7': ['state_id', 'dept_id'],
32 | 'l8': ['store_id', 'cat_id'],
33 | 'l9': ['store_id', 'dept_id'],
34 | 'l10': ['item_id'],
35 | }
36 |
37 | for k, v in groups.items():
38 | if k == 'l1':
39 | grouped_data = all_data.drop(columns=v).sum().to_frame(name='total')
40 | else:
41 | grouped_data = all_data.groupby(v).sum().transpose()
42 | grouped_data['date'] = calendar.index
43 | grouped_data = grouped_data.set_index('date')
44 |
45 | if isinstance(grouped_data.columns, pd.MultiIndex):
46 | grouped_data.columns = [c[0] + "_" + c[1] for c in grouped_data.columns]
47 |
48 | grouped_data.to_csv(f'M5-{k}.csv', index=True)
49 |
--------------------------------------------------------------------------------
/datautils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | from sklearn.preprocessing import StandardScaler, MinMaxScaler
4 |
5 |
6 | def load_forecast_npy(name, univar=False):
7 | data = np.load(f'datasets/{name}.npy')
8 | if univar:
9 | data = data[: -1:]
10 |
11 | train_slice = slice(None, int(0.6 * len(data)))
12 | valid_slice = slice(int(0.6 * len(data)), int(0.8 * len(data)))
13 | test_slice = slice(int(0.8 * len(data)), None)
14 |
15 | scaler = StandardScaler().fit(data[train_slice])
16 | data = scaler.transform(data)
17 | data = np.expand_dims(data, 0)
18 |
19 | pred_lens = [24, 48, 96, 288, 672]
20 | return data, train_slice, valid_slice, test_slice, scaler, pred_lens, 0
21 |
22 | def _get_time_features(dt):
23 | return np.stack([
24 | dt.minute.to_numpy(),
25 | dt.hour.to_numpy(),
26 | dt.dayofweek.to_numpy(),
27 | dt.day.to_numpy(),
28 | dt.dayofyear.to_numpy(),
29 | dt.month.to_numpy(),
30 | dt.weekofyear.to_numpy(),
31 | ], axis=1).astype(np.float)
32 |
33 | def load_forecast_csv(name, univar=False):
34 | data = pd.read_csv(f'datasets/{name}.csv', index_col='date', parse_dates=True)
35 | dt_embed = _get_time_features(data.index)
36 | n_covariate_cols = dt_embed.shape[-1]
37 |
38 | if univar:
39 | if name in ('ETTh1', 'ETTh2', 'ETTm1', 'ETTm2'):
40 | data = data[['OT']]
41 | elif name == 'electricity':
42 | data = data[['MT_001']]
43 | elif name == 'WTH':
44 | data = data[['WetBulbCelsius']]
45 | else:
46 | data = data.iloc[:, -1:]
47 |
48 | data = data.to_numpy()
49 | if name == 'ETTh1' or name == 'ETTh2':
50 | train_slice = slice(None, 12 * 30 * 24)
51 | valid_slice = slice(12 * 30 * 24, 16 * 30 * 24)
52 | test_slice = slice(16 * 30 * 24, 20 * 30 * 24)
53 | elif name == 'ETTm1' or name == 'ETTm2':
54 | train_slice = slice(None, 12 * 30 * 24 * 4)
55 | valid_slice = slice(12 * 30 * 24 * 4, 16 * 30 * 24 * 4)
56 | test_slice = slice(16 * 30 * 24 * 4, 20 * 30 * 24 * 4)
57 | elif name.startswith('M5'):
58 | train_slice = slice(None, int(0.8 * (1913 + 28)))
59 | valid_slice = slice(int(0.8 * (1913 + 28)), 1913 + 28)
60 | test_slice = slice(1913 + 28 - 1, 1913 + 2 * 28)
61 | else:
62 | train_slice = slice(None, int(0.6 * len(data)))
63 | valid_slice = slice(int(0.6 * len(data)), int(0.8 * len(data)))
64 | test_slice = slice(int(0.8 * len(data)), None)
65 |
66 | scaler = StandardScaler().fit(data[train_slice])
67 | data = scaler.transform(data)
68 | if name in ('electricity') or name.startswith('M5'):
69 | data = np.expand_dims(data.T, -1) # Each variable is an instance rather than a feature
70 | else:
71 | data = np.expand_dims(data, 0)
72 |
73 | if n_covariate_cols > 0:
74 | dt_scaler = StandardScaler().fit(dt_embed[train_slice])
75 | dt_embed = np.expand_dims(dt_scaler.transform(dt_embed), 0)
76 | data = np.concatenate([np.repeat(dt_embed, data.shape[0], axis=0), data], axis=-1)
77 |
78 | if name in ('ETTh1', 'ETTh2', 'electricity', 'WTH'):
79 | pred_lens = [24, 48, 168, 336, 720]
80 | elif name.startswith('M5'):
81 | pred_lens = [28]
82 | else:
83 | pred_lens = [24, 48, 96, 288, 672]
84 |
85 | return data, train_slice, valid_slice, test_slice, scaler, pred_lens, n_covariate_cols
86 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/CoST/afc26aa0239470f522135f470861a1c375507e84/models/__init__.py
--------------------------------------------------------------------------------
/models/dilated_conv.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 |
6 | class SamePadConv(nn.Module):
7 | def __init__(self, in_channels, out_channels, kernel_size, dilation=1, groups=1):
8 | super().__init__()
9 | self.receptive_field = (kernel_size - 1) * dilation + 1
10 | padding = self.receptive_field // 2
11 | self.conv = nn.Conv1d(
12 | in_channels, out_channels, kernel_size,
13 | padding=padding,
14 | dilation=dilation,
15 | groups=groups
16 | )
17 | self.remove = 1 if self.receptive_field % 2 == 0 else 0
18 |
19 | def forward(self, x):
20 | out = self.conv(x)
21 | if self.remove > 0:
22 | out = out[:, :, : -self.remove]
23 | return out
24 |
25 |
26 | class ConvBlock(nn.Module):
27 | def __init__(self, in_channels, out_channels, kernel_size, dilation, final=False):
28 | super().__init__()
29 | self.conv1 = SamePadConv(in_channels, out_channels, kernel_size, dilation=dilation)
30 | self.conv2 = SamePadConv(out_channels, out_channels, kernel_size, dilation=dilation)
31 | self.projector = nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels or final else None
32 |
33 | def forward(self, x):
34 | residual = x if self.projector is None else self.projector(x)
35 | x = F.gelu(x)
36 | x = self.conv1(x)
37 | x = F.gelu(x)
38 | x = self.conv2(x)
39 | return x + residual
40 |
41 |
42 | class DilatedConvEncoder(nn.Module):
43 | def __init__(self, in_channels, channels, kernel_size, extract_layers=None):
44 | super().__init__()
45 |
46 | if extract_layers is not None:
47 | assert len(channels) - 1 in extract_layers
48 |
49 | self.extract_layers = extract_layers
50 | self.net = nn.Sequential(*[
51 | ConvBlock(
52 | channels[i-1] if i > 0 else in_channels,
53 | channels[i],
54 | kernel_size=kernel_size,
55 | dilation=2**i,
56 | final=(i == len(channels)-1)
57 | )
58 | for i in range(len(channels))
59 | ])
60 |
61 | def forward(self, x):
62 | if self.extract_layers is not None:
63 | outputs = []
64 | for idx, mod in enumerate(self.net):
65 | x = mod(x)
66 | if idx in self.extract_layers:
67 | outputs.append(x)
68 | return outputs
69 | return self.net(x)
70 |
--------------------------------------------------------------------------------
/models/encoder.py:
--------------------------------------------------------------------------------
1 | import math
2 | from typing import List
3 |
4 | import torch
5 | from torch import nn
6 | import torch.nn.functional as F
7 | import torch.fft as fft
8 | from einops import reduce, rearrange, repeat
9 |
10 | import numpy as np
11 |
12 | from .dilated_conv import DilatedConvEncoder
13 |
14 |
15 | def generate_continuous_mask(B, T, n=5, l=0.1):
16 | res = torch.full((B, T), True, dtype=torch.bool)
17 | if isinstance(n, float):
18 | n = int(n * T)
19 | n = max(min(n, T // 2), 1)
20 |
21 | if isinstance(l, float):
22 | l = int(l * T)
23 | l = max(l, 1)
24 |
25 | for i in range(B):
26 | for _ in range(n):
27 | t = np.random.randint(T-l+1)
28 | res[i, t:t+l] = False
29 | return res
30 |
31 |
32 | def generate_binomial_mask(B, T, p=0.5):
33 | return torch.from_numpy(np.random.binomial(1, p, size=(B, T))).to(torch.bool)
34 |
35 |
36 | class BandedFourierLayer(nn.Module):
37 | def __init__(self, in_channels, out_channels, band, num_bands, length=201):
38 | super().__init__()
39 |
40 | self.length = length
41 | self.total_freqs = (self.length // 2) + 1
42 |
43 | self.in_channels = in_channels
44 | self.out_channels = out_channels
45 |
46 | self.band = band # zero indexed
47 | self.num_bands = num_bands
48 |
49 | self.num_freqs = self.total_freqs // self.num_bands + (self.total_freqs % self.num_bands if self.band == self.num_bands - 1 else 0)
50 |
51 | self.start = self.band * (self.total_freqs // self.num_bands)
52 | self.end = self.start + self.num_freqs
53 |
54 |
55 | # case: from other frequencies
56 | self.weight = nn.Parameter(torch.empty((self.num_freqs, in_channels, out_channels), dtype=torch.cfloat))
57 | self.bias = nn.Parameter(torch.empty((self.num_freqs, out_channels), dtype=torch.cfloat))
58 | self.reset_parameters()
59 |
60 | def forward(self, input):
61 | # input - b t d
62 | b, t, _ = input.shape
63 | input_fft = fft.rfft(input, dim=1)
64 | output_fft = torch.zeros(b, t // 2 + 1, self.out_channels, device=input.device, dtype=torch.cfloat)
65 | output_fft[:, self.start:self.end] = self._forward(input_fft)
66 | return fft.irfft(output_fft, n=input.size(1), dim=1)
67 |
68 | def _forward(self, input):
69 | output = torch.einsum('bti,tio->bto', input[:, self.start:self.end], self.weight)
70 | return output + self.bias
71 |
72 | def reset_parameters(self) -> None:
73 | nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
74 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
75 | bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
76 | nn.init.uniform_(self.bias, -bound, bound)
77 |
78 |
79 | class CoSTEncoder(nn.Module):
80 | def __init__(self, input_dims, output_dims,
81 | kernels: List[int],
82 | length: int,
83 | hidden_dims=64, depth=10,
84 | mask_mode='binomial'):
85 | super().__init__()
86 |
87 | component_dims = output_dims // 2
88 |
89 | self.input_dims = input_dims
90 | self.output_dims = output_dims
91 | self.component_dims = component_dims
92 | self.hidden_dims = hidden_dims
93 | self.mask_mode = mask_mode
94 | self.input_fc = nn.Linear(input_dims, hidden_dims)
95 |
96 | self.feature_extractor = DilatedConvEncoder(
97 | hidden_dims,
98 | [hidden_dims] * depth + [output_dims],
99 | kernel_size=3
100 | )
101 |
102 | self.repr_dropout = nn.Dropout(p=0.1)
103 |
104 | self.kernels = kernels
105 |
106 | self.tfd = nn.ModuleList(
107 | [nn.Conv1d(output_dims, component_dims, k, padding=k-1) for k in kernels]
108 | )
109 |
110 | self.sfd = nn.ModuleList(
111 | [BandedFourierLayer(output_dims, component_dims, b, 1, length=length) for b in range(1)]
112 | )
113 |
114 | def forward(self, x, tcn_output=False, mask='all_true'): # x: B x T x input_dims
115 | nan_mask = ~x.isnan().any(axis=-1)
116 | x[~nan_mask] = 0
117 | x = self.input_fc(x) # B x T x Ch
118 |
119 | # generate & apply mask
120 | if mask is None:
121 | if self.training:
122 | mask = self.mask_mode
123 | else:
124 | mask = 'all_true'
125 |
126 | if mask == 'binomial':
127 | mask = generate_binomial_mask(x.size(0), x.size(1)).to(x.device)
128 | elif mask == 'continuous':
129 | mask = generate_continuous_mask(x.size(0), x.size(1)).to(x.device)
130 | elif mask == 'all_true':
131 | mask = x.new_full((x.size(0), x.size(1)), True, dtype=torch.bool)
132 | elif mask == 'all_false':
133 | mask = x.new_full((x.size(0), x.size(1)), False, dtype=torch.bool)
134 | elif mask == 'mask_last':
135 | mask = x.new_full((x.size(0), x.size(1)), True, dtype=torch.bool)
136 | mask[:, -1] = False
137 |
138 | mask &= nan_mask
139 | x[~mask] = 0
140 |
141 | # conv encoder
142 | x = x.transpose(1, 2) # B x Ch x T
143 | x = self.feature_extractor(x) # B x Co x T
144 |
145 | if tcn_output:
146 | return x.transpose(1, 2)
147 |
148 | trend = []
149 | for idx, mod in enumerate(self.tfd):
150 | out = mod(x) # b d t
151 | if self.kernels[idx] != 1:
152 | out = out[..., :-(self.kernels[idx] - 1)]
153 | trend.append(out.transpose(1, 2)) # b t d
154 | trend = reduce(
155 | rearrange(trend, 'list b t d -> list b t d'),
156 | 'list b t d -> b t d', 'mean'
157 | )
158 |
159 | x = x.transpose(1, 2) # B x T x Co
160 |
161 | season = []
162 | for mod in self.sfd:
163 | out = mod(x) # b t d
164 | season.append(out)
165 | season = season[0]
166 |
167 | return trend, self.repr_dropout(season)
168 |
--------------------------------------------------------------------------------
/pics/CoST.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/CoST/afc26aa0239470f522135f470861a1c375507e84/pics/CoST.png
--------------------------------------------------------------------------------
/pics/results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/CoST/afc26aa0239470f522135f470861a1c375507e84/pics/results.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | scipy==1.6.1
2 | torch==1.9.0
3 | numpy==1.22.0
4 | pandas==1.0.1
5 | scikit_learn==0.24.1
6 | einops==0.3.0
7 |
--------------------------------------------------------------------------------
/scripts/ETT_CoST.sh:
--------------------------------------------------------------------------------
1 | for seed in $(seq 0 4); do
2 | # multivar
3 | python -u train.py ETTh1 forecast_multivar --alpha 0.0005 --kernels 1 2 4 8 16 32 64 128 --max-train-length 201 --batch-size 128 --archive forecast_csv --repr-dims 320 --max-threads 8 --seed ${seed} --eval
4 | python -u train.py ETTh2 forecast_multivar --alpha 0.0005 --kernels 1 2 4 8 16 32 64 128 --max-train-length 201 --batch-size 128 --archive forecast_csv --repr-dims 320 --max-threads 8 --seed ${seed} --eval
5 | python -u train.py ETTm1 forecast_multivar --alpha 0.0005 --kernels 1 2 4 8 16 32 64 128 --max-train-length 201 --batch-size 128 --archive forecast_csv --repr-dims 320 --max-threads 8 --seed ${seed} --eval
6 | # univar
7 | python -u train.py ETTh1 forecast_univar --alpha 0.0005 --kernels 1 2 4 8 16 32 64 128 --max-train-length 201 --batch-size 128 --archive forecast_csv_univar --repr-dims 320 --max-threads 8 --seed ${seed} --eval
8 | python -u train.py ETTh2 forecast_univar --alpha 0.0005 --kernels 1 2 4 8 16 32 64 128 --max-train-length 201 --batch-size 128 --archive forecast_csv_univar --repr-dims 320 --max-threads 8 --seed ${seed} --eval
9 | python -u train.py ETTm1 forecast_univar --alpha 0.0005 --kernels 1 2 4 8 16 32 64 128 --max-train-length 201 --batch-size 128 --archive forecast_csv_univar --repr-dims 320 --max-threads 8 --seed ${seed} --eval
10 | done
11 |
--------------------------------------------------------------------------------
/scripts/Electricity_CoST.sh:
--------------------------------------------------------------------------------
1 | for seed in $(seq 0 4); do
2 | python -u train.py electricity forecast_univar --alpha 0.0005 --kernels 1 2 4 8 16 32 64 128 --max-train-length 201 --batch-size 128 --archive forecast_csv_univar --repr-dims 320 --max-threads 8 --seed ${seed} --eval
3 | python -u train.py electricity forecast_multivar --alpha 0.0005 --kernels 1 2 4 8 16 32 64 128 --max-train-length 201 --batch-size 128 --archive forecast_csv --repr-dims 320 --max-threads 8 --seed ${seed} --eval
4 | done
--------------------------------------------------------------------------------
/scripts/M5_CoST.sh:
--------------------------------------------------------------------------------
1 | for level in $(seq 1 10); do
2 | for seed in $(seq 0 4); do
3 | # multivar
4 | python -u train.py M5-l${level} forecast_multivar --alpha 0.0005 --kernels 1 2 4 8 16 32 64 128 --max-train-length 201 --batch-size 128 --archive forecast_csv --repr-dims 320 --max-threads 8 --seed ${seed} --eval
5 | # univar
6 | python -u train.py M5-l${level} forecast_univar --alpha 0.0005 --kernels 1 2 4 8 16 32 64 128 --max-train-length 201 --batch-size 128 --archive forecast_csv_univar --repr-dims 320 --max-threads 8 --seed ${seed} --eval
7 | done
8 | done
--------------------------------------------------------------------------------
/scripts/Weather_CoST.sh:
--------------------------------------------------------------------------------
1 | for seed in $(seq 0 4); do
2 | # multivar
3 | python -u train.py WTH forecast_multivar --alpha 0.0005 --kernels 1 2 4 8 16 32 64 128 --max-train-length 201 --batch-size 128 --archive forecast_csv --repr-dims 320 --max-threads 8 --seed ${seed} --eval
4 | # univar
5 | python -u train.py WTH forecast_univar --alpha 0.0005 --kernels 1 2 4 8 16 32 64 128 --max-train-length 201 --batch-size 128 --archive forecast_csv_univar --repr-dims 320 --max-threads 8 --seed ${seed} --eval
6 | done
--------------------------------------------------------------------------------
/tasks/__init__.py:
--------------------------------------------------------------------------------
1 | from .forecasting import eval_forecasting
2 |
--------------------------------------------------------------------------------
/tasks/_eval_protocols.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from sklearn.linear_model import Ridge
3 | from sklearn.model_selection import GridSearchCV, train_test_split
4 |
5 |
6 | def fit_ridge(train_features, train_y, valid_features, valid_y, MAX_SAMPLES=100000):
7 | # If the training set is too large, subsample MAX_SAMPLES examples
8 | if train_features.shape[0] > MAX_SAMPLES:
9 | split = train_test_split(
10 | train_features, train_y,
11 | train_size=MAX_SAMPLES, random_state=0
12 | )
13 | train_features = split[0]
14 | train_y = split[2]
15 | if valid_features.shape[0] > MAX_SAMPLES:
16 | split = train_test_split(
17 | valid_features, valid_y,
18 | train_size=MAX_SAMPLES, random_state=0
19 | )
20 | valid_features = split[0]
21 | valid_y = split[2]
22 | alphas = [0.1, 0.2, 0.5, 1, 2, 5, 10, 20, 50, 100, 200, 500, 1000]
23 | valid_results = []
24 | for alpha in alphas:
25 | lr = Ridge(alpha=alpha).fit(train_features, train_y)
26 | valid_pred = lr.predict(valid_features)
27 | score = np.sqrt(((valid_pred - valid_y) ** 2).mean()) + np.abs(valid_pred - valid_y).mean()
28 | valid_results.append(score)
29 | best_alpha = alphas[np.argmin(valid_results)]
30 |
31 | lr = Ridge(alpha=best_alpha)
32 | lr.fit(train_features, train_y)
33 | return lr
34 |
--------------------------------------------------------------------------------
/tasks/forecasting.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import time
3 | from . import _eval_protocols as eval_protocols
4 |
5 |
6 | def generate_pred_samples(features, data, pred_len, drop=0):
7 | n = data.shape[1]
8 | features = features[:, :-pred_len]
9 | labels = np.stack([ data[:, i:1+n+i-pred_len] for i in range(pred_len)], axis=2)[:, 1:]
10 | features = features[:, drop:]
11 | labels = labels[:, drop:]
12 | return features.reshape(-1, features.shape[-1]), \
13 | labels.reshape(-1, labels.shape[2]*labels.shape[3])
14 |
15 |
16 | def cal_metrics(pred, target):
17 | return {
18 | 'MSE': ((pred - target) ** 2).mean(),
19 | 'MAE': np.abs(pred - target).mean()
20 | }
21 |
22 |
23 | def eval_forecasting(model, data, train_slice, valid_slice, test_slice, scaler, pred_lens, n_covariate_cols, padding):
24 | t = time.time()
25 |
26 | all_repr = model.encode(
27 | data,
28 | mode='forecasting',
29 | casual=True,
30 | sliding_length=1,
31 | sliding_padding=padding,
32 | batch_size=256
33 | )
34 |
35 | train_repr = all_repr[:, train_slice]
36 | valid_repr = all_repr[:, valid_slice]
37 | test_repr = all_repr[:, test_slice]
38 |
39 | train_data = data[:, train_slice, n_covariate_cols:]
40 | valid_data = data[:, valid_slice, n_covariate_cols:]
41 | test_data = data[:, test_slice, n_covariate_cols:]
42 |
43 | encoder_infer_time = time.time() - t
44 |
45 | ours_result = {}
46 | lr_train_time = {}
47 | lr_infer_time = {}
48 | out_log = {}
49 | for pred_len in pred_lens:
50 | train_features, train_labels = generate_pred_samples(train_repr, train_data, pred_len, drop=padding)
51 | valid_features, valid_labels = generate_pred_samples(valid_repr, valid_data, pred_len)
52 | test_features, test_labels = generate_pred_samples(test_repr, test_data, pred_len)
53 |
54 | t = time.time()
55 | lr = eval_protocols.fit_ridge(train_features, train_labels, valid_features, valid_labels)
56 | lr_train_time[pred_len] = time.time() - t
57 |
58 | t = time.time()
59 | test_pred = lr.predict(test_features)
60 | lr_infer_time[pred_len] = time.time() - t
61 |
62 | ori_shape = test_data.shape[0], -1, pred_len, test_data.shape[2]
63 | test_pred = test_pred.reshape(ori_shape)
64 | test_labels = test_labels.reshape(ori_shape)
65 |
66 | if test_data.shape[0] > 1:
67 | test_pred_inv = scaler.inverse_transform(test_pred.swapaxes(0, 3)).swapaxes(0, 3)
68 | test_labels_inv = scaler.inverse_transform(test_labels.swapaxes(0, 3)).swapaxes(0, 3)
69 | else:
70 | test_pred_inv = scaler.inverse_transform(test_pred)
71 | test_labels_inv = scaler.inverse_transform(test_labels)
72 | out_log[pred_len] = {
73 | 'norm': test_pred,
74 | 'raw': test_pred_inv,
75 | 'norm_gt': test_labels,
76 | 'raw_gt': test_labels_inv
77 | }
78 | ours_result[pred_len] = {
79 | 'norm': cal_metrics(test_pred, test_labels),
80 | 'raw': cal_metrics(test_pred_inv, test_labels_inv)
81 | }
82 |
83 | eval_res = {
84 | 'ours': ours_result,
85 | 'encoder_infer_time': encoder_infer_time,
86 | 'lr_train_time': lr_train_time,
87 | 'lr_infer_time': lr_infer_time
88 | }
89 | return out_log, eval_res
90 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import time
4 | import datetime
5 | import math
6 | import numpy as np
7 | import tasks
8 | import datautils
9 | from utils import init_dl_program, name_with_datetime, pkl_save, data_dropout
10 |
11 | # import methods
12 | from cost import CoST
13 |
14 |
15 | def save_checkpoint_callback(
16 | save_every=1,
17 | unit='epoch'
18 | ):
19 | assert unit in ('epoch', 'iter')
20 | def callback(model, loss):
21 | n = model.n_epochs if unit == 'epoch' else model.n_iters
22 | if n % save_every == 0:
23 | model.save(f'{run_dir}/model_{n}.pkl')
24 | return callback
25 |
26 | if __name__ == '__main__':
27 | parser = argparse.ArgumentParser()
28 | parser.add_argument('dataset', help='The dataset name')
29 | parser.add_argument('run_name', help='The folder name used to save model, output and evaluation metrics. This can be set to any word')
30 | parser.add_argument('--archive', type=str, required=True, help='The archive name that the dataset belongs to. This can be set to forecast_csv, or forecast_csv_univar')
31 | parser.add_argument('--gpu', type=int, default=0, help='The gpu no. used for training and inference (defaults to 0)')
32 | parser.add_argument('--batch-size', type=int, default=8, help='The batch size (defaults to 8)')
33 | parser.add_argument('--lr', type=float, default=0.001, help='The learning rate (defaults to 0.001)')
34 | parser.add_argument('--repr-dims', type=int, default=320, help='The representation dimension (defaults to 320)')
35 | parser.add_argument('--max-train-length', type=int, default=3000, help='For sequence with a length greater than , it would be cropped into some sequences, each of which has a length less than (defaults to 3000)')
36 | parser.add_argument('--iters', type=int, default=None, help='The number of iterations')
37 | parser.add_argument('--epochs', type=int, default=None, help='The number of epochs')
38 | parser.add_argument('--save-every', type=int, default=None, help='Save the checkpoint every iterations/epochs')
39 | parser.add_argument('--seed', type=int, default=None, help='The random seed')
40 | parser.add_argument('--max-threads', type=int, default=None, help='The maximum allowed number of threads used by this process')
41 | parser.add_argument('--eval', action="store_true", help='Whether to perform evaluation after training')
42 |
43 | parser.add_argument('--kernels', type=int, nargs='+', default=[1, 2, 4, 8, 16, 32, 64, 128], help='The kernel sizes used in the mixture of AR expert layers')
44 | parser.add_argument('--alpha', type=float, default=0.0005, help='Weighting hyperparameter for loss function')
45 |
46 | args = parser.parse_args()
47 |
48 | print("Dataset:", args.dataset)
49 | print("Arguments:", str(args))
50 |
51 | device = init_dl_program(args.gpu, seed=args.seed, max_threads=args.max_threads)
52 |
53 | if args.archive == 'forecast_csv':
54 | task_type = 'forecasting'
55 | data, train_slice, valid_slice, test_slice, scaler, pred_lens, n_covariate_cols = datautils.load_forecast_csv(args.dataset)
56 | train_data = data[:, train_slice]
57 | elif args.archive == 'forecast_csv_univar':
58 | task_type = 'forecasting'
59 | data, train_slice, valid_slice, test_slice, scaler, pred_lens, n_covariate_cols = datautils.load_forecast_csv(args.dataset, univar=True)
60 | train_data = data[:, train_slice]
61 | elif args.archive == 'forecast_npy':
62 | task_type = 'forecasting'
63 | data, train_slice, valid_slice, test_slice, scaler, pred_lens, n_covariate_cols = datautils.load_forecast_npy(args.dataset)
64 | train_data = data[:, train_slice]
65 | elif args.archive == 'forecast_npy_univar':
66 | task_type = 'forecasting'
67 | data, train_slice, valid_slice, test_slice, scaler, pred_lens, n_covariate_cols = datautils.load_forecast_npy(args.dataset, univar=True)
68 | train_data = data[:, train_slice]
69 | else:
70 | raise ValueError(f"Archive type {args.archive} is not supported.")
71 |
72 | config = dict(
73 | batch_size=args.batch_size,
74 | lr=args.lr,
75 | output_dims=args.repr_dims,
76 | )
77 |
78 | if args.save_every is not None:
79 | unit = 'epoch' if args.epochs is not None else 'iter'
80 | config[f'after_{unit}_callback'] = save_checkpoint_callback(args.save_every, unit)
81 |
82 | run_dir = f"training/{args.dataset}/{name_with_datetime(args.run_name)}"
83 |
84 | os.makedirs(run_dir, exist_ok=True)
85 |
86 | t = time.time()
87 |
88 | model = CoST(
89 | input_dims=train_data.shape[-1],
90 | kernels=args.kernels,
91 | alpha=args.alpha,
92 | max_train_length=args.max_train_length,
93 | device=device,
94 | **config
95 | )
96 |
97 | loss_log = model.fit(
98 | train_data,
99 | n_epochs=args.epochs,
100 | n_iters=args.iters,
101 | verbose=True
102 | )
103 | model.save(f'{run_dir}/model.pkl')
104 |
105 | t = time.time() - t
106 | print(f"\nTraining time: {datetime.timedelta(seconds=t)}\n")
107 |
108 | if args.eval:
109 | out, eval_res = tasks.eval_forecasting(model, data, train_slice, valid_slice, test_slice, scaler, pred_lens, n_covariate_cols, args.max_train_length-1)
110 | print('Evaluation result:', eval_res)
111 | pkl_save(f'{run_dir}/eval_res.pkl', eval_res)
112 | pkl_save(f'{run_dir}/out.pkl', out)
113 |
114 | print("Finished.")
115 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import pickle
4 | import torch
5 | import random
6 | from datetime import datetime
7 | import torch.nn as nn
8 |
9 |
10 | def pkl_save(name, var):
11 | with open(name, 'wb') as f:
12 | pickle.dump(var, f)
13 |
14 | def pkl_load(name):
15 | with open(name, 'rb') as f:
16 | return pickle.load(f)
17 |
18 | def torch_pad_nan(arr, left=0, right=0, dim=0):
19 | if left > 0:
20 | padshape = list(arr.shape)
21 | padshape[dim] = left
22 | arr = torch.cat((torch.full(padshape, np.nan), arr), dim=dim)
23 | if right > 0:
24 | padshape = list(arr.shape)
25 | padshape[dim] = right
26 | arr = torch.cat((arr, torch.full(padshape, np.nan)), dim=dim)
27 | return arr
28 |
29 | def pad_nan_to_target(array, target_length, axis=0, both_side=False):
30 | assert array.dtype in [np.float16, np.float32, np.float64]
31 | pad_size = target_length - array.shape[axis]
32 | if pad_size <= 0:
33 | return array
34 | npad = [(0, 0)] * array.ndim
35 | if both_side:
36 | npad[axis] = (pad_size // 2, pad_size - pad_size//2)
37 | else:
38 | npad[axis] = (0, pad_size)
39 | return np.pad(array, pad_width=npad, mode='constant', constant_values=np.nan)
40 |
41 | def split_with_nan(x, sections, axis=0):
42 | assert x.dtype in [np.float16, np.float32, np.float64]
43 | arrs = np.array_split(x, sections, axis=axis)
44 | target_length = arrs[0].shape[axis]
45 | for i in range(len(arrs)):
46 | arrs[i] = pad_nan_to_target(arrs[i], target_length, axis=axis)
47 | return arrs
48 |
49 | def take_per_row(A, indx, num_elem):
50 | all_indx = indx[:,None] + np.arange(num_elem)
51 | return A[torch.arange(all_indx.shape[0])[:,None], all_indx]
52 |
53 | def centerize_vary_length_series(x):
54 | prefix_zeros = np.argmax(~np.isnan(x).all(axis=-1), axis=1)
55 | suffix_zeros = np.argmax(~np.isnan(x[:, ::-1]).all(axis=-1), axis=1)
56 | offset = (prefix_zeros + suffix_zeros) // 2 - prefix_zeros
57 | rows, column_indices = np.ogrid[:x.shape[0], :x.shape[1]]
58 | offset[offset < 0] += x.shape[1]
59 | column_indices = column_indices - offset[:, np.newaxis]
60 | return x[rows, column_indices]
61 |
62 | def data_dropout(arr, p):
63 | B, T = arr.shape[0], arr.shape[1]
64 | mask = np.full(B*T, False, dtype=np.bool)
65 | ele_sel = np.random.choice(
66 | B*T,
67 | size=int(B*T*p),
68 | replace=False
69 | )
70 | mask[ele_sel] = True
71 | res = arr.copy()
72 | res[mask.reshape(B, T)] = np.nan
73 | return res
74 |
75 | def name_with_datetime(prefix='default'):
76 | now = datetime.now()
77 | return prefix + '_' + now.strftime("%Y%m%d_%H%M%S")
78 |
79 | def init_dl_program(
80 | device_name,
81 | seed=None,
82 | use_cudnn=True,
83 | deterministic=False,
84 | benchmark=False,
85 | use_tf32=False,
86 | max_threads=None
87 | ):
88 | import torch
89 | if max_threads is not None:
90 | torch.set_num_threads(max_threads) # intraop
91 | if torch.get_num_interop_threads() != max_threads:
92 | torch.set_num_interop_threads(max_threads) # interop
93 | try:
94 | import mkl
95 | except:
96 | pass
97 | else:
98 | mkl.set_num_threads(max_threads)
99 |
100 | if seed is not None:
101 | random.seed(seed)
102 | seed += 1
103 | np.random.seed(seed)
104 | seed += 1
105 | torch.manual_seed(seed)
106 |
107 | if isinstance(device_name, (str, int)):
108 | device_name = [device_name]
109 |
110 | devices = []
111 | for t in reversed(device_name):
112 | t_device = torch.device(t)
113 | devices.append(t_device)
114 | if t_device.type == 'cuda':
115 | assert torch.cuda.is_available()
116 | torch.cuda.set_device(t_device)
117 | if seed is not None:
118 | seed += 1
119 | torch.cuda.manual_seed(seed)
120 | devices.reverse()
121 | torch.backends.cudnn.enabled = use_cudnn
122 | torch.backends.cudnn.deterministic = deterministic
123 | torch.backends.cudnn.benchmark = benchmark
124 |
125 | if hasattr(torch.backends.cudnn, 'allow_tf32'):
126 | torch.backends.cudnn.allow_tf32 = use_tf32
127 | torch.backends.cuda.matmul.allow_tf32 = use_tf32
128 |
129 | return devices if len(devices) > 1 else devices[0]
130 |
--------------------------------------------------------------------------------