├── models
├── __init__.py
├── dilated_conv.py
└── encoder.py
├── tasks
├── __init__.py
├── _eval_protocols.py
└── forecasting.py
├── pics
├── CoST.png
└── results.png
├── datasets
├── PLACE_DATASETS_HERE
├── electricity.py
└── m5.py
├── requirements.txt
├── CODEOWNERS
├── SECURITY.md
├── scripts
├── Electricity_CoST.sh
├── Weather_CoST.sh
├── M5_CoST.sh
└── ETT_CoST.sh
├── LICENSE.txt
├── datautils.py
├── utils.py
├── README.md
├── CODE_OF_CONDUCT.md
├── train.py
└── cost.py
/models/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tasks/__init__.py:
--------------------------------------------------------------------------------
1 | from .forecasting import eval_forecasting
2 |
--------------------------------------------------------------------------------
/pics/CoST.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/CoST/HEAD/pics/CoST.png
--------------------------------------------------------------------------------
/pics/results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/CoST/HEAD/pics/results.png
--------------------------------------------------------------------------------
/datasets/PLACE_DATASETS_HERE:
--------------------------------------------------------------------------------
1 | Please follow the instructions in README.md to place the datasets into this folder.
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | scipy==1.6.1
2 | torch==1.9.0
3 | numpy==1.22.0
4 | pandas==1.0.1
5 | scikit_learn==0.24.1
6 | einops==0.3.0
7 |
--------------------------------------------------------------------------------
/CODEOWNERS:
--------------------------------------------------------------------------------
1 | # Comment line immediately above ownership line is reserved for related other information. Please be careful while editing.
2 | #ECCN:Open Source
--------------------------------------------------------------------------------
/SECURITY.md:
--------------------------------------------------------------------------------
1 | ## Security
2 |
3 | Please report any security issue to [security@salesforce.com](mailto:security@salesforce.com)
4 | as soon as it is discovered. This library limits its runtime dependencies in
5 | order to reduce the total cost of ownership as much as can be, but all consumers
6 | should remain vigilant and have their security stakeholders review all third-party
7 | products (3PP) like this one and their dependencies.
--------------------------------------------------------------------------------
/datasets/electricity.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | data_ecl = pd.read_csv('LD2011_2014.txt', parse_dates=True, sep=';', decimal=',', index_col=0)
3 | data_ecl = data_ecl.resample('1h', closed='right').sum()
4 | data_ecl = data_ecl.loc[:, data_ecl.cumsum(axis=0).iloc[8920] != 0] # filter out instances with missing values
5 | data_ecl.index = data_ecl.index.rename('date')
6 | data_ecl = data_ecl['2012':]
7 | data_ecl.to_csv('electricity.csv')
--------------------------------------------------------------------------------
/scripts/Electricity_CoST.sh:
--------------------------------------------------------------------------------
1 | for seed in $(seq 0 4); do
2 | python -u train.py electricity forecast_univar --alpha 0.0005 --kernels 1 2 4 8 16 32 64 128 --max-train-length 201 --batch-size 128 --archive forecast_csv_univar --repr-dims 320 --max-threads 8 --seed ${seed} --eval
3 | python -u train.py electricity forecast_multivar --alpha 0.0005 --kernels 1 2 4 8 16 32 64 128 --max-train-length 201 --batch-size 128 --archive forecast_csv --repr-dims 320 --max-threads 8 --seed ${seed} --eval
4 | done
--------------------------------------------------------------------------------
/scripts/Weather_CoST.sh:
--------------------------------------------------------------------------------
1 | for seed in $(seq 0 4); do
2 | # multivar
3 | python -u train.py WTH forecast_multivar --alpha 0.0005 --kernels 1 2 4 8 16 32 64 128 --max-train-length 201 --batch-size 128 --archive forecast_csv --repr-dims 320 --max-threads 8 --seed ${seed} --eval
4 | # univar
5 | python -u train.py WTH forecast_univar --alpha 0.0005 --kernels 1 2 4 8 16 32 64 128 --max-train-length 201 --batch-size 128 --archive forecast_csv_univar --repr-dims 320 --max-threads 8 --seed ${seed} --eval
6 | done
--------------------------------------------------------------------------------
/scripts/M5_CoST.sh:
--------------------------------------------------------------------------------
1 | for level in $(seq 1 10); do
2 | for seed in $(seq 0 4); do
3 | # multivar
4 | python -u train.py M5-l${level} forecast_multivar --alpha 0.0005 --kernels 1 2 4 8 16 32 64 128 --max-train-length 201 --batch-size 128 --archive forecast_csv --repr-dims 320 --max-threads 8 --seed ${seed} --eval
5 | # univar
6 | python -u train.py M5-l${level} forecast_univar --alpha 0.0005 --kernels 1 2 4 8 16 32 64 128 --max-train-length 201 --batch-size 128 --archive forecast_csv_univar --repr-dims 320 --max-threads 8 --seed ${seed} --eval
7 | done
8 | done
--------------------------------------------------------------------------------
/scripts/ETT_CoST.sh:
--------------------------------------------------------------------------------
1 | for seed in $(seq 0 4); do
2 | # multivar
3 | python -u train.py ETTh1 forecast_multivar --alpha 0.0005 --kernels 1 2 4 8 16 32 64 128 --max-train-length 201 --batch-size 128 --archive forecast_csv --repr-dims 320 --max-threads 8 --seed ${seed} --eval
4 | python -u train.py ETTh2 forecast_multivar --alpha 0.0005 --kernels 1 2 4 8 16 32 64 128 --max-train-length 201 --batch-size 128 --archive forecast_csv --repr-dims 320 --max-threads 8 --seed ${seed} --eval
5 | python -u train.py ETTm1 forecast_multivar --alpha 0.0005 --kernels 1 2 4 8 16 32 64 128 --max-train-length 201 --batch-size 128 --archive forecast_csv --repr-dims 320 --max-threads 8 --seed ${seed} --eval
6 | # univar
7 | python -u train.py ETTh1 forecast_univar --alpha 0.0005 --kernels 1 2 4 8 16 32 64 128 --max-train-length 201 --batch-size 128 --archive forecast_csv_univar --repr-dims 320 --max-threads 8 --seed ${seed} --eval
8 | python -u train.py ETTh2 forecast_univar --alpha 0.0005 --kernels 1 2 4 8 16 32 64 128 --max-train-length 201 --batch-size 128 --archive forecast_csv_univar --repr-dims 320 --max-threads 8 --seed ${seed} --eval
9 | python -u train.py ETTm1 forecast_univar --alpha 0.0005 --kernels 1 2 4 8 16 32 64 128 --max-train-length 201 --batch-size 128 --archive forecast_csv_univar --repr-dims 320 --max-threads 8 --seed ${seed} --eval
10 | done
11 |
--------------------------------------------------------------------------------
/tasks/_eval_protocols.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from sklearn.linear_model import Ridge
3 | from sklearn.model_selection import GridSearchCV, train_test_split
4 |
5 |
6 | def fit_ridge(train_features, train_y, valid_features, valid_y, MAX_SAMPLES=100000):
7 | # If the training set is too large, subsample MAX_SAMPLES examples
8 | if train_features.shape[0] > MAX_SAMPLES:
9 | split = train_test_split(
10 | train_features, train_y,
11 | train_size=MAX_SAMPLES, random_state=0
12 | )
13 | train_features = split[0]
14 | train_y = split[2]
15 | if valid_features.shape[0] > MAX_SAMPLES:
16 | split = train_test_split(
17 | valid_features, valid_y,
18 | train_size=MAX_SAMPLES, random_state=0
19 | )
20 | valid_features = split[0]
21 | valid_y = split[2]
22 | alphas = [0.1, 0.2, 0.5, 1, 2, 5, 10, 20, 50, 100, 200, 500, 1000]
23 | valid_results = []
24 | for alpha in alphas:
25 | lr = Ridge(alpha=alpha).fit(train_features, train_y)
26 | valid_pred = lr.predict(valid_features)
27 | score = np.sqrt(((valid_pred - valid_y) ** 2).mean()) + np.abs(valid_pred - valid_y).mean()
28 | valid_results.append(score)
29 | best_alpha = alphas[np.argmin(valid_results)]
30 |
31 | lr = Ridge(alpha=best_alpha)
32 | lr.fit(train_features, train_y)
33 | return lr
34 |
--------------------------------------------------------------------------------
/LICENSE.txt:
--------------------------------------------------------------------------------
1 | Copyright (c) 2022, Salesforce.com, Inc.
2 | All rights reserved.
3 |
4 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
5 |
6 | * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
7 |
8 | * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
9 |
10 | * Neither the name of Salesforce.com nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
11 |
12 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
--------------------------------------------------------------------------------
/datasets/m5.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import numpy as np
3 |
4 | calendar = pd.read_csv('calendar.csv', index_col='date', parse_dates=True)
5 | train_validation = pd.read_csv('sales_train_validation.csv')
6 | train_evaluation = pd.read_csv('sales_train_evaluation.csv')
7 | test_validation = pd.read_csv('sales_test_validation.csv')
8 | test_evaluation = pd.read_csv('sales_test_evaluation.csv')
9 |
10 | all_data = pd.merge(
11 | train_evaluation,
12 | test_evaluation,
13 | how="inner",
14 | on=None,
15 | left_on=['item_id', 'dept_id', 'cat_id', 'store_id', 'state_id'],
16 | right_on=['item_id', 'dept_id', 'cat_id', 'store_id', 'state_id'],
17 | sort=False,
18 | suffixes=("_x", "_y"),
19 | copy=True,
20 | indicator=False,
21 | validate=None,
22 | )
23 |
24 | groups = {
25 | 'l1': ['item_id', 'dept_id', 'cat_id', 'store_id', 'state_id'],
26 | 'l2': ['state_id'],
27 | 'l3': ['store_id'],
28 | 'l4': ['cat_id'],
29 | 'l5': ['dept_id'],
30 | 'l6': ['state_id', 'cat_id'],
31 | 'l7': ['state_id', 'dept_id'],
32 | 'l8': ['store_id', 'cat_id'],
33 | 'l9': ['store_id', 'dept_id'],
34 | 'l10': ['item_id'],
35 | }
36 |
37 | for k, v in groups.items():
38 | if k == 'l1':
39 | grouped_data = all_data.drop(columns=v).sum().to_frame(name='total')
40 | else:
41 | grouped_data = all_data.groupby(v).sum().transpose()
42 | grouped_data['date'] = calendar.index
43 | grouped_data = grouped_data.set_index('date')
44 |
45 | if isinstance(grouped_data.columns, pd.MultiIndex):
46 | grouped_data.columns = [c[0] + "_" + c[1] for c in grouped_data.columns]
47 |
48 | grouped_data.to_csv(f'M5-{k}.csv', index=True)
49 |
--------------------------------------------------------------------------------
/models/dilated_conv.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 |
6 | class SamePadConv(nn.Module):
7 | def __init__(self, in_channels, out_channels, kernel_size, dilation=1, groups=1):
8 | super().__init__()
9 | self.receptive_field = (kernel_size - 1) * dilation + 1
10 | padding = self.receptive_field // 2
11 | self.conv = nn.Conv1d(
12 | in_channels, out_channels, kernel_size,
13 | padding=padding,
14 | dilation=dilation,
15 | groups=groups
16 | )
17 | self.remove = 1 if self.receptive_field % 2 == 0 else 0
18 |
19 | def forward(self, x):
20 | out = self.conv(x)
21 | if self.remove > 0:
22 | out = out[:, :, : -self.remove]
23 | return out
24 |
25 |
26 | class ConvBlock(nn.Module):
27 | def __init__(self, in_channels, out_channels, kernel_size, dilation, final=False):
28 | super().__init__()
29 | self.conv1 = SamePadConv(in_channels, out_channels, kernel_size, dilation=dilation)
30 | self.conv2 = SamePadConv(out_channels, out_channels, kernel_size, dilation=dilation)
31 | self.projector = nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels or final else None
32 |
33 | def forward(self, x):
34 | residual = x if self.projector is None else self.projector(x)
35 | x = F.gelu(x)
36 | x = self.conv1(x)
37 | x = F.gelu(x)
38 | x = self.conv2(x)
39 | return x + residual
40 |
41 |
42 | class DilatedConvEncoder(nn.Module):
43 | def __init__(self, in_channels, channels, kernel_size, extract_layers=None):
44 | super().__init__()
45 |
46 | if extract_layers is not None:
47 | assert len(channels) - 1 in extract_layers
48 |
49 | self.extract_layers = extract_layers
50 | self.net = nn.Sequential(*[
51 | ConvBlock(
52 | channels[i-1] if i > 0 else in_channels,
53 | channels[i],
54 | kernel_size=kernel_size,
55 | dilation=2**i,
56 | final=(i == len(channels)-1)
57 | )
58 | for i in range(len(channels))
59 | ])
60 |
61 | def forward(self, x):
62 | if self.extract_layers is not None:
63 | outputs = []
64 | for idx, mod in enumerate(self.net):
65 | x = mod(x)
66 | if idx in self.extract_layers:
67 | outputs.append(x)
68 | return outputs
69 | return self.net(x)
70 |
--------------------------------------------------------------------------------
/tasks/forecasting.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import time
3 | from . import _eval_protocols as eval_protocols
4 |
5 |
6 | def generate_pred_samples(features, data, pred_len, drop=0):
7 | n = data.shape[1]
8 | features = features[:, :-pred_len]
9 | labels = np.stack([ data[:, i:1+n+i-pred_len] for i in range(pred_len)], axis=2)[:, 1:]
10 | features = features[:, drop:]
11 | labels = labels[:, drop:]
12 | return features.reshape(-1, features.shape[-1]), \
13 | labels.reshape(-1, labels.shape[2]*labels.shape[3])
14 |
15 |
16 | def cal_metrics(pred, target):
17 | return {
18 | 'MSE': ((pred - target) ** 2).mean(),
19 | 'MAE': np.abs(pred - target).mean()
20 | }
21 |
22 |
23 | def eval_forecasting(model, data, train_slice, valid_slice, test_slice, scaler, pred_lens, n_covariate_cols, padding):
24 | t = time.time()
25 |
26 | all_repr = model.encode(
27 | data,
28 | mode='forecasting',
29 | casual=True,
30 | sliding_length=1,
31 | sliding_padding=padding,
32 | batch_size=256
33 | )
34 |
35 | train_repr = all_repr[:, train_slice]
36 | valid_repr = all_repr[:, valid_slice]
37 | test_repr = all_repr[:, test_slice]
38 |
39 | train_data = data[:, train_slice, n_covariate_cols:]
40 | valid_data = data[:, valid_slice, n_covariate_cols:]
41 | test_data = data[:, test_slice, n_covariate_cols:]
42 |
43 | encoder_infer_time = time.time() - t
44 |
45 | ours_result = {}
46 | lr_train_time = {}
47 | lr_infer_time = {}
48 | out_log = {}
49 | for pred_len in pred_lens:
50 | train_features, train_labels = generate_pred_samples(train_repr, train_data, pred_len, drop=padding)
51 | valid_features, valid_labels = generate_pred_samples(valid_repr, valid_data, pred_len)
52 | test_features, test_labels = generate_pred_samples(test_repr, test_data, pred_len)
53 |
54 | t = time.time()
55 | lr = eval_protocols.fit_ridge(train_features, train_labels, valid_features, valid_labels)
56 | lr_train_time[pred_len] = time.time() - t
57 |
58 | t = time.time()
59 | test_pred = lr.predict(test_features)
60 | lr_infer_time[pred_len] = time.time() - t
61 |
62 | ori_shape = test_data.shape[0], -1, pred_len, test_data.shape[2]
63 | test_pred = test_pred.reshape(ori_shape)
64 | test_labels = test_labels.reshape(ori_shape)
65 |
66 | if test_data.shape[0] > 1:
67 | test_pred_inv = scaler.inverse_transform(test_pred.swapaxes(0, 3)).swapaxes(0, 3)
68 | test_labels_inv = scaler.inverse_transform(test_labels.swapaxes(0, 3)).swapaxes(0, 3)
69 | else:
70 | test_pred_inv = scaler.inverse_transform(test_pred)
71 | test_labels_inv = scaler.inverse_transform(test_labels)
72 | out_log[pred_len] = {
73 | 'norm': test_pred,
74 | 'raw': test_pred_inv,
75 | 'norm_gt': test_labels,
76 | 'raw_gt': test_labels_inv
77 | }
78 | ours_result[pred_len] = {
79 | 'norm': cal_metrics(test_pred, test_labels),
80 | 'raw': cal_metrics(test_pred_inv, test_labels_inv)
81 | }
82 |
83 | eval_res = {
84 | 'ours': ours_result,
85 | 'encoder_infer_time': encoder_infer_time,
86 | 'lr_train_time': lr_train_time,
87 | 'lr_infer_time': lr_infer_time
88 | }
89 | return out_log, eval_res
90 |
--------------------------------------------------------------------------------
/datautils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | from sklearn.preprocessing import StandardScaler, MinMaxScaler
4 |
5 |
6 | def load_forecast_npy(name, univar=False):
7 | data = np.load(f'datasets/{name}.npy')
8 | if univar:
9 | data = data[: -1:]
10 |
11 | train_slice = slice(None, int(0.6 * len(data)))
12 | valid_slice = slice(int(0.6 * len(data)), int(0.8 * len(data)))
13 | test_slice = slice(int(0.8 * len(data)), None)
14 |
15 | scaler = StandardScaler().fit(data[train_slice])
16 | data = scaler.transform(data)
17 | data = np.expand_dims(data, 0)
18 |
19 | pred_lens = [24, 48, 96, 288, 672]
20 | return data, train_slice, valid_slice, test_slice, scaler, pred_lens, 0
21 |
22 | def _get_time_features(dt):
23 | return np.stack([
24 | dt.minute.to_numpy(),
25 | dt.hour.to_numpy(),
26 | dt.dayofweek.to_numpy(),
27 | dt.day.to_numpy(),
28 | dt.dayofyear.to_numpy(),
29 | dt.month.to_numpy(),
30 | dt.weekofyear.to_numpy(),
31 | ], axis=1).astype(np.float)
32 |
33 | def load_forecast_csv(name, univar=False):
34 | data = pd.read_csv(f'datasets/{name}.csv', index_col='date', parse_dates=True)
35 | dt_embed = _get_time_features(data.index)
36 | n_covariate_cols = dt_embed.shape[-1]
37 |
38 | if univar:
39 | if name in ('ETTh1', 'ETTh2', 'ETTm1', 'ETTm2'):
40 | data = data[['OT']]
41 | elif name == 'electricity':
42 | data = data[['MT_001']]
43 | elif name == 'WTH':
44 | data = data[['WetBulbCelsius']]
45 | else:
46 | data = data.iloc[:, -1:]
47 |
48 | data = data.to_numpy()
49 | if name == 'ETTh1' or name == 'ETTh2':
50 | train_slice = slice(None, 12 * 30 * 24)
51 | valid_slice = slice(12 * 30 * 24, 16 * 30 * 24)
52 | test_slice = slice(16 * 30 * 24, 20 * 30 * 24)
53 | elif name == 'ETTm1' or name == 'ETTm2':
54 | train_slice = slice(None, 12 * 30 * 24 * 4)
55 | valid_slice = slice(12 * 30 * 24 * 4, 16 * 30 * 24 * 4)
56 | test_slice = slice(16 * 30 * 24 * 4, 20 * 30 * 24 * 4)
57 | elif name.startswith('M5'):
58 | train_slice = slice(None, int(0.8 * (1913 + 28)))
59 | valid_slice = slice(int(0.8 * (1913 + 28)), 1913 + 28)
60 | test_slice = slice(1913 + 28 - 1, 1913 + 2 * 28)
61 | else:
62 | train_slice = slice(None, int(0.6 * len(data)))
63 | valid_slice = slice(int(0.6 * len(data)), int(0.8 * len(data)))
64 | test_slice = slice(int(0.8 * len(data)), None)
65 |
66 | scaler = StandardScaler().fit(data[train_slice])
67 | data = scaler.transform(data)
68 | if name in ('electricity') or name.startswith('M5'):
69 | data = np.expand_dims(data.T, -1) # Each variable is an instance rather than a feature
70 | else:
71 | data = np.expand_dims(data, 0)
72 |
73 | if n_covariate_cols > 0:
74 | dt_scaler = StandardScaler().fit(dt_embed[train_slice])
75 | dt_embed = np.expand_dims(dt_scaler.transform(dt_embed), 0)
76 | data = np.concatenate([np.repeat(dt_embed, data.shape[0], axis=0), data], axis=-1)
77 |
78 | if name in ('ETTh1', 'ETTh2', 'electricity', 'WTH'):
79 | pred_lens = [24, 48, 168, 336, 720]
80 | elif name.startswith('M5'):
81 | pred_lens = [28]
82 | else:
83 | pred_lens = [24, 48, 96, 288, 672]
84 |
85 | return data, train_slice, valid_slice, test_slice, scaler, pred_lens, n_covariate_cols
86 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import pickle
4 | import torch
5 | import random
6 | from datetime import datetime
7 | import torch.nn as nn
8 |
9 |
10 | def pkl_save(name, var):
11 | with open(name, 'wb') as f:
12 | pickle.dump(var, f)
13 |
14 | def pkl_load(name):
15 | with open(name, 'rb') as f:
16 | return pickle.load(f)
17 |
18 | def torch_pad_nan(arr, left=0, right=0, dim=0):
19 | if left > 0:
20 | padshape = list(arr.shape)
21 | padshape[dim] = left
22 | arr = torch.cat((torch.full(padshape, np.nan), arr), dim=dim)
23 | if right > 0:
24 | padshape = list(arr.shape)
25 | padshape[dim] = right
26 | arr = torch.cat((arr, torch.full(padshape, np.nan)), dim=dim)
27 | return arr
28 |
29 | def pad_nan_to_target(array, target_length, axis=0, both_side=False):
30 | assert array.dtype in [np.float16, np.float32, np.float64]
31 | pad_size = target_length - array.shape[axis]
32 | if pad_size <= 0:
33 | return array
34 | npad = [(0, 0)] * array.ndim
35 | if both_side:
36 | npad[axis] = (pad_size // 2, pad_size - pad_size//2)
37 | else:
38 | npad[axis] = (0, pad_size)
39 | return np.pad(array, pad_width=npad, mode='constant', constant_values=np.nan)
40 |
41 | def split_with_nan(x, sections, axis=0):
42 | assert x.dtype in [np.float16, np.float32, np.float64]
43 | arrs = np.array_split(x, sections, axis=axis)
44 | target_length = arrs[0].shape[axis]
45 | for i in range(len(arrs)):
46 | arrs[i] = pad_nan_to_target(arrs[i], target_length, axis=axis)
47 | return arrs
48 |
49 | def take_per_row(A, indx, num_elem):
50 | all_indx = indx[:,None] + np.arange(num_elem)
51 | return A[torch.arange(all_indx.shape[0])[:,None], all_indx]
52 |
53 | def centerize_vary_length_series(x):
54 | prefix_zeros = np.argmax(~np.isnan(x).all(axis=-1), axis=1)
55 | suffix_zeros = np.argmax(~np.isnan(x[:, ::-1]).all(axis=-1), axis=1)
56 | offset = (prefix_zeros + suffix_zeros) // 2 - prefix_zeros
57 | rows, column_indices = np.ogrid[:x.shape[0], :x.shape[1]]
58 | offset[offset < 0] += x.shape[1]
59 | column_indices = column_indices - offset[:, np.newaxis]
60 | return x[rows, column_indices]
61 |
62 | def data_dropout(arr, p):
63 | B, T = arr.shape[0], arr.shape[1]
64 | mask = np.full(B*T, False, dtype=np.bool)
65 | ele_sel = np.random.choice(
66 | B*T,
67 | size=int(B*T*p),
68 | replace=False
69 | )
70 | mask[ele_sel] = True
71 | res = arr.copy()
72 | res[mask.reshape(B, T)] = np.nan
73 | return res
74 |
75 | def name_with_datetime(prefix='default'):
76 | now = datetime.now()
77 | return prefix + '_' + now.strftime("%Y%m%d_%H%M%S")
78 |
79 | def init_dl_program(
80 | device_name,
81 | seed=None,
82 | use_cudnn=True,
83 | deterministic=False,
84 | benchmark=False,
85 | use_tf32=False,
86 | max_threads=None
87 | ):
88 | import torch
89 | if max_threads is not None:
90 | torch.set_num_threads(max_threads) # intraop
91 | if torch.get_num_interop_threads() != max_threads:
92 | torch.set_num_interop_threads(max_threads) # interop
93 | try:
94 | import mkl
95 | except:
96 | pass
97 | else:
98 | mkl.set_num_threads(max_threads)
99 |
100 | if seed is not None:
101 | random.seed(seed)
102 | seed += 1
103 | np.random.seed(seed)
104 | seed += 1
105 | torch.manual_seed(seed)
106 |
107 | if isinstance(device_name, (str, int)):
108 | device_name = [device_name]
109 |
110 | devices = []
111 | for t in reversed(device_name):
112 | t_device = torch.device(t)
113 | devices.append(t_device)
114 | if t_device.type == 'cuda':
115 | assert torch.cuda.is_available()
116 | torch.cuda.set_device(t_device)
117 | if seed is not None:
118 | seed += 1
119 | torch.cuda.manual_seed(seed)
120 | devices.reverse()
121 | torch.backends.cudnn.enabled = use_cudnn
122 | torch.backends.cudnn.deterministic = deterministic
123 | torch.backends.cudnn.benchmark = benchmark
124 |
125 | if hasattr(torch.backends.cudnn, 'allow_tf32'):
126 | torch.backends.cudnn.allow_tf32 = use_tf32
127 | torch.backends.cuda.matmul.allow_tf32 = use_tf32
128 |
129 | return devices if len(devices) > 1 else devices[0]
130 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # CoST: Contrastive Learning of Disentangled Seasonal-Trend Representations for Time Series Forecasting (ICLR 2022)
2 |
3 |
4 |
5 |
6 | Figure 1. Overall CoST Architecture.
7 |
8 |
9 | Official PyTorch code repository for the [CoST paper](https://openreview.net/forum?id=PilZY3omXV2).
10 |
11 | * CoST is a contrastive learning method for learning disentangled seasonal-trend representations for time series forecasting.
12 | * CoST consistently outperforms state-of-the-art methods by a considerable margin, achieveing a 21.3% improvement in MSE on multivariate benchmarks.
13 |
14 | ## Requirements
15 | 1. Install Python 3.8, and the required dependencies.
16 | 2. Required dependencies can be installed by: ```pip install -r requirements.txt```
17 |
18 | ## Data
19 |
20 | The datasets can be obtained and put into `datasets/` folder in the following way:
21 |
22 | * [3 ETT datasets](https://github.com/zhouhaoyi/ETDataset) should be placed at `datasets/ETTh1.csv`, `datasets/ETTh2.csv` and `datasets/ETTm1.csv`.
23 | * [Electricity dataset](https://archive.ics.uci.edu/ml/datasets/ElectricityLoadDiagrams20112014) placed at `datasets/LD2011_2014.txt` and run `electricity.py`.
24 | * [Weather dataset](https://drive.google.com/drive/folders/1ohGYWWohJlOlb2gsGTeEq3Wii2egnEPR) (link from [Informer repository](https://github.com/zhouhaoyi/Informer2020)) placed at `datasets/WTH.csv`
25 | * [M5 dataset](https://drive.google.com/drive/folders/1D6EWdVSaOtrP1LEFh1REjI3vej6iUS_4) place `calendar.csv`, `sales_train_validation.csv`, `sales_train_evaluation.csv`, `sales_test_validation.csv` and `sales_test_evaluation.csv` at `datasets/` and run m5.py.
26 |
27 | ## Usage
28 | To train and evaluate CoST on a dataset, run the script from the scripts folder: ```./scripts/ETT_CoST.sh``` (edit file permissions via ```chmod u+x scripts/*```).
29 |
30 | After training and evaluation, the trained encoder, output and evaluation metrics can be found in `training//__