├── models ├── __init__.py ├── dilated_conv.py └── encoder.py ├── tasks ├── __init__.py ├── _eval_protocols.py └── forecasting.py ├── pics ├── CoST.png └── results.png ├── datasets ├── PLACE_DATASETS_HERE ├── electricity.py └── m5.py ├── requirements.txt ├── CODEOWNERS ├── SECURITY.md ├── scripts ├── Electricity_CoST.sh ├── Weather_CoST.sh ├── M5_CoST.sh └── ETT_CoST.sh ├── LICENSE.txt ├── datautils.py ├── utils.py ├── README.md ├── CODE_OF_CONDUCT.md ├── train.py └── cost.py /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tasks/__init__.py: -------------------------------------------------------------------------------- 1 | from .forecasting import eval_forecasting 2 | -------------------------------------------------------------------------------- /pics/CoST.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/CoST/HEAD/pics/CoST.png -------------------------------------------------------------------------------- /pics/results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/CoST/HEAD/pics/results.png -------------------------------------------------------------------------------- /datasets/PLACE_DATASETS_HERE: -------------------------------------------------------------------------------- 1 | Please follow the instructions in README.md to place the datasets into this folder. -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | scipy==1.6.1 2 | torch==1.9.0 3 | numpy==1.22.0 4 | pandas==1.0.1 5 | scikit_learn==0.24.1 6 | einops==0.3.0 7 | -------------------------------------------------------------------------------- /CODEOWNERS: -------------------------------------------------------------------------------- 1 | # Comment line immediately above ownership line is reserved for related other information. Please be careful while editing. 2 | #ECCN:Open Source -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | ## Security 2 | 3 | Please report any security issue to [security@salesforce.com](mailto:security@salesforce.com) 4 | as soon as it is discovered. This library limits its runtime dependencies in 5 | order to reduce the total cost of ownership as much as can be, but all consumers 6 | should remain vigilant and have their security stakeholders review all third-party 7 | products (3PP) like this one and their dependencies. -------------------------------------------------------------------------------- /datasets/electricity.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | data_ecl = pd.read_csv('LD2011_2014.txt', parse_dates=True, sep=';', decimal=',', index_col=0) 3 | data_ecl = data_ecl.resample('1h', closed='right').sum() 4 | data_ecl = data_ecl.loc[:, data_ecl.cumsum(axis=0).iloc[8920] != 0] # filter out instances with missing values 5 | data_ecl.index = data_ecl.index.rename('date') 6 | data_ecl = data_ecl['2012':] 7 | data_ecl.to_csv('electricity.csv') -------------------------------------------------------------------------------- /scripts/Electricity_CoST.sh: -------------------------------------------------------------------------------- 1 | for seed in $(seq 0 4); do 2 | python -u train.py electricity forecast_univar --alpha 0.0005 --kernels 1 2 4 8 16 32 64 128 --max-train-length 201 --batch-size 128 --archive forecast_csv_univar --repr-dims 320 --max-threads 8 --seed ${seed} --eval 3 | python -u train.py electricity forecast_multivar --alpha 0.0005 --kernels 1 2 4 8 16 32 64 128 --max-train-length 201 --batch-size 128 --archive forecast_csv --repr-dims 320 --max-threads 8 --seed ${seed} --eval 4 | done -------------------------------------------------------------------------------- /scripts/Weather_CoST.sh: -------------------------------------------------------------------------------- 1 | for seed in $(seq 0 4); do 2 | # multivar 3 | python -u train.py WTH forecast_multivar --alpha 0.0005 --kernels 1 2 4 8 16 32 64 128 --max-train-length 201 --batch-size 128 --archive forecast_csv --repr-dims 320 --max-threads 8 --seed ${seed} --eval 4 | # univar 5 | python -u train.py WTH forecast_univar --alpha 0.0005 --kernels 1 2 4 8 16 32 64 128 --max-train-length 201 --batch-size 128 --archive forecast_csv_univar --repr-dims 320 --max-threads 8 --seed ${seed} --eval 6 | done -------------------------------------------------------------------------------- /scripts/M5_CoST.sh: -------------------------------------------------------------------------------- 1 | for level in $(seq 1 10); do 2 | for seed in $(seq 0 4); do 3 | # multivar 4 | python -u train.py M5-l${level} forecast_multivar --alpha 0.0005 --kernels 1 2 4 8 16 32 64 128 --max-train-length 201 --batch-size 128 --archive forecast_csv --repr-dims 320 --max-threads 8 --seed ${seed} --eval 5 | # univar 6 | python -u train.py M5-l${level} forecast_univar --alpha 0.0005 --kernels 1 2 4 8 16 32 64 128 --max-train-length 201 --batch-size 128 --archive forecast_csv_univar --repr-dims 320 --max-threads 8 --seed ${seed} --eval 7 | done 8 | done -------------------------------------------------------------------------------- /scripts/ETT_CoST.sh: -------------------------------------------------------------------------------- 1 | for seed in $(seq 0 4); do 2 | # multivar 3 | python -u train.py ETTh1 forecast_multivar --alpha 0.0005 --kernels 1 2 4 8 16 32 64 128 --max-train-length 201 --batch-size 128 --archive forecast_csv --repr-dims 320 --max-threads 8 --seed ${seed} --eval 4 | python -u train.py ETTh2 forecast_multivar --alpha 0.0005 --kernels 1 2 4 8 16 32 64 128 --max-train-length 201 --batch-size 128 --archive forecast_csv --repr-dims 320 --max-threads 8 --seed ${seed} --eval 5 | python -u train.py ETTm1 forecast_multivar --alpha 0.0005 --kernels 1 2 4 8 16 32 64 128 --max-train-length 201 --batch-size 128 --archive forecast_csv --repr-dims 320 --max-threads 8 --seed ${seed} --eval 6 | # univar 7 | python -u train.py ETTh1 forecast_univar --alpha 0.0005 --kernels 1 2 4 8 16 32 64 128 --max-train-length 201 --batch-size 128 --archive forecast_csv_univar --repr-dims 320 --max-threads 8 --seed ${seed} --eval 8 | python -u train.py ETTh2 forecast_univar --alpha 0.0005 --kernels 1 2 4 8 16 32 64 128 --max-train-length 201 --batch-size 128 --archive forecast_csv_univar --repr-dims 320 --max-threads 8 --seed ${seed} --eval 9 | python -u train.py ETTm1 forecast_univar --alpha 0.0005 --kernels 1 2 4 8 16 32 64 128 --max-train-length 201 --batch-size 128 --archive forecast_csv_univar --repr-dims 320 --max-threads 8 --seed ${seed} --eval 10 | done 11 | -------------------------------------------------------------------------------- /tasks/_eval_protocols.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.linear_model import Ridge 3 | from sklearn.model_selection import GridSearchCV, train_test_split 4 | 5 | 6 | def fit_ridge(train_features, train_y, valid_features, valid_y, MAX_SAMPLES=100000): 7 | # If the training set is too large, subsample MAX_SAMPLES examples 8 | if train_features.shape[0] > MAX_SAMPLES: 9 | split = train_test_split( 10 | train_features, train_y, 11 | train_size=MAX_SAMPLES, random_state=0 12 | ) 13 | train_features = split[0] 14 | train_y = split[2] 15 | if valid_features.shape[0] > MAX_SAMPLES: 16 | split = train_test_split( 17 | valid_features, valid_y, 18 | train_size=MAX_SAMPLES, random_state=0 19 | ) 20 | valid_features = split[0] 21 | valid_y = split[2] 22 | alphas = [0.1, 0.2, 0.5, 1, 2, 5, 10, 20, 50, 100, 200, 500, 1000] 23 | valid_results = [] 24 | for alpha in alphas: 25 | lr = Ridge(alpha=alpha).fit(train_features, train_y) 26 | valid_pred = lr.predict(valid_features) 27 | score = np.sqrt(((valid_pred - valid_y) ** 2).mean()) + np.abs(valid_pred - valid_y).mean() 28 | valid_results.append(score) 29 | best_alpha = alphas[np.argmin(valid_results)] 30 | 31 | lr = Ridge(alpha=best_alpha) 32 | lr.fit(train_features, train_y) 33 | return lr 34 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2022, Salesforce.com, Inc. 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 5 | 6 | * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 7 | 8 | * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 9 | 10 | * Neither the name of Salesforce.com nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 11 | 12 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /datasets/m5.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | 4 | calendar = pd.read_csv('calendar.csv', index_col='date', parse_dates=True) 5 | train_validation = pd.read_csv('sales_train_validation.csv') 6 | train_evaluation = pd.read_csv('sales_train_evaluation.csv') 7 | test_validation = pd.read_csv('sales_test_validation.csv') 8 | test_evaluation = pd.read_csv('sales_test_evaluation.csv') 9 | 10 | all_data = pd.merge( 11 | train_evaluation, 12 | test_evaluation, 13 | how="inner", 14 | on=None, 15 | left_on=['item_id', 'dept_id', 'cat_id', 'store_id', 'state_id'], 16 | right_on=['item_id', 'dept_id', 'cat_id', 'store_id', 'state_id'], 17 | sort=False, 18 | suffixes=("_x", "_y"), 19 | copy=True, 20 | indicator=False, 21 | validate=None, 22 | ) 23 | 24 | groups = { 25 | 'l1': ['item_id', 'dept_id', 'cat_id', 'store_id', 'state_id'], 26 | 'l2': ['state_id'], 27 | 'l3': ['store_id'], 28 | 'l4': ['cat_id'], 29 | 'l5': ['dept_id'], 30 | 'l6': ['state_id', 'cat_id'], 31 | 'l7': ['state_id', 'dept_id'], 32 | 'l8': ['store_id', 'cat_id'], 33 | 'l9': ['store_id', 'dept_id'], 34 | 'l10': ['item_id'], 35 | } 36 | 37 | for k, v in groups.items(): 38 | if k == 'l1': 39 | grouped_data = all_data.drop(columns=v).sum().to_frame(name='total') 40 | else: 41 | grouped_data = all_data.groupby(v).sum().transpose() 42 | grouped_data['date'] = calendar.index 43 | grouped_data = grouped_data.set_index('date') 44 | 45 | if isinstance(grouped_data.columns, pd.MultiIndex): 46 | grouped_data.columns = [c[0] + "_" + c[1] for c in grouped_data.columns] 47 | 48 | grouped_data.to_csv(f'M5-{k}.csv', index=True) 49 | -------------------------------------------------------------------------------- /models/dilated_conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | class SamePadConv(nn.Module): 7 | def __init__(self, in_channels, out_channels, kernel_size, dilation=1, groups=1): 8 | super().__init__() 9 | self.receptive_field = (kernel_size - 1) * dilation + 1 10 | padding = self.receptive_field // 2 11 | self.conv = nn.Conv1d( 12 | in_channels, out_channels, kernel_size, 13 | padding=padding, 14 | dilation=dilation, 15 | groups=groups 16 | ) 17 | self.remove = 1 if self.receptive_field % 2 == 0 else 0 18 | 19 | def forward(self, x): 20 | out = self.conv(x) 21 | if self.remove > 0: 22 | out = out[:, :, : -self.remove] 23 | return out 24 | 25 | 26 | class ConvBlock(nn.Module): 27 | def __init__(self, in_channels, out_channels, kernel_size, dilation, final=False): 28 | super().__init__() 29 | self.conv1 = SamePadConv(in_channels, out_channels, kernel_size, dilation=dilation) 30 | self.conv2 = SamePadConv(out_channels, out_channels, kernel_size, dilation=dilation) 31 | self.projector = nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels or final else None 32 | 33 | def forward(self, x): 34 | residual = x if self.projector is None else self.projector(x) 35 | x = F.gelu(x) 36 | x = self.conv1(x) 37 | x = F.gelu(x) 38 | x = self.conv2(x) 39 | return x + residual 40 | 41 | 42 | class DilatedConvEncoder(nn.Module): 43 | def __init__(self, in_channels, channels, kernel_size, extract_layers=None): 44 | super().__init__() 45 | 46 | if extract_layers is not None: 47 | assert len(channels) - 1 in extract_layers 48 | 49 | self.extract_layers = extract_layers 50 | self.net = nn.Sequential(*[ 51 | ConvBlock( 52 | channels[i-1] if i > 0 else in_channels, 53 | channels[i], 54 | kernel_size=kernel_size, 55 | dilation=2**i, 56 | final=(i == len(channels)-1) 57 | ) 58 | for i in range(len(channels)) 59 | ]) 60 | 61 | def forward(self, x): 62 | if self.extract_layers is not None: 63 | outputs = [] 64 | for idx, mod in enumerate(self.net): 65 | x = mod(x) 66 | if idx in self.extract_layers: 67 | outputs.append(x) 68 | return outputs 69 | return self.net(x) 70 | -------------------------------------------------------------------------------- /tasks/forecasting.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import time 3 | from . import _eval_protocols as eval_protocols 4 | 5 | 6 | def generate_pred_samples(features, data, pred_len, drop=0): 7 | n = data.shape[1] 8 | features = features[:, :-pred_len] 9 | labels = np.stack([ data[:, i:1+n+i-pred_len] for i in range(pred_len)], axis=2)[:, 1:] 10 | features = features[:, drop:] 11 | labels = labels[:, drop:] 12 | return features.reshape(-1, features.shape[-1]), \ 13 | labels.reshape(-1, labels.shape[2]*labels.shape[3]) 14 | 15 | 16 | def cal_metrics(pred, target): 17 | return { 18 | 'MSE': ((pred - target) ** 2).mean(), 19 | 'MAE': np.abs(pred - target).mean() 20 | } 21 | 22 | 23 | def eval_forecasting(model, data, train_slice, valid_slice, test_slice, scaler, pred_lens, n_covariate_cols, padding): 24 | t = time.time() 25 | 26 | all_repr = model.encode( 27 | data, 28 | mode='forecasting', 29 | casual=True, 30 | sliding_length=1, 31 | sliding_padding=padding, 32 | batch_size=256 33 | ) 34 | 35 | train_repr = all_repr[:, train_slice] 36 | valid_repr = all_repr[:, valid_slice] 37 | test_repr = all_repr[:, test_slice] 38 | 39 | train_data = data[:, train_slice, n_covariate_cols:] 40 | valid_data = data[:, valid_slice, n_covariate_cols:] 41 | test_data = data[:, test_slice, n_covariate_cols:] 42 | 43 | encoder_infer_time = time.time() - t 44 | 45 | ours_result = {} 46 | lr_train_time = {} 47 | lr_infer_time = {} 48 | out_log = {} 49 | for pred_len in pred_lens: 50 | train_features, train_labels = generate_pred_samples(train_repr, train_data, pred_len, drop=padding) 51 | valid_features, valid_labels = generate_pred_samples(valid_repr, valid_data, pred_len) 52 | test_features, test_labels = generate_pred_samples(test_repr, test_data, pred_len) 53 | 54 | t = time.time() 55 | lr = eval_protocols.fit_ridge(train_features, train_labels, valid_features, valid_labels) 56 | lr_train_time[pred_len] = time.time() - t 57 | 58 | t = time.time() 59 | test_pred = lr.predict(test_features) 60 | lr_infer_time[pred_len] = time.time() - t 61 | 62 | ori_shape = test_data.shape[0], -1, pred_len, test_data.shape[2] 63 | test_pred = test_pred.reshape(ori_shape) 64 | test_labels = test_labels.reshape(ori_shape) 65 | 66 | if test_data.shape[0] > 1: 67 | test_pred_inv = scaler.inverse_transform(test_pred.swapaxes(0, 3)).swapaxes(0, 3) 68 | test_labels_inv = scaler.inverse_transform(test_labels.swapaxes(0, 3)).swapaxes(0, 3) 69 | else: 70 | test_pred_inv = scaler.inverse_transform(test_pred) 71 | test_labels_inv = scaler.inverse_transform(test_labels) 72 | out_log[pred_len] = { 73 | 'norm': test_pred, 74 | 'raw': test_pred_inv, 75 | 'norm_gt': test_labels, 76 | 'raw_gt': test_labels_inv 77 | } 78 | ours_result[pred_len] = { 79 | 'norm': cal_metrics(test_pred, test_labels), 80 | 'raw': cal_metrics(test_pred_inv, test_labels_inv) 81 | } 82 | 83 | eval_res = { 84 | 'ours': ours_result, 85 | 'encoder_infer_time': encoder_infer_time, 86 | 'lr_train_time': lr_train_time, 87 | 'lr_infer_time': lr_infer_time 88 | } 89 | return out_log, eval_res 90 | -------------------------------------------------------------------------------- /datautils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from sklearn.preprocessing import StandardScaler, MinMaxScaler 4 | 5 | 6 | def load_forecast_npy(name, univar=False): 7 | data = np.load(f'datasets/{name}.npy') 8 | if univar: 9 | data = data[: -1:] 10 | 11 | train_slice = slice(None, int(0.6 * len(data))) 12 | valid_slice = slice(int(0.6 * len(data)), int(0.8 * len(data))) 13 | test_slice = slice(int(0.8 * len(data)), None) 14 | 15 | scaler = StandardScaler().fit(data[train_slice]) 16 | data = scaler.transform(data) 17 | data = np.expand_dims(data, 0) 18 | 19 | pred_lens = [24, 48, 96, 288, 672] 20 | return data, train_slice, valid_slice, test_slice, scaler, pred_lens, 0 21 | 22 | def _get_time_features(dt): 23 | return np.stack([ 24 | dt.minute.to_numpy(), 25 | dt.hour.to_numpy(), 26 | dt.dayofweek.to_numpy(), 27 | dt.day.to_numpy(), 28 | dt.dayofyear.to_numpy(), 29 | dt.month.to_numpy(), 30 | dt.weekofyear.to_numpy(), 31 | ], axis=1).astype(np.float) 32 | 33 | def load_forecast_csv(name, univar=False): 34 | data = pd.read_csv(f'datasets/{name}.csv', index_col='date', parse_dates=True) 35 | dt_embed = _get_time_features(data.index) 36 | n_covariate_cols = dt_embed.shape[-1] 37 | 38 | if univar: 39 | if name in ('ETTh1', 'ETTh2', 'ETTm1', 'ETTm2'): 40 | data = data[['OT']] 41 | elif name == 'electricity': 42 | data = data[['MT_001']] 43 | elif name == 'WTH': 44 | data = data[['WetBulbCelsius']] 45 | else: 46 | data = data.iloc[:, -1:] 47 | 48 | data = data.to_numpy() 49 | if name == 'ETTh1' or name == 'ETTh2': 50 | train_slice = slice(None, 12 * 30 * 24) 51 | valid_slice = slice(12 * 30 * 24, 16 * 30 * 24) 52 | test_slice = slice(16 * 30 * 24, 20 * 30 * 24) 53 | elif name == 'ETTm1' or name == 'ETTm2': 54 | train_slice = slice(None, 12 * 30 * 24 * 4) 55 | valid_slice = slice(12 * 30 * 24 * 4, 16 * 30 * 24 * 4) 56 | test_slice = slice(16 * 30 * 24 * 4, 20 * 30 * 24 * 4) 57 | elif name.startswith('M5'): 58 | train_slice = slice(None, int(0.8 * (1913 + 28))) 59 | valid_slice = slice(int(0.8 * (1913 + 28)), 1913 + 28) 60 | test_slice = slice(1913 + 28 - 1, 1913 + 2 * 28) 61 | else: 62 | train_slice = slice(None, int(0.6 * len(data))) 63 | valid_slice = slice(int(0.6 * len(data)), int(0.8 * len(data))) 64 | test_slice = slice(int(0.8 * len(data)), None) 65 | 66 | scaler = StandardScaler().fit(data[train_slice]) 67 | data = scaler.transform(data) 68 | if name in ('electricity') or name.startswith('M5'): 69 | data = np.expand_dims(data.T, -1) # Each variable is an instance rather than a feature 70 | else: 71 | data = np.expand_dims(data, 0) 72 | 73 | if n_covariate_cols > 0: 74 | dt_scaler = StandardScaler().fit(dt_embed[train_slice]) 75 | dt_embed = np.expand_dims(dt_scaler.transform(dt_embed), 0) 76 | data = np.concatenate([np.repeat(dt_embed, data.shape[0], axis=0), data], axis=-1) 77 | 78 | if name in ('ETTh1', 'ETTh2', 'electricity', 'WTH'): 79 | pred_lens = [24, 48, 168, 336, 720] 80 | elif name.startswith('M5'): 81 | pred_lens = [28] 82 | else: 83 | pred_lens = [24, 48, 96, 288, 672] 84 | 85 | return data, train_slice, valid_slice, test_slice, scaler, pred_lens, n_covariate_cols 86 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pickle 4 | import torch 5 | import random 6 | from datetime import datetime 7 | import torch.nn as nn 8 | 9 | 10 | def pkl_save(name, var): 11 | with open(name, 'wb') as f: 12 | pickle.dump(var, f) 13 | 14 | def pkl_load(name): 15 | with open(name, 'rb') as f: 16 | return pickle.load(f) 17 | 18 | def torch_pad_nan(arr, left=0, right=0, dim=0): 19 | if left > 0: 20 | padshape = list(arr.shape) 21 | padshape[dim] = left 22 | arr = torch.cat((torch.full(padshape, np.nan), arr), dim=dim) 23 | if right > 0: 24 | padshape = list(arr.shape) 25 | padshape[dim] = right 26 | arr = torch.cat((arr, torch.full(padshape, np.nan)), dim=dim) 27 | return arr 28 | 29 | def pad_nan_to_target(array, target_length, axis=0, both_side=False): 30 | assert array.dtype in [np.float16, np.float32, np.float64] 31 | pad_size = target_length - array.shape[axis] 32 | if pad_size <= 0: 33 | return array 34 | npad = [(0, 0)] * array.ndim 35 | if both_side: 36 | npad[axis] = (pad_size // 2, pad_size - pad_size//2) 37 | else: 38 | npad[axis] = (0, pad_size) 39 | return np.pad(array, pad_width=npad, mode='constant', constant_values=np.nan) 40 | 41 | def split_with_nan(x, sections, axis=0): 42 | assert x.dtype in [np.float16, np.float32, np.float64] 43 | arrs = np.array_split(x, sections, axis=axis) 44 | target_length = arrs[0].shape[axis] 45 | for i in range(len(arrs)): 46 | arrs[i] = pad_nan_to_target(arrs[i], target_length, axis=axis) 47 | return arrs 48 | 49 | def take_per_row(A, indx, num_elem): 50 | all_indx = indx[:,None] + np.arange(num_elem) 51 | return A[torch.arange(all_indx.shape[0])[:,None], all_indx] 52 | 53 | def centerize_vary_length_series(x): 54 | prefix_zeros = np.argmax(~np.isnan(x).all(axis=-1), axis=1) 55 | suffix_zeros = np.argmax(~np.isnan(x[:, ::-1]).all(axis=-1), axis=1) 56 | offset = (prefix_zeros + suffix_zeros) // 2 - prefix_zeros 57 | rows, column_indices = np.ogrid[:x.shape[0], :x.shape[1]] 58 | offset[offset < 0] += x.shape[1] 59 | column_indices = column_indices - offset[:, np.newaxis] 60 | return x[rows, column_indices] 61 | 62 | def data_dropout(arr, p): 63 | B, T = arr.shape[0], arr.shape[1] 64 | mask = np.full(B*T, False, dtype=np.bool) 65 | ele_sel = np.random.choice( 66 | B*T, 67 | size=int(B*T*p), 68 | replace=False 69 | ) 70 | mask[ele_sel] = True 71 | res = arr.copy() 72 | res[mask.reshape(B, T)] = np.nan 73 | return res 74 | 75 | def name_with_datetime(prefix='default'): 76 | now = datetime.now() 77 | return prefix + '_' + now.strftime("%Y%m%d_%H%M%S") 78 | 79 | def init_dl_program( 80 | device_name, 81 | seed=None, 82 | use_cudnn=True, 83 | deterministic=False, 84 | benchmark=False, 85 | use_tf32=False, 86 | max_threads=None 87 | ): 88 | import torch 89 | if max_threads is not None: 90 | torch.set_num_threads(max_threads) # intraop 91 | if torch.get_num_interop_threads() != max_threads: 92 | torch.set_num_interop_threads(max_threads) # interop 93 | try: 94 | import mkl 95 | except: 96 | pass 97 | else: 98 | mkl.set_num_threads(max_threads) 99 | 100 | if seed is not None: 101 | random.seed(seed) 102 | seed += 1 103 | np.random.seed(seed) 104 | seed += 1 105 | torch.manual_seed(seed) 106 | 107 | if isinstance(device_name, (str, int)): 108 | device_name = [device_name] 109 | 110 | devices = [] 111 | for t in reversed(device_name): 112 | t_device = torch.device(t) 113 | devices.append(t_device) 114 | if t_device.type == 'cuda': 115 | assert torch.cuda.is_available() 116 | torch.cuda.set_device(t_device) 117 | if seed is not None: 118 | seed += 1 119 | torch.cuda.manual_seed(seed) 120 | devices.reverse() 121 | torch.backends.cudnn.enabled = use_cudnn 122 | torch.backends.cudnn.deterministic = deterministic 123 | torch.backends.cudnn.benchmark = benchmark 124 | 125 | if hasattr(torch.backends.cudnn, 'allow_tf32'): 126 | torch.backends.cudnn.allow_tf32 = use_tf32 127 | torch.backends.cuda.matmul.allow_tf32 = use_tf32 128 | 129 | return devices if len(devices) > 1 else devices[0] 130 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CoST: Contrastive Learning of Disentangled Seasonal-Trend Representations for Time Series Forecasting (ICLR 2022) 2 | 3 |

4 | 5 |

6 | Figure 1. Overall CoST Architecture. 7 |

8 | 9 | Official PyTorch code repository for the [CoST paper](https://openreview.net/forum?id=PilZY3omXV2). 10 | 11 | * CoST is a contrastive learning method for learning disentangled seasonal-trend representations for time series forecasting. 12 | * CoST consistently outperforms state-of-the-art methods by a considerable margin, achieveing a 21.3% improvement in MSE on multivariate benchmarks. 13 | 14 | ## Requirements 15 | 1. Install Python 3.8, and the required dependencies. 16 | 2. Required dependencies can be installed by: ```pip install -r requirements.txt``` 17 | 18 | ## Data 19 | 20 | The datasets can be obtained and put into `datasets/` folder in the following way: 21 | 22 | * [3 ETT datasets](https://github.com/zhouhaoyi/ETDataset) should be placed at `datasets/ETTh1.csv`, `datasets/ETTh2.csv` and `datasets/ETTm1.csv`. 23 | * [Electricity dataset](https://archive.ics.uci.edu/ml/datasets/ElectricityLoadDiagrams20112014) placed at `datasets/LD2011_2014.txt` and run `electricity.py`. 24 | * [Weather dataset](https://drive.google.com/drive/folders/1ohGYWWohJlOlb2gsGTeEq3Wii2egnEPR) (link from [Informer repository](https://github.com/zhouhaoyi/Informer2020)) placed at `datasets/WTH.csv` 25 | * [M5 dataset](https://drive.google.com/drive/folders/1D6EWdVSaOtrP1LEFh1REjI3vej6iUS_4) place `calendar.csv`, `sales_train_validation.csv`, `sales_train_evaluation.csv`, `sales_test_validation.csv` and `sales_test_evaluation.csv` at `datasets/` and run m5.py. 26 | 27 | ## Usage 28 | To train and evaluate CoST on a dataset, run the script from the scripts folder: ```./scripts/ETT_CoST.sh``` (edit file permissions via ```chmod u+x scripts/*```). 29 | 30 | After training and evaluation, the trained encoder, output and evaluation metrics can be found in `training//__