├── ESRNN ├── utils │ ├── __init__.py │ ├── config.py │ ├── losses.py │ ├── data.py │ ├── ESRNN.py │ └── DRNN.py ├── __init__.py ├── tests │ └── test_esrnn.py ├── utils_visualization.py ├── m4_run.py ├── utils_configs.py ├── m4_data.py ├── utils_evaluation.py ├── ESRNNensemble.py └── ESRNN.py ├── requirements.txt ├── .github ├── images │ ├── metrics.png │ ├── x_test.png │ ├── x_train.png │ ├── y_test.png │ ├── y_train.png │ ├── test-data-example.png │ └── train-data-example.png └── workflows │ └── pythonpackage.yml ├── .gitignore ├── setup.py ├── LICENSE └── README.md /ESRNN/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ESRNN/__init__.py: -------------------------------------------------------------------------------- 1 | from ESRNN.ESRNN import * 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy>=1.16.1 2 | pandas>=0.25.2 3 | torch>=1.3.1 4 | -------------------------------------------------------------------------------- /.github/images/metrics.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kdgutier/esrnn_torch/HEAD/.github/images/metrics.png -------------------------------------------------------------------------------- /.github/images/x_test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kdgutier/esrnn_torch/HEAD/.github/images/x_test.png -------------------------------------------------------------------------------- /.github/images/x_train.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kdgutier/esrnn_torch/HEAD/.github/images/x_train.png -------------------------------------------------------------------------------- /.github/images/y_test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kdgutier/esrnn_torch/HEAD/.github/images/y_test.png -------------------------------------------------------------------------------- /.github/images/y_train.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kdgutier/esrnn_torch/HEAD/.github/images/y_train.png -------------------------------------------------------------------------------- /.github/images/test-data-example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kdgutier/esrnn_torch/HEAD/.github/images/test-data-example.png -------------------------------------------------------------------------------- /.github/images/train-data-example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kdgutier/esrnn_torch/HEAD/.github/images/train-data-example.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.DS_Store 3 | *.vscode 4 | *.ipynb_checkpoints 5 | data/ 6 | !results/ 7 | results/* 8 | literature/* 9 | plots/ 10 | statistics/ 11 | dynet/ 12 | theory/ 13 | *.so 14 | *.dll 15 | *.exe 16 | *.c 17 | build/ 18 | ESRNN/hyperpar_tunning_m4.py 19 | logs/* 20 | configs/* 21 | setup.sh 22 | server_results.sh 23 | test.py 24 | 25 | # Setuptools distribution folder. 26 | /dist/ 27 | 28 | # Python egg metadata, regenerated from source files by setuptools. 29 | /*.egg-info 30 | /*.egg 31 | ESRNN/r 32 | -------------------------------------------------------------------------------- /.github/workflows/pythonpackage.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: Python package 5 | 6 | on: 7 | push: 8 | branches: [ master, pip ] 9 | pull_request: 10 | branches: [ master ] 11 | 12 | jobs: 13 | build: 14 | 15 | runs-on: ubuntu-latest 16 | strategy: 17 | matrix: 18 | python-version: [3.6, 3.7, 3.8] 19 | 20 | steps: 21 | - uses: actions/checkout@v2 22 | - name: Set up Python ${{ matrix.python-version }} 23 | uses: actions/setup-python@v1 24 | with: 25 | python-version: ${{ matrix.python-version }} 26 | - name: Install dependencies 27 | run: | 28 | python -m pip install --upgrade pip 29 | pip install -r requirements.txt 30 | - name: Test with pytest 31 | run: | 32 | pip install pytest 33 | pytest -s 34 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | with open("README.md", "r") as fh: 4 | long_description = fh.read() 5 | 6 | setuptools.setup( 7 | name="ESRNN", 8 | version="0.1.3", 9 | author="Kin Gutierrez, Cristian Challu, Federico Garza", 10 | author_email="kdgutier@cs.cmu.edu, cchallu@andrew.cmu.edu, fede.garza.ramirez@gmail.com", 11 | description="Pytorch implementation of the ESRNN", 12 | long_description=long_description, 13 | long_description_content_type="text/markdown", 14 | url="https://github.com/kdgutier/esrnn_torch", 15 | packages=setuptools.find_packages(), 16 | classifiers=[ 17 | "Programming Language :: Python :: 3", 18 | "License :: OSI Approved :: MIT License", 19 | "Operating System :: OS Independent", 20 | ], 21 | python_requires='>=3.6', 22 | install_requires =[ 23 | "numpy>=1.16.1", 24 | "pandas>=0.25.2", 25 | "torch>=1.3.1" 26 | ], 27 | entry_points=''' 28 | [console_scripts] 29 | m4_run=ESRNN.m4_run:cli 30 | ''' 31 | ) 32 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Kin Gutierrez and Cristian Challu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /ESRNN/utils/config.py: -------------------------------------------------------------------------------- 1 | class ModelConfig(object): 2 | def __init__(self, max_epochs, batch_size, batch_size_test, freq_of_test, 3 | learning_rate, lr_scheduler_step_size, lr_decay, 4 | per_series_lr_multip, gradient_eps, gradient_clipping_threshold, 5 | rnn_weight_decay, 6 | noise_std, 7 | level_variability_penalty, 8 | testing_percentile, training_percentile, ensemble, 9 | cell_type, 10 | state_hsize, dilations, add_nl_layer, seasonality, input_size, output_size, 11 | frequency, max_periods, random_seed, device, root_dir): 12 | 13 | # Train Parameters 14 | self.max_epochs = max_epochs 15 | self.batch_size = batch_size 16 | self.batch_size_test = batch_size_test 17 | self.freq_of_test = freq_of_test 18 | self.learning_rate = learning_rate 19 | self.lr_scheduler_step_size = lr_scheduler_step_size 20 | self.lr_decay = lr_decay 21 | self.per_series_lr_multip = per_series_lr_multip 22 | self.gradient_eps = gradient_eps 23 | self.gradient_clipping_threshold = gradient_clipping_threshold 24 | self.rnn_weight_decay = rnn_weight_decay 25 | self.noise_std = noise_std 26 | self.level_variability_penalty = level_variability_penalty 27 | self.testing_percentile = testing_percentile 28 | self.training_percentile = training_percentile 29 | self.ensemble = ensemble 30 | self.device = device 31 | 32 | # Model Parameters 33 | self.cell_type = cell_type 34 | self.state_hsize = state_hsize 35 | self.dilations = dilations 36 | self.add_nl_layer = add_nl_layer 37 | self.random_seed = random_seed 38 | 39 | # Data Parameters 40 | self.seasonality = seasonality 41 | if len(seasonality)>0: 42 | self.naive_seasonality = seasonality[0] 43 | else: 44 | self.naive_seasonality = 1 45 | self.input_size = input_size 46 | self.input_size_i = self.input_size 47 | self.output_size = output_size 48 | self.output_size_i = self.output_size 49 | self.frequency = frequency 50 | self.min_series_length = self.input_size_i + self.output_size_i 51 | self.max_series_length = (max_periods * self.input_size) + self.min_series_length 52 | self.max_periods = max_periods 53 | self.root_dir = root_dir -------------------------------------------------------------------------------- /ESRNN/tests/test_esrnn.py: -------------------------------------------------------------------------------- 1 | #Testing ESRNN 2 | import runpy 3 | import os 4 | 5 | print('\n') 6 | print(10*'='+'TEST ESRNN'+10*'=') 7 | print('\n') 8 | 9 | def test_esrnn_hourly(): 10 | if not os.path.exists('./data'): 11 | os.mkdir('./data') 12 | 13 | print('\n') 14 | print(10*'='+'HOURLY'+10*'=') 15 | print('\n') 16 | 17 | exec_str = 'python -m ESRNN.m4_run --dataset Hourly ' 18 | exec_str += '--results_directory ./data --gpu_id 0 ' 19 | exec_str += '--use_cpu 1 --num_obs 100 --test 1' 20 | results = os.system(exec_str) 21 | 22 | if results==0: 23 | print('Test completed') 24 | else: 25 | raise Exception('Something went wrong') 26 | 27 | def test_esrnn_weekly(): 28 | if not os.path.exists('./data'): 29 | os.mkdir('./data') 30 | 31 | print('\n') 32 | print(10*'='+'WEEKLY'+10*'=') 33 | print('\n') 34 | 35 | exec_str = 'python -m ESRNN.m4_run --dataset Weekly ' 36 | exec_str += '--results_directory ./data --gpu_id 0 ' 37 | exec_str += '--use_cpu 1 --num_obs 100 --test 1' 38 | results = os.system(exec_str) 39 | 40 | if results==0: 41 | print('Test completed') 42 | else: 43 | raise Exception('Something went wrong') 44 | 45 | 46 | def test_esrnn_daily(): 47 | if not os.path.exists('./data'): 48 | os.mkdir('./data') 49 | 50 | print('\n') 51 | print(10*'='+'DAILY'+10*'=') 52 | print('\n') 53 | 54 | exec_str = 'python -m ESRNN.m4_run --dataset Daily ' 55 | exec_str += '--results_directory ./data --gpu_id 0 ' 56 | exec_str += '--use_cpu 1 --num_obs 100 --test 1' 57 | results = os.system(exec_str) 58 | 59 | if results==0: 60 | print('Test completed') 61 | else: 62 | raise Exception('Something went wrong') 63 | 64 | 65 | def test_esrnn_monthly(): 66 | if not os.path.exists('./data'): 67 | os.mkdir('./data') 68 | 69 | 70 | print('\n') 71 | print(10*'='+'MONTHLY'+10*'=') 72 | print('\n') 73 | 74 | exec_str = 'python -m ESRNN.m4_run --dataset Monthly ' 75 | exec_str += '--results_directory ./data --gpu_id 0 ' 76 | exec_str += '--use_cpu 1 --num_obs 100 --test 1' 77 | results = os.system(exec_str) 78 | 79 | if results==0: 80 | print('Test completed') 81 | else: 82 | raise Exception('Something went wrong') 83 | 84 | 85 | def test_esrnn_quarterly(): 86 | if not os.path.exists('./data'): 87 | os.mkdir('./data') 88 | 89 | print('\n') 90 | print(10*'='+'QUARTERLY'+10*'=') 91 | print('\n') 92 | 93 | exec_str = 'python -m ESRNN.m4_run --dataset Quarterly ' 94 | exec_str += '--results_directory ./data --gpu_id 0 ' 95 | exec_str += '--use_cpu 1 --num_obs 100 --test 1' 96 | results = os.system(exec_str) 97 | 98 | if results==0: 99 | print('Test completed') 100 | else: 101 | raise Exception('Something went wrong') 102 | 103 | 104 | def test_esrnn_yearly(): 105 | if not os.path.exists('./data'): 106 | os.mkdir('./data') 107 | 108 | print('\n') 109 | print(10*'='+'YEARLY'+10*'=') 110 | print('\n') 111 | 112 | exec_str = 'python -m ESRNN.m4_run --dataset Yearly ' 113 | exec_str += '--results_directory ./data --gpu_id 0 ' 114 | exec_str += '--use_cpu 1 --num_obs 100 --test 1' 115 | results = os.system(exec_str) 116 | 117 | if results==0: 118 | print('Test completed') 119 | else: 120 | raise Exception('Something went wrong') 121 | -------------------------------------------------------------------------------- /ESRNN/utils/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class PinballLoss(nn.Module): 5 | """ Pinball Loss 6 | Computes the pinball loss between y and y_hat. 7 | 8 | Parameters 9 | ---------- 10 | y: tensor 11 | actual values in torch tensor. 12 | y_hat: tensor (same shape as y) 13 | predicted values in torch tensor. 14 | tau: float, between 0 and 1 15 | the slope of the pinball loss, in the context of 16 | quantile regression, the value of tau determines the 17 | conditional quantile level. 18 | 19 | Returns 20 | ---------- 21 | pinball_loss: 22 | average accuracy for the predicted quantile 23 | """ 24 | def __init__(self, tau=0.5): 25 | super(PinballLoss, self).__init__() 26 | self.tau = tau 27 | 28 | def forward(self, y, y_hat): 29 | delta_y = torch.sub(y, y_hat) 30 | pinball = torch.max(torch.mul(self.tau, delta_y), torch.mul((self.tau-1), delta_y)) 31 | pinball = pinball.mean() 32 | return pinball 33 | 34 | class LevelVariabilityLoss(nn.Module): 35 | """ Level Variability Loss 36 | Computes the variability penalty for the level. 37 | 38 | Parameters 39 | ---------- 40 | levels: tensor with shape (batch, n_time) 41 | levels obtained from exponential smoothing component of ESRNN 42 | level_variability_penalty: float 43 | this parameter controls the strength of the penalization 44 | to the wigglines of the level vector, induces smoothness 45 | in the output 46 | 47 | Returns 48 | ---------- 49 | level_var_loss: 50 | wiggliness loss for the level vector 51 | """ 52 | def __init__(self, level_variability_penalty): 53 | super(LevelVariabilityLoss, self).__init__() 54 | self.level_variability_penalty = level_variability_penalty 55 | 56 | def forward(self, levels): 57 | assert levels.shape[1] > 2 58 | level_prev = torch.log(levels[:, :-1]) 59 | level_next = torch.log(levels[:, 1:]) 60 | log_diff_of_levels = torch.sub(level_prev, level_next) 61 | 62 | log_diff_prev = log_diff_of_levels[:, :-1] 63 | log_diff_next = log_diff_of_levels[:, 1:] 64 | diff = torch.sub(log_diff_prev, log_diff_next) 65 | level_var_loss = diff**2 66 | level_var_loss = level_var_loss.mean() * self.level_variability_penalty 67 | return level_var_loss 68 | 69 | class StateLoss(nn.Module): 70 | pass 71 | 72 | class SmylLoss(nn.Module): 73 | """Computes the Smyl Loss that combines level variability with 74 | with Pinball loss. 75 | windows_y: tensor of actual values, 76 | shape (n_windows, batch_size, window_size). 77 | windows_y_hat: tensor of predicted values, 78 | shape (n_windows, batch_size, window_size). 79 | levels: levels obtained from exponential smoothing component of ESRNN. 80 | tensor with shape (batch, n_time). 81 | return: smyl_loss. 82 | """ 83 | def __init__(self, tau, level_variability_penalty=0.0): 84 | super(SmylLoss, self).__init__() 85 | self.pinball_loss = PinballLoss(tau) 86 | self.level_variability_loss = LevelVariabilityLoss(level_variability_penalty) 87 | 88 | def forward(self, windows_y, windows_y_hat, levels): 89 | smyl_loss = self.pinball_loss(windows_y, windows_y_hat) 90 | if self.level_variability_loss.level_variability_penalty>0: 91 | log_diff_of_levels = self.level_variability_loss(levels) 92 | smyl_loss += log_diff_of_levels 93 | return smyl_loss 94 | 95 | 96 | class DisaggregatedPinballLoss(nn.Module): 97 | """ Pinball Loss 98 | Computes the pinball loss between y and y_hat. 99 | 100 | Parameters 101 | ---------- 102 | y: tensor 103 | actual values in torch tensor. 104 | y_hat: tensor (same shape as y) 105 | predicted values in torch tensor. 106 | tau: float, between 0 and 1 107 | the slope of the pinball loss, in the context of 108 | quantile regression, the value of tau determines the 109 | conditional quantile level. 110 | 111 | Returns 112 | ---------- 113 | pinball_loss: 114 | average accuracy for the predicted quantile 115 | """ 116 | def __init__(self, tau=0.5): 117 | super(DisaggregatedPinballLoss, self).__init__() 118 | self.tau = tau 119 | 120 | def forward(self, y, y_hat): 121 | delta_y = torch.sub(y, y_hat) 122 | pinball = torch.max(torch.mul(self.tau, delta_y), torch.mul((self.tau-1), delta_y)) 123 | pinball = pinball.mean(axis=0).mean(axis=1) 124 | return pinball -------------------------------------------------------------------------------- /ESRNN/utils_visualization.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import pandas as pd 4 | import matplotlib.pyplot as plt 5 | plt.style.use('ggplot') 6 | 7 | import seaborn as sns 8 | from itertools import product 9 | import random 10 | 11 | 12 | def plot_prediction(y, y_hat): 13 | """ 14 | y: pandas df 15 | panel with columns unique_id, ds, y 16 | y_hat: pandas df 17 | panel with columns unique_id, ds, y_hat 18 | """ 19 | pd.plotting.register_matplotlib_converters() 20 | 21 | plt.plot(y.ds, y.y, label = 'y') 22 | plt.plot(y_hat.ds, y_hat.y_hat, label='y_hat') 23 | plt.legend(loc='upper left') 24 | plt.show() 25 | 26 | def plot_grid_prediction(y, y_hat, plot_random=True, unique_ids=None, save_file_name = None): 27 | """ 28 | y: pandas df 29 | panel with columns unique_id, ds, y 30 | y_hat: pandas df 31 | panel with columns unique_id, ds, y_hat 32 | plot_random: bool 33 | if unique_ids will be sampled 34 | unique_ids: list 35 | unique_ids to plot 36 | save_file_name: str 37 | file name to save plot 38 | """ 39 | pd.plotting.register_matplotlib_converters() 40 | 41 | fig, axes = plt.subplots(2, 4, figsize = (24,8)) 42 | 43 | if not unique_ids: 44 | unique_ids = y['unique_id'].unique() 45 | 46 | assert len(unique_ids) >= 8, "Must provide at least 8 ts" 47 | 48 | if plot_random: 49 | unique_ids = random.sample(set(unique_ids), k=8) 50 | 51 | for i, (idx, idy) in enumerate(product(range(2), range(4))): 52 | y_uid = y[y.unique_id == unique_ids[i]] 53 | y_uid_hat = y_hat[y_hat.unique_id == unique_ids[i]] 54 | 55 | axes[idx, idy].plot(y_uid.ds, y_uid.y, label = 'y') 56 | axes[idx, idy].plot(y_uid_hat.ds, y_uid_hat.y_hat, label='y_hat') 57 | axes[idx, idy].set_title(unique_ids[i]) 58 | axes[idx, idy].legend(loc='upper left') 59 | 60 | plt.show() 61 | 62 | if save_file_name: 63 | fig.savefig(save_file_name, bbox_inches='tight', pad_inches=0) 64 | 65 | 66 | def plot_distributions(distributions_dict, fig_title=None, xlabel=None): 67 | n_distributions = len(distributions_dict.keys()) 68 | fig, ax = plt.subplots(1, figsize=(7, 5.5)) 69 | plt.subplots_adjust(wspace=0.35) 70 | 71 | n_colors = len(distributions_dict.keys()) 72 | colors = sns.color_palette("hls", n_colors) 73 | 74 | for idx, dist_name in enumerate(distributions_dict.keys()): 75 | train_dist_plot = sns.kdeplot(distributions_dict[dist_name], 76 | bw='silverman', 77 | label=dist_name, 78 | color=colors[idx]) 79 | if xlabel is not None: 80 | ax.set_xlabel(xlabel, fontsize=14) 81 | ax.set_ylabel('Density', fontsize=14) 82 | ax.set_title(fig_title, fontsize=15.5) 83 | ax.grid(True) 84 | ax.legend(loc='center left', bbox_to_anchor=(1, 0.5)) 85 | 86 | fig.tight_layout() 87 | if fig_title is not None: 88 | fig_title = fig_title.replace(' ', '_') 89 | plot_file = "./results/plots/{}_distributions.png".format(fig_title) 90 | plt.savefig(plot_file, bbox_inches = "tight", dpi=300) 91 | plt.show() 92 | 93 | def plot_cat_distributions(df, cat, var): 94 | unique_cats = df[cat].unique() 95 | cat_dict = {} 96 | for c in unique_cats: 97 | cat_dict[c] = df[df[cat]==c][var].values 98 | 99 | plot_distributions(cat_dict, xlabel=var) 100 | 101 | def plot_single_cat_distributions(distributions_dict, ax, fig_title=None, xlabel=None): 102 | n_distributions = len(distributions_dict.keys()) 103 | 104 | n_colors = len(distributions_dict.keys()) 105 | colors = sns.color_palette("hls", n_colors) 106 | 107 | for idx, dist_name in enumerate(distributions_dict.keys()): 108 | train_dist_plot = sns.distplot(distributions_dict[dist_name], 109 | #bw='silverman', 110 | #kde=False, 111 | rug=True, 112 | label=dist_name, 113 | color=colors[idx], 114 | ax=ax) 115 | if xlabel is not None: 116 | ax.set_xlabel(xlabel, fontsize=14) 117 | ax.set_ylabel('Density', fontsize=14) 118 | ax.set_title(fig_title, fontsize=15.5) 119 | ax.grid(True) 120 | ax.legend(loc='center left', bbox_to_anchor=(1, 0.5)) 121 | 122 | def plot_grid_cat_distributions(df, cats, var): 123 | cols = int(np.ceil(len(cats)/2)) 124 | fig, axs = plt.subplots(2, cols, figsize=(4*cols, 5.5)) 125 | plt.subplots_adjust(wspace=0.95) 126 | plt.subplots_adjust(hspace=0.5) 127 | 128 | for idx, cat in enumerate(cats): 129 | unique_cats = df[cat].unique() 130 | cat_dict = {} 131 | for c in unique_cats: 132 | values = df[df[cat]==c][var].values 133 | values = values[~np.isnan(values)] 134 | if len(values)>0: 135 | cat_dict[c] = values 136 | 137 | row = int(np.round((idx/len(cats))+0.001, 0)) 138 | col = idx % cols 139 | plot_single_cat_distributions(cat_dict, axs[row, col], 140 | fig_title=cat, xlabel=var) 141 | 142 | min_owa = math.floor(df.min_owa.min() * 1000) / 1000 143 | suptitle = var + ': ' + str(min_owa) 144 | fig.suptitle(suptitle, fontsize=18) 145 | plt.show() 146 | -------------------------------------------------------------------------------- /ESRNN/utils/data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | class Batch(): 6 | def __init__(self, mc, y, last_ds, categories, idxs): 7 | # Parse Model config 8 | exogenous_size = mc.exogenous_size 9 | device = mc.device 10 | 11 | # y: time series values 12 | n = len(y) 13 | y = np.float32(y) 14 | self.idxs = torch.LongTensor(idxs).to(device) 15 | self.y = y 16 | if (self.y.shape[1] > mc.max_series_length): 17 | y = y[:, -mc.max_series_length:] 18 | self.y = torch.tensor(y).float() 19 | 20 | # last_ds: last time for prediction purposes 21 | self.last_ds = last_ds 22 | 23 | # categories: exogenous categoric data 24 | if exogenous_size >0: 25 | self.categories = np.zeros((len(idxs), exogenous_size)) 26 | cols_idx = np.array([mc.category_to_idx[category] for category in categories]) 27 | rows_idx = np.array(range(len(cols_idx))) 28 | self.categories[rows_idx, cols_idx] = 1 29 | self.categories = torch.from_numpy(self.categories).float() 30 | 31 | self.y = self.y.to(device) 32 | self.categories = self.categories.to(device) 33 | 34 | 35 | class Iterator(object): 36 | """ Time Series Iterator. 37 | 38 | Parameters 39 | ---------- 40 | mc: ModelConfig object 41 | ModelConfig object with inherited hyperparameters: 42 | batch_size, and exogenous_size, from the ESRNN 43 | initialization. 44 | X: array, shape (n_unique_id, 3) 45 | Panel array with unique_id, last date stamp and 46 | exogenous variable. 47 | y: array, shape (n_unique_id, n_time) 48 | Panel array in wide format with unique_id, last 49 | date stamp and time series values. 50 | Returns 51 | ---------- 52 | self : object 53 | Iterator method get_batch() returns a batch of time 54 | series objects defined by the Batch class. 55 | """ 56 | def __init__(self, mc, X, y, weights=None): 57 | if weights is not None: 58 | assert len(weights)==len(X) 59 | train_ids = np.where(weights==1)[0] 60 | self.X = X[train_ids,:] 61 | self.y = y[train_ids,:] 62 | else: 63 | self.X = X 64 | self.y = y 65 | assert len(X)==len(y) 66 | 67 | # Parse Model config 68 | self.mc = mc 69 | self.batch_size = mc.batch_size 70 | 71 | self.unique_idxs = np.unique(self.X[:, 0]) 72 | assert len(self.unique_idxs)==len(self.X) 73 | self.n_series = len(self.unique_idxs) 74 | 75 | #assert self.batch_size <= self.n_series 76 | 77 | # Initialize batch iterator 78 | self.b = 0 79 | self.n_batches = int(np.ceil(self.n_series / self.batch_size)) 80 | shuffle = list(range(self.n_series)) 81 | self.sort_key = {'unique_id': [self.unique_idxs[i] for i in shuffle], 82 | 'sort_key': shuffle} 83 | 84 | def update_batch_size(self, new_batch_size): 85 | self.batch_size = new_batch_size 86 | assert self.batch_size <= self.n_series 87 | self.n_batches = int(np.ceil(self.n_series / self.batch_size)) 88 | 89 | def shuffle_dataset(self, random_seed=1): 90 | """Return the examples in the dataset in order, or shuffled.""" 91 | # Random Seed 92 | np.random.seed(random_seed) 93 | self.random_seed = random_seed 94 | shuffle = np.random.choice(self.n_series, self.n_series, replace=False) 95 | self.X = self.X[shuffle] 96 | self.y = self.y[shuffle] 97 | 98 | old_sort_key = self.sort_key['sort_key'] 99 | old_unique_idxs = self.sort_key['unique_id'] 100 | self.sort_key = {'unique_id': [old_unique_idxs[i] for i in shuffle], 101 | 'sort_key': [old_sort_key[i] for i in shuffle]} 102 | 103 | def get_trim_batch(self, unique_id): 104 | if unique_id==None: 105 | # Compute the indexes of the minibatch. 106 | first = (self.b * self.batch_size) 107 | last = min((first + self.batch_size), self.n_series) 108 | else: 109 | # Obtain unique_id index 110 | assert unique_id in self.sort_key['unique_id'], "unique_id, not fitted" 111 | first = self.sort_key['unique_id'].index(unique_id) 112 | last = first+1 113 | 114 | # Extract values for batch 115 | unique_idxs = self.sort_key['unique_id'][first:last] 116 | batch_idxs = self.sort_key['sort_key'][first:last] 117 | 118 | batch_y = self.y[first:last] 119 | batch_categories = self.X[first:last, 1] 120 | batch_last_ds = self.X[first:last, 2] 121 | 122 | len_series = np.count_nonzero(~np.isnan(batch_y), axis=1) 123 | min_len = min(len_series) 124 | last_numeric = (~np.isnan(batch_y)).cumsum(1).argmax(1)+1 125 | 126 | # Trimming to match min_len 127 | y_b = np.zeros((batch_y.shape[0], min_len)) 128 | for i in range(batch_y.shape[0]): 129 | y_b[i] = batch_y[i,(last_numeric[i]-min_len):last_numeric[i]] 130 | batch_y = y_b 131 | 132 | assert not np.isnan(batch_y).any(), \ 133 | "clean np.nan's from unique_idxs: {}".format(unique_idxs) 134 | assert batch_y.shape[0]==len(batch_idxs)==len(batch_last_ds)==len(batch_categories) 135 | assert batch_y.shape[1]>=1 136 | 137 | # Feed to Batch 138 | batch = Batch(mc=self.mc, y=batch_y, last_ds=batch_last_ds, 139 | categories=batch_categories, idxs=batch_idxs) 140 | self.b = (self.b + 1) % self.n_batches 141 | return batch 142 | 143 | def get_batch(self, unique_id=None): 144 | return self.get_trim_batch(unique_id) 145 | 146 | def __len__(self): 147 | return self.n_batches 148 | 149 | def __iter__(self): 150 | pass 151 | -------------------------------------------------------------------------------- /ESRNN/m4_run.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import itertools 4 | import ast 5 | import pickle 6 | import time 7 | 8 | import os 9 | import numpy as np 10 | import pandas as pd 11 | 12 | from ESRNN.m4_data import prepare_m4_data 13 | from ESRNN.utils_evaluation import evaluate_prediction_owa 14 | from ESRNN.utils_configs import get_config 15 | 16 | from ESRNN import ESRNN 17 | 18 | import torch 19 | 20 | def main(args): 21 | config = get_config(args.dataset) 22 | if config['data_parameters']['frequency'] == 'Y': 23 | config['data_parameters']['frequency'] = None 24 | 25 | #Setting needed parameters 26 | os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id) 27 | 28 | if args.num_obs: 29 | num_obs = args.num_obs 30 | else: 31 | num_obs = 100000 32 | 33 | if args.use_cpu == 1: 34 | config['device'] = 'cpu' 35 | else: 36 | assert torch.cuda.is_available(), 'No cuda devices detected. You can try using CPU instead.' 37 | 38 | #Reading data 39 | print('Reading data') 40 | X_train_df, y_train_df, X_test_df, y_test_df = prepare_m4_data(dataset_name=args.dataset, 41 | directory=args.results_directory, 42 | num_obs=num_obs) 43 | 44 | # Instantiate model 45 | model = ESRNN(max_epochs=config['train_parameters']['max_epochs'], 46 | batch_size=config['train_parameters']['batch_size'], 47 | freq_of_test=config['train_parameters']['freq_of_test'], 48 | learning_rate=float(config['train_parameters']['learning_rate']), 49 | lr_scheduler_step_size=config['train_parameters']['lr_scheduler_step_size'], 50 | lr_decay=config['train_parameters']['lr_decay'], 51 | per_series_lr_multip=config['train_parameters']['per_series_lr_multip'], 52 | gradient_clipping_threshold=config['train_parameters']['gradient_clipping_threshold'], 53 | rnn_weight_decay=config['train_parameters']['rnn_weight_decay'], 54 | noise_std=config['train_parameters']['noise_std'], 55 | level_variability_penalty=config['train_parameters']['level_variability_penalty'], 56 | testing_percentile=config['train_parameters']['testing_percentile'], 57 | training_percentile=config['train_parameters']['training_percentile'], 58 | ensemble=config['train_parameters']['ensemble'], 59 | max_periods=config['data_parameters']['max_periods'], 60 | seasonality=config['data_parameters']['seasonality'], 61 | input_size=config['data_parameters']['input_size'], 62 | output_size=config['data_parameters']['output_size'], 63 | frequency=config['data_parameters']['frequency'], 64 | cell_type=config['model_parameters']['cell_type'], 65 | state_hsize=config['model_parameters']['state_hsize'], 66 | dilations=config['model_parameters']['dilations'], 67 | add_nl_layer=config['model_parameters']['add_nl_layer'], 68 | random_seed=config['model_parameters']['random_seed'], 69 | device=config['device']) 70 | 71 | if args.test == 1: 72 | model = ESRNN(max_epochs=1, 73 | batch_size=20, 74 | seasonality=config['data_parameters']['seasonality'], 75 | input_size=config['data_parameters']['input_size'], 76 | output_size=config['data_parameters']['output_size'], 77 | frequency=config['data_parameters']['frequency'], 78 | device=config['device']) 79 | 80 | # Fit model 81 | # If y_test_df is provided the model will evaluate predictions on this set every freq_test epochs 82 | model.fit(X_train_df, y_train_df, X_test_df, y_test_df) 83 | 84 | # Predict on test set 85 | print('\nForecasting') 86 | y_hat_df = model.predict(X_test_df) 87 | 88 | # Evaluate predictions 89 | print(15*'=', ' Final evaluation ', 14*'=') 90 | seasonality = config['data_parameters']['seasonality'] 91 | if not seasonality: 92 | seasonality = 1 93 | else: 94 | seasonality = seasonality[0] 95 | 96 | final_owa, final_mase, final_smape = evaluate_prediction_owa(y_hat_df, y_train_df, 97 | X_test_df, y_test_df, 98 | naive2_seasonality=seasonality) 99 | 100 | if __name__ == '__main__': 101 | parser = argparse.ArgumentParser(description='Replicate M4 results for the ESRNN model') 102 | parser.add_argument("--dataset", required=True, type=str, 103 | choices=['Yearly', 'Quarterly', 'Monthly', 'Weekly', 'Hourly', 'Daily'], 104 | help="set of M4 time series to be tested") 105 | parser.add_argument("--results_directory", required=True, type=str, 106 | help="directory where M4 data will be downloaded") 107 | parser.add_argument("--gpu_id", required=False, type=int, 108 | help="an integer that specify which GPU will be used") 109 | parser.add_argument("--use_cpu", required=False, type=int, 110 | help="1 to use CPU instead of GPU (uses GPU by default)") 111 | parser.add_argument("--num_obs", required=False, type=int, 112 | help="number of M4 time series to be tested (uses all data by default)") 113 | parser.add_argument("--test", required=False, type=int, 114 | help="run fast for tests (no test by default)") 115 | args = parser.parse_args() 116 | 117 | main(args) 118 | -------------------------------------------------------------------------------- /ESRNN/utils_configs.py: -------------------------------------------------------------------------------- 1 | def get_config(dataset_name): 2 | """ 3 | Returns dict config 4 | 5 | Parameters 6 | ---------- 7 | dataset_name: str 8 | """ 9 | allowed_dataset_names = ('Yearly', 'Monthly', 'Weekly', 'Hourly', 'Quarterly', 'Daily') 10 | if dataset_name not in allowed_dataset_names: 11 | raise ValueError(f'kind must be one of {allowed_kinds}') 12 | 13 | if dataset_name == 'Yearly': 14 | return YEARLY 15 | elif dataset_name == 'Monthly': 16 | return MONTHLY 17 | elif dataset_name == 'Weekly': 18 | return WEEKLY 19 | elif dataset_name == 'Hourly': 20 | return HOURLY 21 | elif dataset_name == 'Quarterly': 22 | return QUARTERLY 23 | elif dataset_name == 'Daily': 24 | return DAILY 25 | 26 | YEARLY = { 27 | 'device': 'cuda', 28 | 'train_parameters': { 29 | 'max_epochs': 25, 30 | 'batch_size': 4, 31 | 'freq_of_test': 5, 32 | 'learning_rate': '1e-4', 33 | 'lr_scheduler_step_size': 10, 34 | 'lr_decay': 0.1, 35 | 'per_series_lr_multip': 0.8, 36 | 'gradient_clipping_threshold': 50, 37 | 'rnn_weight_decay': 0, 38 | 'noise_std': 0.001, 39 | 'level_variability_penalty': 100, 40 | 'testing_percentile': 50, 41 | 'training_percentile': 50, 42 | 'ensemble': False 43 | }, 44 | 'data_parameters': { 45 | 'max_periods': 25, 46 | 'seasonality': [], 47 | 'input_size': 4, 48 | 'output_size': 6, 49 | 'frequency': 'Y' 50 | }, 51 | 'model_parameters': { 52 | 'cell_type': 'LSTM', 53 | 'state_hsize': 40, 54 | 'dilations': [[1], [6]], 55 | 'add_nl_layer': False, 56 | 'random_seed': 117982 57 | } 58 | } 59 | 60 | MONTHLY = { 61 | 'device': 'cuda', 62 | 'train_parameters': { 63 | 'max_epochs': 15, 64 | 'batch_size': 64, 65 | 'freq_of_test': 4, 66 | 'learning_rate': '7e-4', 67 | 'lr_scheduler_step_size': 12, 68 | 'lr_decay': 0.2, 69 | 'per_series_lr_multip': 0.5, 70 | 'gradient_clipping_threshold': 20, 71 | 'rnn_weight_decay': 0, 72 | 'noise_std': 0.001, 73 | 'level_variability_penalty': 50, 74 | 'testing_percentile': 50, 75 | 'training_percentile': 45, 76 | 'ensemble': False 77 | }, 78 | 'data_parameters': { 79 | 'max_periods': 36, 80 | 'seasonality': [12], 81 | 'input_size': 12, 82 | 'output_size': 18, 83 | 'frequency': 'M' 84 | }, 85 | 'model_parameters': { 86 | 'cell_type': 'LSTM', 87 | 'state_hsize': 50, 88 | 'dilations': [[1, 3, 6, 12]], 89 | 'add_nl_layer': False, 90 | 'random_seed': 1 91 | } 92 | } 93 | 94 | 95 | WEEKLY = { 96 | 'device': 'cuda', 97 | 'train_parameters': { 98 | 'max_epochs': 50, 99 | 'batch_size': 32, 100 | 'freq_of_test': 10, 101 | 'learning_rate': '1e-2', 102 | 'lr_scheduler_step_size': 10, 103 | 'lr_decay': 0.5, 104 | 'per_series_lr_multip': 1.0, 105 | 'gradient_clipping_threshold': 20, 106 | 'rnn_weight_decay': 0, 107 | 'noise_std': 0.001, 108 | 'level_variability_penalty': 100, 109 | 'testing_percentile': 50, 110 | 'training_percentile': 50, 111 | 'ensemble': True 112 | }, 113 | 'data_parameters': { 114 | 'max_periods': 31, 115 | 'seasonality': [], 116 | 'input_size': 10, 117 | 'output_size': 13, 118 | 'frequency': 'W' 119 | }, 120 | 'model_parameters': { 121 | 'cell_type': 'ResLSTM', 122 | 'state_hsize': 40, 123 | 'dilations': [[1, 52]], 124 | 'add_nl_layer': False, 125 | 'random_seed': 2 126 | } 127 | } 128 | 129 | HOURLY = { 130 | 'device': 'cuda', 131 | 'train_parameters': { 132 | 'max_epochs': 20, 133 | 'batch_size': 32, 134 | 'freq_of_test': 5, 135 | 'learning_rate': '1e-2', 136 | 'lr_scheduler_step_size': 7, 137 | 'lr_decay': 0.5, 138 | 'per_series_lr_multip': 1.0, 139 | 'gradient_clipping_threshold': 50, 140 | 'rnn_weight_decay': 0, 141 | 'noise_std': 0.001, 142 | 'level_variability_penalty': 30, 143 | 'testing_percentile': 50, 144 | 'training_percentile': 50, 145 | 'ensemble': True 146 | }, 147 | 'data_parameters': { 148 | 'max_periods': 371, 149 | 'seasonality': [24, 168], 150 | 'input_size': 24, 151 | 'output_size': 48, 152 | 'frequency': 'H' 153 | }, 154 | 'model_parameters': { 155 | 'cell_type': 'LSTM', 156 | 'state_hsize': 40, 157 | 'dilations': [[1, 4, 24, 168]], 158 | 'add_nl_layer': False, 159 | 'random_seed': 1 160 | } 161 | } 162 | 163 | QUARTERLY = { 164 | 'device': 'cuda', 165 | 'train_parameters': { 166 | 'max_epochs': 30, 167 | 'batch_size': 16, 168 | 'freq_of_test': 5, 169 | 'learning_rate': '5e-4', 170 | 'lr_scheduler_step_size': 10, 171 | 'lr_decay': 0.5, 172 | 'per_series_lr_multip': 1.0, 173 | 'gradient_clipping_threshold': 20, 174 | 'rnn_weight_decay': 0, 175 | 'noise_std': 0.001, 176 | 'level_variability_penalty': 100, 177 | 'testing_percentile': 50, 178 | 'training_percentile': 50, 179 | 'ensemble': False 180 | }, 181 | 'data_parameters': { 182 | 'max_periods': 20, 183 | 'seasonality': [4], 184 | 'input_size': 4, 185 | 'output_size': 8, 186 | 'frequency': 'Q' 187 | }, 188 | 'model_parameters': { 189 | 'cell_type': 'LSTM', 190 | 'state_hsize': 40, 191 | 'dilations': [[1, 2, 4, 8]], 192 | 'add_nl_layer': False, 193 | 'random_seed': 3 194 | } 195 | } 196 | 197 | DAILY = { 198 | 'device': 'cuda', 199 | 'train_parameters': { 200 | 'max_epochs': 20, 201 | 'batch_size': 64, 202 | 'freq_of_test': 2, 203 | 'learning_rate': '1e-2', 204 | 'lr_scheduler_step_size': 4, 205 | 'lr_decay': 0.3333, 206 | 'per_series_lr_multip': 0.5, 207 | 'gradient_clipping_threshold': 50, 208 | 'rnn_weight_decay': 0, 209 | 'noise_std': 0.0001, 210 | 'level_variability_penalty': 100, 211 | 'testing_percentile': 50, 212 | 'training_percentile': 65, 213 | 'ensemble': False 214 | }, 215 | 'data_parameters': { 216 | 'max_periods': 15, 217 | 'seasonality': [7], 218 | 'input_size': 7, 219 | 'output_size': 14, 220 | 'frequency': 'D' 221 | }, 222 | 'model_parameters': { 223 | 'n_models': 5, 224 | 'n_top': 4, 225 | 'cell_type': 226 | 'LSTM', 227 | 'state_hsize': 40, 228 | 'dilations': [[1, 7, 28]], 229 | 'add_nl_layer': True, 230 | 'random_seed': 1 231 | } 232 | } 233 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![Build](https://github.com/kdgutier/esrnn_torch/workflows/Python%20package/badge.svg?branch=master)](https://github.com/kdgutier/esrnn_torch/tree/master) 2 | [![PyPI version fury.io](https://badge.fury.io/py/ESRNN.svg)](https://pypi.python.org/pypi/ESRNN/) 3 | [![Downloads](https://pepy.tech/badge/esrnn)](https://pepy.tech/project/esrnn) 4 | [![Python 3.6+](https://img.shields.io/badge/python-3.6+-blue.svg)](https://www.python.org/downloads/release/python-360+/) 5 | [![License: MIT](https://img.shields.io/badge/License-MIT-green.svg)](https://github.com/kdgutier/esrnn_torch/blob/master/LICENSE) 6 | 7 | 8 | # Pytorch Implementation of the ES-RNN 9 | Pytorch implementation of the ES-RNN algorithm proposed by Smyl, winning submission of the M4 Forecasting Competition. The class wraps fit and predict methods to facilitate interaction with Machine Learning pipelines along with evaluation and data wrangling utility. Developed by [Autonlab](https://www.autonlab.org/)’s members at Carnegie Mellon University. 10 | 11 | ## Installation Prerequisites 12 | * numpy>=1.16.1 13 | * pandas>=0.25.2 14 | * pytorch>=1.3.1 15 | 16 | ## Installation 17 | 18 | This code is a work in progress, any contributions or issues are welcome on 19 | GitHub at: https://github.com/kdgutier/esrnn_torch 20 | 21 | You can install the *released version* of `ESRNN` from the [Python package index](https://pypi.org) with: 22 | 23 | ```python 24 | pip install ESRNN 25 | ``` 26 | 27 | ## Usage 28 | 29 | ### Input data 30 | 31 | The fit method receives `X_df`, `y_df` training pandas dataframes in long format. Optionally `X_test_df` and `y_test_df` to compute out of sample performance. 32 | - `X_df` must contain the columns `['unique_id', 'ds', 'x']` 33 | - `y_df` must contain the columns `['unique_id', 'ds', 'y']` 34 | - `X_test_df` must contain the columns `['unique_id', 'ds', 'x']` 35 | - `y_test_df` must contain the columns `['unique_id', 'ds', 'y']` and a benchmark model to compare against (default `'y_hat_naive2'`). 36 | 37 | For all the above: 38 | - The column `'unique_id'` is a time series identifier, the column `'ds'` stands for the datetime. 39 | - Column `'x'` is an exogenous categorical feature. 40 | - Column `'y'` is the target variable. 41 | - Column `'y'` **does not allow negative values** and the first entry for all series must be **grater than 0**. 42 | 43 | The `X` and `y` dataframes must contain the same values for `'unique_id'`, `'ds'` columns and be **balanced**, ie.no *gaps* between dates for the frequency. 44 | 45 | 46 |
47 | 48 | |`X_df`|`y_df` |`X_test_df`| `y_test_df`| 49 | |:-----------:|:-----------:|:-----------:|:-----------:| 50 | | | | | | 51 | 52 |
53 | 54 | 55 | ### M4 example 56 | 57 | 58 | ```python 59 | from ESRNN.m4_data import prepare_m4_data 60 | from ESRNN.utils_evaluation import evaluate_prediction_owa 61 | 62 | from ESRNN import ESRNN 63 | 64 | X_train_df, y_train_df, X_test_df, y_test_df = prepare_m4_data(dataset_name='Yearly', 65 | directory = './data', 66 | num_obs=1000) 67 | 68 | # Instantiate model 69 | model = ESRNN(max_epochs=25, freq_of_test=5, batch_size=4, learning_rate=1e-4, 70 | per_series_lr_multip=0.8, lr_scheduler_step_size=10, 71 | lr_decay=0.1, gradient_clipping_threshold=50, 72 | rnn_weight_decay=0.0, level_variability_penalty=100, 73 | testing_percentile=50, training_percentile=50, 74 | ensemble=False, max_periods=25, seasonality=[], 75 | input_size=4, output_size=6, 76 | cell_type='LSTM', state_hsize=40, 77 | dilations=[[1], [6]], add_nl_layer=False, 78 | random_seed=1, device='cpu') 79 | 80 | # Fit model 81 | # If y_test_df is provided the model 82 | # will evaluate predictions on 83 | # this set every freq_test epochs 84 | model.fit(X_train_df, y_train_df, X_test_df, y_test_df) 85 | 86 | # Predict on test set 87 | y_hat_df = model.predict(X_test_df) 88 | 89 | # Evaluate predictions 90 | final_owa, final_mase, final_smape = evaluate_prediction_owa(y_hat_df, y_train_df, 91 | X_test_df, y_test_df, 92 | naive2_seasonality=1) 93 | ``` 94 | ## Overall Weighted Average 95 | 96 | A metric that is useful for quantifying the aggregate error of a specific model for various time series is the Overall Weighted Average (OWA) proposed for the M4 competition. This metric is calculated by obtaining the average of the symmetric mean absolute percentage error (sMAPE) and the mean absolute scaled error (MASE) for all the time series of the model and also calculating it for the Naive2 predictions. Both sMAPE and MASE are scale independent. These measurements are calculated as follows: 97 | 98 | ![OWA](https://raw.githubusercontent.com/kdgutier/esrnn_torch/master/.github/images/metrics.png) 99 | 100 | 101 | 102 | ## Current Results 103 | Here we used the model directly to compare to the original implementation. It is worth noticing that these results do not include the ensemble methods mentioned in the [ESRNN paper](https://www.sciencedirect.com/science/article/pii/S0169207019301153).
104 | [Results of the M4 competition](https://www.researchgate.net/publication/325901666_The_M4_Competition_Results_findings_conclusion_and_way_forward). 105 |
106 | 107 | | DATASET | OUR OWA | M4 OWA (Smyl) | 108 | |-----------|:---------:|:--------:| 109 | | Yearly | 0.785 | 0.778 | 110 | | Quarterly | 0.879 | 0.847 | 111 | | Monthly | 0.872 | 0.836 | 112 | | Hourly | 0.615 | 0.920 | 113 | | Weekly | 0.952 | 0.920 | 114 | | Daily | 0.968 | 0.920 | 115 | 116 | 117 | ## Replicating M4 results 118 | 119 | 120 | Replicating the M4 results is as easy as running the following line of code (for each frequency) after installing the package via pip: 121 | 122 | ```console 123 | python -m ESRNN.m4_run --dataset 'Yearly' --results_directory '/some/path' \ 124 | --gpu_id 0 --use_cpu 0 125 | ``` 126 | 127 | Use `--help` to get the description of each argument: 128 | 129 | ```console 130 | python -m ESRNN.m4_run --help 131 | ``` 132 | 133 | ## Authors 134 | This repository was developed with joint efforts from AutonLab researchers at Carnegie Mellon University and Orax data scientists. 135 | * **Kin Gutierrez** - [kdgutier](https://github.com/kdgutier) 136 | * **Cristian Challu** - [cristianchallu](https://github.com/cristianchallu) 137 | * **Federico Garza** - [FedericoGarza](https://github.com/FedericoGarza) - [mail](fede.garza.ramirez@gmail.com) 138 | * **Max Mergenthaler** - [mergenthaler](https://github.com/mergenthaler) 139 | 140 | ## License 141 | This project is licensed under the MIT License - see the [LICENSE](https://github.com/kdgutier/esrnn_torch/blob/master/LICENSE) file for details. 142 | 143 | 144 | ## REFERENCES 145 | 1. [A hybrid method of exponential smoothing and recurrent neural networks for time series forecasting](https://www.sciencedirect.com/science/article/pii/S0169207019301153) 146 | 2. [The M4 Competition: Results, findings, conclusion and way forward](https://www.researchgate.net/publication/325901666_The_M4_Competition_Results_findings_conclusion_and_way_forward) 147 | 3. [M4 Competition Data](https://github.com/M4Competition/M4-methods/tree/master/Dataset) 148 | 4. [Dilated Recurrent Neural Networks](https://papers.nips.cc/paper/6613-dilated-recurrent-neural-networks.pdf) 149 | 5. [Residual LSTM: Design of a Deep Recurrent Architecture for Distant Speech Recognition](https://arxiv.org/abs/1701.03360) 150 | 6. [A Dual-Stage Attention-Based recurrent neural network for time series prediction](https://arxiv.org/abs/1704.02971) 151 | -------------------------------------------------------------------------------- /ESRNN/m4_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | from six.moves import urllib 3 | import subprocess 4 | 5 | import numpy as np 6 | import pandas as pd 7 | 8 | from ESRNN.utils_evaluation import Naive2 9 | 10 | 11 | seas_dict = {'Hourly': {'seasonality': 24, 'input_size': 24, 12 | 'output_size': 48, 'freq': 'H'}, 13 | 'Daily': {'seasonality': 7, 'input_size': 7, 14 | 'output_size': 14, 'freq': 'D'}, 15 | 'Weekly': {'seasonality': 52, 'input_size': 52, 16 | 'output_size': 13, 'freq': 'W'}, 17 | 'Monthly': {'seasonality': 12, 'input_size': 12, 18 | 'output_size':18, 'freq': 'M'}, 19 | 'Quarterly': {'seasonality': 4, 'input_size': 4, 20 | 'output_size': 8, 'freq': 'Q'}, 21 | 'Yearly': {'seasonality': 1, 'input_size': 4, 22 | 'output_size': 6, 'freq': 'D'}} 23 | 24 | SOURCE_URL = 'https://raw.githubusercontent.com/Mcompetitions/M4-methods/master/Dataset/' 25 | 26 | 27 | def maybe_download(filename, directory): 28 | """ 29 | Download the data from M4's website, unless it's already here. 30 | 31 | Parameters 32 | ---------- 33 | filename: str 34 | Filename of M4 data with format /Type/Frequency.csv. Example: /Test/Daily-train.csv 35 | directory: str 36 | Custom directory where data will be downloaded. 37 | """ 38 | data_directory = directory + "/m4" 39 | train_directory = data_directory + "/Train/" 40 | test_directory = data_directory + "/Test/" 41 | 42 | if not os.path.exists(data_directory): 43 | os.mkdir(data_directory) 44 | if not os.path.exists(train_directory): 45 | os.mkdir(train_directory) 46 | if not os.path.exists(test_directory): 47 | os.mkdir(test_directory) 48 | 49 | filepath = os.path.join(data_directory, filename) 50 | if not os.path.exists(filepath): 51 | filepath, _ = urllib.request.urlretrieve(SOURCE_URL + filename, filepath) 52 | size = os.path.getsize(filepath) 53 | print('Successfully downloaded', filename, size, 'bytes.') 54 | return filepath 55 | 56 | def m4_parser(dataset_name, directory, num_obs=1000000): 57 | """ 58 | Transform M4 data into a panel. 59 | 60 | Parameters 61 | ---------- 62 | dataset_name: str 63 | Frequency of the data. Example: 'Yearly'. 64 | directory: str 65 | Custom directory where data will be saved. 66 | num_obs: int 67 | Number of time series to return. 68 | """ 69 | data_directory = directory + "/m4" 70 | train_directory = data_directory + "/Train/" 71 | test_directory = data_directory + "/Test/" 72 | freq = seas_dict[dataset_name]['freq'] 73 | 74 | m4_info = pd.read_csv(data_directory+'/M4-info.csv', usecols=['M4id','category']) 75 | m4_info = m4_info[m4_info['M4id'].str.startswith(dataset_name[0])].reset_index(drop=True) 76 | 77 | # Train data 78 | train_path='{}{}-train.csv'.format(train_directory, dataset_name) 79 | 80 | train_df = pd.read_csv(train_path, nrows=num_obs) 81 | train_df = train_df.rename(columns={'V1':'unique_id'}) 82 | 83 | train_df = pd.wide_to_long(train_df, stubnames=["V"], i="unique_id", j="ds").reset_index() 84 | train_df = train_df.rename(columns={'V':'y'}) 85 | train_df = train_df.dropna() 86 | train_df['split'] = 'train' 87 | train_df['ds'] = train_df['ds']-1 88 | # Get len of series per unique_id 89 | len_series = train_df.groupby('unique_id').agg({'ds': 'max'}).reset_index() 90 | len_series.columns = ['unique_id', 'len_serie'] 91 | 92 | # Test data 93 | test_path='{}{}-test.csv'.format(test_directory, dataset_name) 94 | 95 | test_df = pd.read_csv(test_path, nrows=num_obs) 96 | test_df = test_df.rename(columns={'V1':'unique_id'}) 97 | 98 | test_df = pd.wide_to_long(test_df, stubnames=["V"], i="unique_id", j="ds").reset_index() 99 | test_df = test_df.rename(columns={'V':'y'}) 100 | test_df = test_df.dropna() 101 | test_df['split'] = 'test' 102 | test_df = test_df.merge(len_series, on='unique_id') 103 | test_df['ds'] = test_df['ds'] + test_df['len_serie'] - 1 104 | test_df = test_df[['unique_id','ds','y','split']] 105 | 106 | df = pd.concat((train_df,test_df)) 107 | df = df.sort_values(by=['unique_id', 'ds']).reset_index(drop=True) 108 | 109 | # Create column with dates with freq of dataset 110 | len_series = df.groupby('unique_id').agg({'ds': 'max'}).reset_index() 111 | dates = [] 112 | for i in range(len(len_series)): 113 | len_serie = len_series.iloc[i,1] 114 | ranges = pd.date_range(start='1970/01/01', periods=len_serie, freq=freq) 115 | dates += list(ranges) 116 | df.loc[:,'ds'] = dates 117 | 118 | df = df.merge(m4_info, left_on=['unique_id'], right_on=['M4id']) 119 | df.drop(columns=['M4id'], inplace=True) 120 | df = df.rename(columns={'category': 'x'}) 121 | 122 | X_train_df = df[df['split']=='train'].filter(items=['unique_id', 'ds', 'x']) 123 | y_train_df = df[df['split']=='train'].filter(items=['unique_id', 'ds', 'y']) 124 | X_test_df = df[df['split']=='test'].filter(items=['unique_id', 'ds', 'x']) 125 | y_test_df = df[df['split']=='test'].filter(items=['unique_id', 'ds', 'y']) 126 | 127 | X_train_df = X_train_df.reset_index(drop=True) 128 | y_train_df = y_train_df.reset_index(drop=True) 129 | X_test_df = X_test_df.reset_index(drop=True) 130 | y_test_df = y_test_df.reset_index(drop=True) 131 | 132 | return X_train_df, y_train_df, X_test_df, y_test_df 133 | 134 | def naive2_predictions(dataset_name, directory, num_obs, y_train_df = None, y_test_df = None): 135 | """ 136 | Computes Naive2 predictions. 137 | 138 | Parameters 139 | ---------- 140 | dataset_name: str 141 | Frequency of the data. Example: 'Yearly'. 142 | directory: str 143 | Custom directory where data will be saved. 144 | num_obs: int 145 | Number of time series to return. 146 | y_train_df: DataFrame 147 | Y train set returned by m4_parser 148 | y_test_df: DataFrame 149 | Y test set returned by m4_parser 150 | """ 151 | # Read train and test data 152 | if (y_train_df is None) or (y_test_df is None): 153 | _, y_train_df, _, y_test_df = m4_parser(dataset_name, directory, num_obs) 154 | 155 | seasonality = seas_dict[dataset_name]['seasonality'] 156 | input_size = seas_dict[dataset_name]['input_size'] 157 | output_size = seas_dict[dataset_name]['output_size'] 158 | freq = seas_dict[dataset_name]['freq'] 159 | 160 | print('Preparing {} dataset'.format(dataset_name)) 161 | print('Preparing Naive2 {} dataset predictions'.format(dataset_name)) 162 | 163 | # Naive2 164 | y_naive2_df = pd.DataFrame(columns=['unique_id', 'ds', 'y_hat']) 165 | 166 | # Sort X by unique_id for faster loop 167 | y_train_df = y_train_df.sort_values(by=['unique_id', 'ds']) 168 | # List of uniques ids 169 | unique_ids = y_train_df['unique_id'].unique() 170 | # Panel of fitted models 171 | for unique_id in unique_ids: 172 | # Fast filter X and y by id. 173 | top_row = np.asscalar(y_train_df['unique_id'].searchsorted(unique_id, 'left')) 174 | bottom_row = np.asscalar(y_train_df['unique_id'].searchsorted(unique_id, 'right')) 175 | y_id = y_train_df[top_row:bottom_row] 176 | 177 | y_naive2 = pd.DataFrame(columns=['unique_id', 'ds', 'y_hat']) 178 | y_naive2['ds'] = pd.date_range(start=y_id.ds.max(), 179 | periods=output_size+1, freq=freq)[1:] 180 | y_naive2['unique_id'] = unique_id 181 | y_naive2['y_hat'] = Naive2(seasonality).fit(y_id.y.to_numpy()).predict(output_size) 182 | y_naive2_df = y_naive2_df.append(y_naive2) 183 | 184 | y_naive2_df = y_test_df.merge(y_naive2_df, on=['unique_id', 'ds'], how='left') 185 | y_naive2_df.rename(columns={'y_hat': 'y_hat_naive2'}, inplace=True) 186 | 187 | results_dir = directory + '/results' 188 | naive2_file = results_dir + '/{}-naive2predictions_{}.csv'.format(dataset_name, num_obs) 189 | y_naive2_df.to_csv(naive2_file, encoding='utf-8', index=None) 190 | 191 | return y_naive2_df 192 | 193 | def prepare_m4_data(dataset_name, directory, num_obs): 194 | """ 195 | Pipeline that obtains M4 times series, tranforms it and gets naive2 predictions. 196 | 197 | Parameters 198 | ---------- 199 | dataset_name: str 200 | Frequency of the data. Example: 'Yearly'. 201 | directory: str 202 | Custom directory where data will be saved. 203 | num_obs: int 204 | Number of time series to return. 205 | py_predictions: bool 206 | whether use python or r predictions 207 | """ 208 | m4info_filename = maybe_download('M4-info.csv', directory) 209 | 210 | dailytrain_filename = maybe_download('Train/Daily-train.csv', directory) 211 | hourlytrain_filename = maybe_download('Train/Hourly-train.csv', directory) 212 | monthlytrain_filename = maybe_download('Train/Monthly-train.csv', directory) 213 | quarterlytrain_filename = maybe_download('Train/Quarterly-train.csv', directory) 214 | weeklytrain_filename = maybe_download('Train/Weekly-train.csv', directory) 215 | yearlytrain_filename = maybe_download('Train/Yearly-train.csv', directory) 216 | 217 | dailytest_filename = maybe_download('Test/Daily-test.csv', directory) 218 | hourlytest_filename = maybe_download('Test/Hourly-test.csv', directory) 219 | monthlytest_filename = maybe_download('Test/Monthly-test.csv', directory) 220 | quarterlytest_filename = maybe_download('Test/Quarterly-test.csv', directory) 221 | weeklytest_filename = maybe_download('Test/Weekly-test.csv', directory) 222 | yearlytest_filename = maybe_download('Test/Yearly-test.csv', directory) 223 | print('\n') 224 | 225 | X_train_df, y_train_df, X_test_df, y_test_df = m4_parser(dataset_name, directory, num_obs) 226 | 227 | results_dir = directory + '/results' 228 | if not os.path.exists(results_dir): 229 | os.mkdir(results_dir) 230 | 231 | naive2_file = results_dir + '/{}-naive2predictions_{}.csv' 232 | naive2_file = naive2_file.format(dataset_name, num_obs) 233 | 234 | if not os.path.exists(naive2_file): 235 | y_naive2_df = naive2_predictions(dataset_name, directory, num_obs, y_train_df, y_test_df) 236 | else: 237 | y_naive2_df = pd.read_csv(naive2_file) 238 | y_naive2_df['ds'] = pd.to_datetime(y_naive2_df['ds']) 239 | 240 | return X_train_df, y_train_df, X_test_df, y_naive2_df 241 | -------------------------------------------------------------------------------- /ESRNN/utils/ESRNN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from ESRNN.utils.DRNN import DRNN 4 | import numpy as np 5 | 6 | #import torch.jit as jit 7 | 8 | 9 | class _ES(nn.Module): 10 | def __init__(self, mc): 11 | super(_ES, self).__init__() 12 | self.mc = mc 13 | self.n_series = self.mc.n_series 14 | self.output_size = self.mc.output_size 15 | assert len(self.mc.seasonality) in [0, 1, 2] 16 | 17 | def gaussian_noise(self, input_data, std=0.2): 18 | size = input_data.size() 19 | noise = torch.autograd.Variable(input_data.data.new(size).normal_(0, std)) 20 | return input_data + noise 21 | 22 | #@jit.script_method 23 | def compute_levels_seasons(self, y, idxs): 24 | pass 25 | 26 | def normalize(self, y, level, seasonalities): 27 | pass 28 | 29 | def predict(self, trend, levels, seasonalities): 30 | pass 31 | 32 | def forward(self, ts_object): 33 | # parse mc 34 | input_size = self.mc.input_size 35 | output_size = self.mc.output_size 36 | exogenous_size = self.mc.exogenous_size 37 | noise_std = self.mc.noise_std 38 | seasonality = self.mc.seasonality 39 | batch_size = len(ts_object.idxs) 40 | 41 | # Parse ts_object 42 | y = ts_object.y 43 | idxs = ts_object.idxs 44 | n_series, n_time = y.shape 45 | if self.training: 46 | windows_end = n_time-input_size-output_size+1 47 | windows_range = range(windows_end) 48 | else: 49 | windows_start = n_time-input_size-output_size+1 50 | windows_end = n_time-input_size+1 51 | 52 | windows_range = range(windows_start, windows_end) 53 | n_windows = len(windows_range) 54 | assert n_windows>0 55 | 56 | # Initialize windows, levels and seasonalities 57 | levels, seasonalities = self.compute_levels_seasons(y, idxs) 58 | windows_y_hat = torch.zeros((n_windows, batch_size, input_size+exogenous_size), 59 | device=self.mc.device) 60 | windows_y = torch.zeros((n_windows, batch_size, output_size), 61 | device=self.mc.device) 62 | 63 | for i, window in enumerate(windows_range): 64 | # Windows yhat 65 | y_hat_start = window 66 | y_hat_end = input_size + window 67 | 68 | # Y_hat deseasonalization and normalization 69 | window_y_hat = self.normalize(y=y[:, y_hat_start:y_hat_end], 70 | level=levels[:, [y_hat_end-1]], 71 | seasonalities=seasonalities, 72 | start=y_hat_start, end=y_hat_end) 73 | 74 | if self.training: 75 | window_y_hat = self.gaussian_noise(window_y_hat, std=noise_std) 76 | 77 | # Concatenate categories 78 | if exogenous_size>0: 79 | window_y_hat = torch.cat((window_y_hat, ts_object.categories), 1) 80 | 81 | windows_y_hat[i, :, :] += window_y_hat 82 | 83 | # Windows y (for loss during train) 84 | if self.training: 85 | y_start = y_hat_end 86 | y_end = y_start+output_size 87 | # Y deseasonalization and normalization 88 | window_y = self.normalize(y=y[:, y_start:y_end], 89 | level=levels[:, [y_start]], 90 | seasonalities=seasonalities, 91 | start=y_start, end=y_end) 92 | windows_y[i, :, :] += window_y 93 | 94 | return windows_y_hat, windows_y, levels, seasonalities 95 | 96 | class _ESM(_ES): 97 | def __init__(self, mc): 98 | super(_ESM, self).__init__(mc) 99 | # Level and Seasonality Smoothing parameters 100 | # 1 level, S seasonalities, S init_seas 101 | embeds_size = 1 + len(self.mc.seasonality) + sum(self.mc.seasonality) 102 | init_embeds = torch.ones((self.n_series, embeds_size)) * 0.5 103 | self.embeds = nn.Embedding(self.n_series, embeds_size) 104 | self.embeds.weight.data.copy_(init_embeds) 105 | self.register_buffer('seasonality', torch.LongTensor(self.mc.seasonality)) 106 | 107 | #@jit.script_method 108 | def compute_levels_seasons(self, y, idxs): 109 | """ 110 | Computes levels and seasons 111 | """ 112 | # Lookup parameters per serie 113 | #seasonality = self.seasonality 114 | embeds = self.embeds(idxs) 115 | lev_sms = torch.sigmoid(embeds[:, 0]) 116 | 117 | # Initialize seasonalities 118 | seas_prod = torch.ones(len(y[:,0])).to(y.device) 119 | #seasonalities1 = torch.jit.annotate(List[Tensor], []) 120 | #seasonalities2 = torch.jit.annotate(List[Tensor], []) 121 | seasonalities1 = [] 122 | seasonalities2 = [] 123 | seas_sms1 = torch.ones(1).to(y.device) 124 | seas_sms2 = torch.ones(1).to(y.device) 125 | 126 | if len(self.seasonality)>0: 127 | seas_sms1 = torch.sigmoid(embeds[:, 1]) 128 | init_seas1 = torch.exp(embeds[:, 2:(2+self.seasonality[0])]).unbind(1) 129 | assert len(init_seas1) == self.seasonality[0] 130 | 131 | for i in range(len(init_seas1)): 132 | seasonalities1 += [init_seas1[i]] 133 | seasonalities1 += [init_seas1[0]] 134 | seas_prod = seas_prod * init_seas1[0] 135 | 136 | if len(self.seasonality)==2: 137 | seas_sms2 = torch.sigmoid(embeds[:, 2+self.seasonality[0]]) 138 | init_seas2 = torch.exp(embeds[:, 3+self.seasonality[0]:]).unbind(1) 139 | assert len(init_seas2) == self.seasonality[1] 140 | 141 | for i in range(len(init_seas2)): 142 | seasonalities2 += [init_seas2[i]] 143 | seasonalities2 += [init_seas2[0]] 144 | seas_prod = seas_prod * init_seas2[0] 145 | 146 | # Initialize levels 147 | #levels = torch.jit.annotate(List[Tensor], []) 148 | levels = [] 149 | levels += [y[:,0]/seas_prod] 150 | 151 | # Recursive seasonalities and levels 152 | ys = y.unbind(1) 153 | n_time = len(ys) 154 | for t in range(1, n_time): 155 | 156 | seas_prod_t = torch.ones(len(y[:,t])).to(y.device) 157 | if len(self.seasonality)>0: 158 | seas_prod_t = seas_prod_t * seasonalities1[t] 159 | if len(self.seasonality)==2: 160 | seas_prod_t = seas_prod_t * seasonalities2[t] 161 | 162 | newlev = lev_sms * (ys[t] / seas_prod_t) + (1-lev_sms) * levels[t-1] 163 | levels += [newlev] 164 | 165 | if len(self.seasonality)==1: 166 | newseason1 = seas_sms1 * (ys[t] / newlev) + (1-seas_sms1) * seasonalities1[t] 167 | seasonalities1 += [newseason1] 168 | 169 | if len(self.seasonality)==2: 170 | newseason1 = seas_sms1 * (ys[t] / (newlev * seasonalities2[t])) + \ 171 | (1-seas_sms1) * seasonalities1[t] 172 | seasonalities1 += [newseason1] 173 | newseason2 = seas_sms2 * (ys[t] / (newlev * seasonalities1[t])) + \ 174 | (1-seas_sms2) * seasonalities2[t] 175 | seasonalities2 += [newseason2] 176 | 177 | levels = torch.stack(levels).transpose(1,0) 178 | 179 | #seasonalities = torch.jit.annotate(List[Tensor], []) 180 | seasonalities = [] 181 | 182 | if len(self.seasonality)>0: 183 | seasonalities += [torch.stack(seasonalities1).transpose(1,0)] 184 | 185 | if len(self.seasonality)==2: 186 | seasonalities += [torch.stack(seasonalities2).transpose(1,0)] 187 | 188 | return levels, seasonalities 189 | 190 | def normalize(self, y, level, seasonalities, start, end): 191 | # Deseasonalization and normalization 192 | y_n = y / level 193 | for s in range(len(self.seasonality)): 194 | y_n /= seasonalities[s][:, start:end] 195 | y_n = torch.log(y_n) 196 | return y_n 197 | 198 | def predict(self, trend, levels, seasonalities): 199 | output_size = self.mc.output_size 200 | seasonality = self.mc.seasonality 201 | n_time = levels.shape[1] 202 | 203 | # Denormalize 204 | trend = torch.exp(trend) 205 | 206 | # Completion of seasonalities if prediction horizon is larger than seasonality 207 | # Naive2 like prediction, to avoid recursive forecasting 208 | for s in range(len(seasonality)): 209 | if output_size > seasonality[s]: 210 | repetitions = int(np.ceil(output_size/seasonality[s]))-1 211 | last_season = seasonalities[s][:, -seasonality[s]:] 212 | extra_seasonality = last_season.repeat((1, repetitions)) 213 | seasonalities[s] = torch.cat((seasonalities[s], extra_seasonality), 1) 214 | 215 | # Deseasonalization and normalization (inverse) 216 | y_hat = trend * levels[:,[n_time-1]] 217 | for s in range(len(seasonality)): 218 | y_hat *= seasonalities[s][:, n_time:(n_time+output_size)] 219 | 220 | return y_hat 221 | 222 | class _RNN(nn.Module): 223 | def __init__(self, mc): 224 | super(_RNN, self).__init__() 225 | self.mc = mc 226 | self.layers = len(mc.dilations) 227 | 228 | layers = [] 229 | for grp_num in range(len(mc.dilations)): 230 | if grp_num == 0: 231 | input_size = mc.input_size + mc.exogenous_size 232 | else: 233 | input_size = mc.state_hsize 234 | layer = DRNN(input_size, 235 | mc.state_hsize, 236 | n_layers=len(mc.dilations[grp_num]), 237 | dilations=mc.dilations[grp_num], 238 | cell_type=mc.cell_type) 239 | layers.append(layer) 240 | 241 | self.rnn_stack = nn.Sequential(*layers) 242 | 243 | if self.mc.add_nl_layer: 244 | self.MLPW = nn.Linear(mc.state_hsize, mc.state_hsize) 245 | 246 | self.adapterW = nn.Linear(mc.state_hsize, mc.output_size) 247 | 248 | def forward(self, input_data): 249 | for layer_num in range(len(self.rnn_stack)): 250 | residual = input_data 251 | output, _ = self.rnn_stack[layer_num](input_data) 252 | if layer_num > 0: 253 | output += residual 254 | input_data = output 255 | 256 | if self.mc.add_nl_layer: 257 | input_data = self.MLPW(input_data) 258 | input_data = torch.tanh(input_data) 259 | 260 | input_data = self.adapterW(input_data) 261 | return input_data 262 | 263 | 264 | class _ESRNN(nn.Module): 265 | def __init__(self, mc): 266 | super(_ESRNN, self).__init__() 267 | self.mc = mc 268 | self.es = _ESM(mc).to(self.mc.device) 269 | self.rnn = _RNN(mc).to(self.mc.device) 270 | 271 | def forward(self, ts_object): 272 | # ES Forward 273 | windows_y_hat, windows_y, levels, seasonalities = self.es(ts_object) 274 | 275 | # RNN Forward 276 | windows_y_hat = self.rnn(windows_y_hat) 277 | 278 | return windows_y, windows_y_hat, levels 279 | 280 | def predict(self, ts_object): 281 | # ES Forward 282 | windows_y_hat, _, levels, seasonalities = self.es(ts_object) 283 | 284 | # RNN Forward 285 | windows_y_hat = self.rnn(windows_y_hat) 286 | trend = windows_y_hat[-1,:,:] # Last observation prediction 287 | 288 | y_hat = self.es.predict(trend, levels, seasonalities) 289 | return y_hat 290 | -------------------------------------------------------------------------------- /ESRNN/utils/DRNN.py: -------------------------------------------------------------------------------- 1 | # Dilated Recurrent Neural Networks. https://papers.nips.cc/paper/6613-dilated-recurrent-neural-networks.pdf 2 | # implementation from https://github.com/zalandoresearch/pytorch-dilated-rnn 3 | # Residual LSTM: Design of a Deep Recurrent Architecture for Distant Speech Recognition. https://arxiv.org/abs/1701.03360 4 | # A Dual-Stage Attention-Based recurrent neural network for time series prediction. https://arxiv.org/abs/1704.02971 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.autograd as autograd 9 | 10 | #import torch.jit as jit 11 | 12 | use_cuda = torch.cuda.is_available() 13 | 14 | 15 | class LSTMCell(nn.Module): #jit.ScriptModule 16 | def __init__(self, input_size, hidden_size, dropout=0.): 17 | super(LSTMCell, self).__init__() 18 | self.input_size = input_size 19 | self.hidden_size = hidden_size 20 | self.weight_ih = nn.Parameter(torch.randn(4 * hidden_size, input_size)) 21 | self.weight_hh = nn.Parameter(torch.randn(4 * hidden_size, hidden_size)) 22 | self.bias_ih = nn.Parameter(torch.randn(4 * hidden_size)) 23 | self.bias_hh = nn.Parameter(torch.randn(4 * hidden_size)) 24 | self.dropout = dropout 25 | 26 | #@jit.script_method 27 | def forward(self, input, hidden): 28 | # type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]] 29 | hx, cx = hidden[0].squeeze(0), hidden[1].squeeze(0) 30 | gates = (torch.matmul(input, self.weight_ih.t()) + self.bias_ih + 31 | torch.matmul(hx, self.weight_hh.t()) + self.bias_hh) 32 | ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) 33 | 34 | ingate = torch.sigmoid(ingate) 35 | forgetgate = torch.sigmoid(forgetgate) 36 | cellgate = torch.tanh(cellgate) 37 | outgate = torch.sigmoid(outgate) 38 | 39 | cy = (forgetgate * cx) + (ingate * cellgate) 40 | hy = outgate * torch.tanh(cy) 41 | 42 | return hy, (hy, cy) 43 | 44 | 45 | class ResLSTMCell(nn.Module): 46 | def __init__(self, input_size, hidden_size, dropout=0.): 47 | super(ResLSTMCell, self).__init__() 48 | self.register_buffer('input_size', torch.Tensor([input_size])) 49 | self.register_buffer('hidden_size', torch.Tensor([hidden_size])) 50 | self.weight_ii = nn.Parameter(torch.randn(3 * hidden_size, input_size)) 51 | self.weight_ic = nn.Parameter(torch.randn(3 * hidden_size, hidden_size)) 52 | self.weight_ih = nn.Parameter(torch.randn(3 * hidden_size, hidden_size)) 53 | self.bias_ii = nn.Parameter(torch.randn(3 * hidden_size)) 54 | self.bias_ic = nn.Parameter(torch.randn(3 * hidden_size)) 55 | self.bias_ih = nn.Parameter(torch.randn(3 * hidden_size)) 56 | self.weight_hh = nn.Parameter(torch.randn(1 * hidden_size, hidden_size)) 57 | self.bias_hh = nn.Parameter(torch.randn(1 * hidden_size)) 58 | self.weight_ir = nn.Parameter(torch.randn(hidden_size, input_size)) 59 | self.dropout = dropout 60 | 61 | #@jit.script_method 62 | def forward(self, input, hidden): 63 | # type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]] 64 | hx, cx = hidden[0].squeeze(0), hidden[1].squeeze(0) 65 | 66 | ifo_gates = (torch.matmul(input, self.weight_ii.t()) + self.bias_ii + 67 | torch.matmul(hx, self.weight_ih.t()) + self.bias_ih + 68 | torch.matmul(cx, self.weight_ic.t()) + self.bias_ic) 69 | ingate, forgetgate, outgate = ifo_gates.chunk(3, 1) 70 | 71 | cellgate = torch.matmul(hx, self.weight_hh.t()) + self.bias_hh 72 | 73 | ingate = torch.sigmoid(ingate) 74 | forgetgate = torch.sigmoid(forgetgate) 75 | cellgate = torch.tanh(cellgate) 76 | outgate = torch.sigmoid(outgate) 77 | 78 | cy = (forgetgate * cx) + (ingate * cellgate) 79 | ry = torch.tanh(cy) 80 | 81 | if self.input_size == self.hidden_size: 82 | hy = outgate * (ry + input) 83 | else: 84 | hy = outgate * (ry + torch.matmul(input, self.weight_ir.t())) 85 | return hy, (hy, cy) 86 | 87 | 88 | class ResLSTMLayer(nn.Module): 89 | def __init__(self, input_size, hidden_size, dropout=0.): 90 | super(ResLSTMLayer, self).__init__() 91 | self.input_size = input_size 92 | self.hidden_size = hidden_size 93 | self.cell = ResLSTMCell(input_size, hidden_size, dropout=0.) 94 | 95 | #@jit.script_method 96 | def forward(self, input, hidden): 97 | # type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]] 98 | inputs = input.unbind(0) 99 | #outputs = torch.jit.annotate(List[Tensor], []) 100 | outputs = [] 101 | for i in range(len(inputs)): 102 | out, hidden = self.cell(inputs[i], hidden) 103 | outputs += [out] 104 | outputs = torch.stack(outputs) 105 | return outputs, hidden 106 | 107 | 108 | class AttentiveLSTMLayer(nn.Module): 109 | def __init__(self, input_size, hidden_size, dropout=0.0): 110 | super(AttentiveLSTMLayer, self).__init__() 111 | self.input_size = input_size 112 | self.hidden_size = hidden_size 113 | attention_hsize = hidden_size 114 | self.attention_hsize = attention_hsize 115 | 116 | self.cell = LSTMCell(input_size, hidden_size) 117 | self.attn_layer = nn.Sequential(nn.Linear(2 * hidden_size + input_size, attention_hsize), 118 | nn.Tanh(), 119 | nn.Linear(attention_hsize, 1)) 120 | self.softmax = nn.Softmax(dim=0) 121 | self.dropout = dropout 122 | 123 | #@jit.script_method 124 | def forward(self, input, hidden): 125 | inputs = input.unbind(0) 126 | #outputs = torch.jit.annotate(List[Tensor], []) 127 | outputs = [] 128 | 129 | for t in range(len(input)): 130 | # attention on windows 131 | hx, cx = hidden[0].squeeze(0), hidden[1].squeeze(0) 132 | hx_rep = hx.repeat(len(inputs), 1, 1) 133 | cx_rep = cx.repeat(len(inputs), 1, 1) 134 | x = torch.cat((input, hx_rep, cx_rep), dim=-1) 135 | l = self.attn_layer(x) 136 | beta = self.softmax(l) 137 | context = torch.bmm(beta.permute(1, 2, 0), 138 | input.permute(1, 0, 2)).squeeze(1) 139 | out, hidden = self.cell(context, hidden) 140 | outputs += [out] 141 | outputs = torch.stack(outputs) 142 | return outputs, hidden 143 | 144 | 145 | class DRNN(nn.Module): 146 | 147 | def __init__(self, n_input, n_hidden, n_layers, dilations, dropout=0, cell_type='GRU', batch_first=False): 148 | 149 | super(DRNN, self).__init__() 150 | 151 | self.dilations = dilations 152 | self.cell_type = cell_type 153 | self.batch_first = batch_first 154 | 155 | layers = [] 156 | if self.cell_type == "GRU": 157 | cell = nn.GRU 158 | elif self.cell_type == "RNN": 159 | cell = nn.RNN 160 | elif self.cell_type == "LSTM": 161 | cell = nn.LSTM 162 | elif self.cell_type == "ResLSTM": 163 | cell = ResLSTMLayer 164 | elif self.cell_type == "AttentiveLSTM": 165 | cell = AttentiveLSTMLayer 166 | else: 167 | raise NotImplementedError 168 | 169 | for i in range(n_layers): 170 | if i == 0: 171 | c = cell(n_input, n_hidden, dropout=dropout) 172 | else: 173 | c = cell(n_hidden, n_hidden, dropout=dropout) 174 | layers.append(c) 175 | self.cells = nn.Sequential(*layers) 176 | 177 | def forward(self, inputs, hidden=None): 178 | if self.batch_first: 179 | inputs = inputs.transpose(0, 1) 180 | outputs = [] 181 | for i, (cell, dilation) in enumerate(zip(self.cells, self.dilations)): 182 | if hidden is None: 183 | inputs, _ = self.drnn_layer(cell, inputs, dilation) 184 | else: 185 | inputs, hidden[i] = self.drnn_layer(cell, inputs, dilation, hidden[i]) 186 | 187 | outputs.append(inputs[-dilation:]) 188 | 189 | if self.batch_first: 190 | inputs = inputs.transpose(0, 1) 191 | return inputs, outputs 192 | 193 | def drnn_layer(self, cell, inputs, rate, hidden=None): 194 | 195 | n_steps = len(inputs) 196 | batch_size = inputs[0].size(0) 197 | hidden_size = cell.hidden_size 198 | 199 | inputs, dilated_steps = self._pad_inputs(inputs, n_steps, rate) 200 | dilated_inputs = self._prepare_inputs(inputs, rate) 201 | 202 | if hidden is None: 203 | dilated_outputs, hidden = self._apply_cell(dilated_inputs, cell, batch_size, rate, hidden_size) 204 | else: 205 | hidden = self._prepare_inputs(hidden, rate) 206 | dilated_outputs, hidden = self._apply_cell(dilated_inputs, cell, batch_size, rate, hidden_size, 207 | hidden=hidden) 208 | 209 | splitted_outputs = self._split_outputs(dilated_outputs, rate) 210 | outputs = self._unpad_outputs(splitted_outputs, n_steps) 211 | 212 | return outputs, hidden 213 | 214 | def _apply_cell(self, dilated_inputs, cell, batch_size, rate, hidden_size, hidden=None): 215 | if hidden is None: 216 | if self.cell_type == 'LSTM' or self.cell_type == 'ResLSTM' or self.cell_type == 'AttentiveLSTM': 217 | c, m = self.init_hidden(batch_size * rate, hidden_size) 218 | hidden = (c.unsqueeze(0), m.unsqueeze(0)) 219 | else: 220 | hidden = self.init_hidden(batch_size * rate, hidden_size).unsqueeze(0) 221 | 222 | dilated_outputs, hidden = cell(dilated_inputs, hidden) # compatibility hack 223 | 224 | return dilated_outputs, hidden 225 | 226 | def _unpad_outputs(self, splitted_outputs, n_steps): 227 | return splitted_outputs[:n_steps] 228 | 229 | def _split_outputs(self, dilated_outputs, rate): 230 | batchsize = dilated_outputs.size(1) // rate 231 | 232 | blocks = [dilated_outputs[:, i * batchsize: (i + 1) * batchsize, :] for i in range(rate)] 233 | 234 | interleaved = torch.stack((blocks)).transpose(1, 0).contiguous() 235 | interleaved = interleaved.view(dilated_outputs.size(0) * rate, 236 | batchsize, 237 | dilated_outputs.size(2)) 238 | return interleaved 239 | 240 | def _pad_inputs(self, inputs, n_steps, rate): 241 | iseven = (n_steps % rate) == 0 242 | 243 | if not iseven: 244 | dilated_steps = n_steps // rate + 1 245 | 246 | zeros_ = torch.zeros(dilated_steps * rate - inputs.size(0), 247 | inputs.size(1), 248 | inputs.size(2)) 249 | if use_cuda: 250 | zeros_ = zeros_.cuda() 251 | 252 | inputs = torch.cat((inputs, autograd.Variable(zeros_))) 253 | else: 254 | dilated_steps = n_steps // rate 255 | 256 | return inputs, dilated_steps 257 | 258 | def _prepare_inputs(self, inputs, rate): 259 | dilated_inputs = torch.cat([inputs[j::rate, :, :] for j in range(rate)], 1) 260 | return dilated_inputs 261 | 262 | def init_hidden(self, batch_size, hidden_dim): 263 | hidden = autograd.Variable(torch.zeros(batch_size, hidden_dim)) 264 | if use_cuda: 265 | hidden = hidden.cuda() 266 | if self.cell_type == "LSTM" or self.cell_type == 'ResLSTM' or self.cell_type == 'AttentiveLSTM': 267 | memory = autograd.Variable(torch.zeros(batch_size, hidden_dim)) 268 | if use_cuda: 269 | memory = memory.cuda() 270 | return hidden, memory 271 | else: 272 | return hidden 273 | 274 | 275 | if __name__ == '__main__': 276 | n_inp = 4 277 | n_hidden = 4 278 | n_layers = 2 279 | batch_size = 3 280 | n_windows = 2 281 | cell_type = 'ResLSTM' 282 | 283 | model = DRNN(n_inp, n_hidden, n_layers=n_layers, cell_type=cell_type, dilations=[1,2]) 284 | 285 | test_x1 = torch.autograd.Variable(torch.randn(n_windows, batch_size, n_inp)) 286 | test_x2 = torch.autograd.Variable(torch.randn(n_windows, batch_size, n_inp)) 287 | 288 | out, hidden = model(test_x1) 289 | -------------------------------------------------------------------------------- /ESRNN/utils_evaluation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from numpy.random import seed 3 | seed(1) 4 | 5 | import pandas as pd 6 | from math import sqrt 7 | 8 | 9 | ######################## 10 | # UTILITY MODELS 11 | ######################## 12 | 13 | def detrend(insample_data): 14 | """ 15 | Calculates a & b parameters of LRL 16 | :param insample_data: 17 | :return: 18 | """ 19 | x = np.arange(len(insample_data)) 20 | a, b = np.polyfit(x, insample_data, 1) 21 | return a, b 22 | 23 | def deseasonalize(original_ts, ppy): 24 | """ 25 | Calculates and returns seasonal indices 26 | :param original_ts: original data 27 | :param ppy: periods per year 28 | :return: 29 | """ 30 | """ 31 | # === get in-sample data 32 | original_ts = original_ts[:-out_of_sample] 33 | """ 34 | if seasonality_test(original_ts, ppy): 35 | # ==== get moving averages 36 | ma_ts = moving_averages(original_ts, ppy) 37 | 38 | # ==== get seasonality indices 39 | le_ts = original_ts * 100 / ma_ts 40 | le_ts = np.hstack((le_ts, np.full((ppy - (len(le_ts) % ppy)), np.nan))) 41 | le_ts = np.reshape(le_ts, (-1, ppy)) 42 | si = np.nanmean(le_ts, 0) 43 | norm = np.sum(si) / (ppy * 100) 44 | si = si / norm 45 | else: 46 | si = np.ones(ppy) 47 | 48 | return si 49 | 50 | def moving_averages(ts_init, window): 51 | """ 52 | Calculates the moving averages for a given TS 53 | :param ts_init: the original time series 54 | :param window: window length 55 | :return: moving averages ts 56 | """ 57 | """ 58 | As noted by Professor Isidro Lloret Galiana: 59 | line 82: 60 | if len(ts_init) % 2 == 0: 61 | 62 | should be changed to 63 | if window % 2 == 0: 64 | 65 | This change has a minor (less then 0.05%) impact on the calculations of the seasonal indices 66 | In order for the results to be fully replicable this change is not incorporated into the code below 67 | """ 68 | ts_init = pd.Series(ts_init) 69 | 70 | if len(ts_init) % 2 == 0: 71 | ts_ma = ts_init.rolling(window, center=True).mean() 72 | ts_ma = ts_ma.rolling(2, center=True).mean() 73 | ts_ma = np.roll(ts_ma, -1) 74 | else: 75 | ts_ma = ts_init.rolling(window, center=True).mean() 76 | 77 | return ts_ma 78 | 79 | def seasonality_test(original_ts, ppy): 80 | """ 81 | Seasonality test 82 | :param original_ts: time series 83 | :param ppy: periods per year 84 | :return: boolean value: whether the TS is seasonal 85 | """ 86 | s = acf(original_ts, 1) 87 | for i in range(2, ppy): 88 | s = s + (acf(original_ts, i) ** 2) 89 | 90 | limit = 1.645 * (sqrt((1 + 2 * s) / len(original_ts))) 91 | 92 | return (abs(acf(original_ts, ppy))) > limit 93 | 94 | def acf(data, k): 95 | """ 96 | Autocorrelation function 97 | :param data: time series 98 | :param k: lag 99 | :return: 100 | """ 101 | m = np.mean(data) 102 | s1 = 0 103 | for i in range(k, len(data)): 104 | s1 = s1 + ((data[i] - m) * (data[i - k] - m)) 105 | 106 | s2 = 0 107 | for i in range(0, len(data)): 108 | s2 = s2 + ((data[i] - m) ** 2) 109 | 110 | return float(s1 / s2) 111 | 112 | class Naive: 113 | """ 114 | Naive model. 115 | This benchmark model produces a forecast that is equal to 116 | the last observed value for a given time series. 117 | """ 118 | def __init__(self): 119 | pass 120 | 121 | def fit(self, ts_init): 122 | """ 123 | ts_init: the original time series 124 | ts_naive: last observations of time series 125 | """ 126 | self.ts_naive = [ts_init[-1]] 127 | return self 128 | 129 | def predict(self, h): 130 | return np.array(self.ts_naive * h) 131 | 132 | class SeasonalNaive: 133 | """ 134 | Seasonal Naive model. 135 | This benchmark model produces a forecast that is equal to 136 | the last observed value of the same season for a given time 137 | series. 138 | """ 139 | def __init__(self): 140 | pass 141 | 142 | def fit(self, ts_init, seasonality): 143 | """ 144 | ts_init: the original time series 145 | frcy: frequency of the time series 146 | ts_naive: last observations of time series 147 | """ 148 | self.ts_seasonal_naive = ts_init[-seasonality:] 149 | return self 150 | 151 | def predict(self, h): 152 | repetitions = int(np.ceil(h/len(self.ts_seasonal_naive))) 153 | y_hat = np.tile(self.ts_seasonal_naive, reps=repetitions)[:h] 154 | return y_hat 155 | 156 | class Naive2: 157 | """ 158 | Naive2 model. 159 | Popular benchmark model for time series forecasting that automatically adapts 160 | to the potential seasonality of a series based on an autocorrelation test. 161 | If the series is seasonal the model composes the predictions of Naive and SeasonalNaive, 162 | else the model predicts on the simple Naive. 163 | """ 164 | def __init__(self, seasonality): 165 | self.seasonality = seasonality 166 | 167 | def fit(self, ts_init): 168 | seasonality_in = deseasonalize(ts_init, ppy=self.seasonality) 169 | windows = int(np.ceil(len(ts_init) / self.seasonality)) 170 | 171 | self.ts_init = ts_init 172 | self.s_hat = np.tile(seasonality_in, reps=windows)[:len(ts_init)] 173 | self.ts_des = ts_init / self.s_hat 174 | 175 | return self 176 | 177 | def predict(self, h): 178 | s_hat = SeasonalNaive().fit(self.s_hat, 179 | seasonality=self.seasonality).predict(h) 180 | r_hat = Naive().fit(self.ts_des).predict(h) 181 | y_hat = s_hat * r_hat 182 | return y_hat 183 | 184 | ######################## 185 | # METRICS 186 | ######################## 187 | 188 | def mse(y, y_hat): 189 | """ 190 | Calculates Mean Squared Error. 191 | 192 | Parameters 193 | ---------- 194 | y: numpy array 195 | actual test values 196 | y_hat: numpy array 197 | predicted values 198 | 199 | Returns 200 | ------- 201 | mse: float 202 | mean squared error 203 | """ 204 | y = np.reshape(y, (-1,)) 205 | y_hat = np.reshape(y_hat, (-1,)) 206 | mse = np.mean(np.square(y - y_hat)).item() 207 | return mse 208 | 209 | def mape(y, y_hat): 210 | """ 211 | Calculates Mean Absolute Percentage Error. 212 | 213 | Parameters 214 | ---------- 215 | y: numpy array 216 | actual test values 217 | y_hat: numpy array 218 | predicted values 219 | 220 | Returns 221 | ------- 222 | mape: float 223 | mean absolute percentage error 224 | """ 225 | y = np.reshape(y, (-1,)) 226 | y_hat = np.reshape(y_hat, (-1,)) 227 | mape = np.mean(np.abs(y - y_hat) / np.abs(y)) 228 | return mape 229 | 230 | def smape(y, y_hat): 231 | """ 232 | Calculates Symmetric Mean Absolute Percentage Error. 233 | 234 | Parameters 235 | ---------- 236 | y: numpy array 237 | actual test values 238 | y_hat: numpy array 239 | predicted values 240 | 241 | Returns 242 | ------- 243 | smape: float 244 | symmetric mean absolute percentage error 245 | """ 246 | y = np.reshape(y, (-1,)) 247 | y_hat = np.reshape(y_hat, (-1,)) 248 | smape = np.mean(2.0 * np.abs(y - y_hat) / (np.abs(y) + np.abs(y_hat))) 249 | return smape 250 | 251 | def mase(y, y_hat, y_train, seasonality): 252 | """ 253 | Calculates Mean Absolute Scaled Error. 254 | 255 | Parameters 256 | ---------- 257 | y: numpy array 258 | actual test values 259 | y_hat: numpy array 260 | predicted values 261 | y_train: numpy array 262 | actual train values for Naive1 predictions 263 | seasonality: int 264 | main frequency of the time series 265 | Quarterly 4, Daily 7, Monthly 12 266 | 267 | Returns 268 | ------- 269 | mase: float 270 | mean absolute scaled error 271 | """ 272 | y_hat_naive = [] 273 | for i in range(seasonality, len(y_train)): 274 | y_hat_naive.append(y_train[(i - seasonality)]) 275 | 276 | masep = np.mean(abs(y_train[seasonality:] - y_hat_naive)) 277 | mase = np.mean(abs(y - y_hat)) / masep 278 | return mase 279 | 280 | ######################## 281 | # PANEL EVALUATION 282 | ######################## 283 | 284 | def evaluate_panel(y_panel, y_hat_panel, metric, 285 | y_insample=None, seasonality=None): 286 | """ 287 | Calculates metric for y_panel and y_hat_panel 288 | y_panel: pandas df 289 | panel with columns unique_id, ds, y 290 | y_naive2_panel: pandas df 291 | panel with columns unique_id, ds, y_hat 292 | y_insample: pandas df 293 | panel with columns unique_id, ds, y (train) 294 | this is used in the MASE 295 | seasonality: int 296 | main frequency of the time series 297 | Quarterly 4, Daily 7, Monthly 12 298 | return: list of metric evaluations 299 | """ 300 | metric_name = metric.__code__.co_name 301 | 302 | y_panel = y_panel.sort_values(['unique_id', 'ds']) 303 | y_hat_panel = y_hat_panel.sort_values(['unique_id', 'ds']) 304 | if y_insample is not None: 305 | y_insample = y_insample.sort_values(['unique_id', 'ds']) 306 | 307 | assert len(y_panel)==len(y_hat_panel) 308 | assert all(y_panel.unique_id.unique() == y_hat_panel.unique_id.unique()), "not same u_ids" 309 | 310 | evaluation = [] 311 | for u_id in y_panel.unique_id.unique(): 312 | top_row = np.asscalar(y_panel['unique_id'].searchsorted(u_id, 'left')) 313 | bottom_row = np.asscalar(y_panel['unique_id'].searchsorted(u_id, 'right')) 314 | y_id = y_panel[top_row:bottom_row].y.to_numpy() 315 | 316 | top_row = np.asscalar(y_hat_panel['unique_id'].searchsorted(u_id, 'left')) 317 | bottom_row = np.asscalar(y_hat_panel['unique_id'].searchsorted(u_id, 'right')) 318 | y_hat_id = y_hat_panel[top_row:bottom_row].y_hat.to_numpy() 319 | assert len(y_id)==len(y_hat_id) 320 | 321 | if metric_name == 'mase': 322 | assert (y_insample is not None) and (seasonality is not None) 323 | top_row = np.asscalar(y_insample['unique_id'].searchsorted(u_id, 'left')) 324 | bottom_row = np.asscalar(y_insample['unique_id'].searchsorted(u_id, 'right')) 325 | y_insample_id = y_insample[top_row:bottom_row].y.to_numpy() 326 | evaluation_id = metric(y_id, y_hat_id, y_insample_id, seasonality) 327 | else: 328 | evaluation_id = metric(y_id, y_hat_id) 329 | evaluation.append(evaluation_id) 330 | return evaluation 331 | 332 | def owa(y_panel, y_hat_panel, y_naive2_panel, y_insample, seasonality): 333 | """ 334 | Calculates MASE, sMAPE for Naive2 and current model 335 | then calculatess Overall Weighted Average. 336 | y_panel: pandas df 337 | panel with columns unique_id, ds, y 338 | y_hat_panel: pandas df 339 | panel with columns unique_id, ds, y_hat 340 | y_naive2_panel: pandas df 341 | panel with columns unique_id, ds, y_hat 342 | y_insample: pandas df 343 | panel with columns unique_id, ds, y (train) 344 | this is used in the MASE 345 | seasonality: int 346 | main frequency of the time series 347 | Quarterly 4, Daily 7, Monthly 12 348 | return: OWA 349 | """ 350 | total_mase = evaluate_panel(y_panel, y_hat_panel, mase, 351 | y_insample, seasonality) 352 | total_mase_naive2 = evaluate_panel(y_panel, y_naive2_panel, mase, 353 | y_insample, seasonality) 354 | total_smape = evaluate_panel(y_panel, y_hat_panel, smape) 355 | total_smape_naive2 = evaluate_panel(y_panel, y_naive2_panel, smape) 356 | 357 | assert len(total_mase) == len(total_mase_naive2) 358 | assert len(total_smape) == len(total_smape_naive2) 359 | assert len(total_mase) == len(total_smape) 360 | 361 | naive2_mase = np.mean(total_mase_naive2) 362 | naive2_smape = np.mean(total_smape_naive2) * 100 363 | 364 | model_mase = np.mean(total_mase) 365 | model_smape = np.mean(total_smape) * 100 366 | 367 | model_owa = ((model_mase/naive2_mase) + (model_smape/naive2_smape))/2 368 | return model_owa, model_mase, model_smape 369 | 370 | def evaluate_prediction_owa(y_hat_df, y_train_df, X_test_df, y_test_df, 371 | naive2_seasonality): 372 | """ 373 | y_hat_df: pandas df 374 | panel with columns unique_id, ds, y_hat 375 | y_train_df: pandas df 376 | panel with columns unique_id, ds, y 377 | X_test_df: pandas df 378 | panel with columns unique_id, ds, x 379 | y_test_df: pandas df 380 | panel with columns unique_id, ds, y, y_hat_naive2 381 | naive2_seasonality: int 382 | seasonality for the Naive2 predictions (needed for owa) 383 | model: python class 384 | python class with predict method 385 | """ 386 | y_panel = y_test_df.filter(['unique_id', 'ds', 'y']) 387 | y_naive2_panel = y_test_df.filter(['unique_id', 'ds', 'y_hat_naive2']) 388 | y_naive2_panel.rename(columns={'y_hat_naive2': 'y_hat'}, inplace=True) 389 | y_hat_panel = y_hat_df 390 | y_insample = y_train_df.filter(['unique_id', 'ds', 'y']) 391 | 392 | model_owa, model_mase, model_smape = owa(y_panel, y_hat_panel, 393 | y_naive2_panel, y_insample, 394 | seasonality=naive2_seasonality) 395 | 396 | print(15*'=', ' Model evaluation ', 14*'=') 397 | print('OWA: {} '.format(np.round(model_owa, 3))) 398 | print('SMAPE: {} '.format(np.round(model_smape, 3))) 399 | print('MASE: {} '.format(np.round(model_mase, 3))) 400 | return model_owa, model_mase, model_smape 401 | -------------------------------------------------------------------------------- /ESRNN/ESRNNensemble.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import random 4 | 5 | import numpy as np 6 | import pandas as pd 7 | 8 | import torch 9 | 10 | from pathlib import Path 11 | 12 | from ESRNN.utils.config import ModelConfig 13 | from ESRNN.utils.losses import DisaggregatedPinballLoss 14 | from ESRNN.utils.data import Iterator 15 | 16 | from ESRNN.ESRNN import ESRNN 17 | 18 | from ESRNN.utils_evaluation import owa 19 | 20 | class ESRNNensemble(object): 21 | """ Exponential Smoothing Recurrent Neural Network Ensemble 22 | 23 | Pytorch Implementation of the M4 time series forecasting competition winner. 24 | Proposed by Smyl. The model ensembles multiple ESRNNs that use a hybrid approach 25 | of Machine Learning and statistical methods by combining recurrent neural networks 26 | to model a common trend with shared parameters across series, and multiplicative 27 | Holt-Winter exponential smoothing. 28 | 29 | Parameters 30 | ---------- 31 | n_models: int 32 | the number of ESRNNs in the ensemble. 33 | n_top: int 34 | the number of ESRNNs from the n_models pool, to which a 35 | particular time series gets assigned during fitting procedure. 36 | max_epochs: int 37 | maximum number of complete passes to train data during fit 38 | freq_of_test: int 39 | period for the diagnostic evaluation of the model. 40 | learning_rate: float 41 | size of the stochastic gradient descent steps 42 | lr_scheduler_step_size: int 43 | this step_size is the period for each learning rate decay 44 | per_series_lr_multip: float 45 | multiplier for per-series parameters smoothing and initial 46 | seasonalities learning rate (default 1.0) 47 | gradient_eps: float 48 | term added to the Adam optimizer denominator to improve 49 | numerical stability (default: 1e-8) 50 | gradient_clipping_threshold: float 51 | max norm of gradient vector, with all parameters treated 52 | as a single vector 53 | rnn_weight_decay: float 54 | parameter to control classic L2/Tikhonov regularization 55 | of the rnn parameters 56 | noise_std: float 57 | standard deviation of white noise added to input during 58 | fit to avoid the model from memorizing the train data 59 | level_variability_penalty: float 60 | this parameter controls the strength of the penalization 61 | to the wigglines of the level vector, induces smoothness 62 | in the output 63 | testing_percentile: float 64 | This value is only for diagnostic evaluation. 65 | In case of percentile predictions this parameter controls 66 | for the value predicted, when forecasting point value, 67 | the forecast is the median, so percentile=50. 68 | training_percentile: float 69 | To reduce the model's tendency to over estimate, the 70 | training_percentile can be set to fit a smaller value 71 | through the Pinball Loss. 72 | batch_size: int 73 | number of training examples for the stochastic gradient steps 74 | seasonality: int list 75 | list of seasonalities of the time series 76 | Hourly [24, 168], Daily [7], Weekly [52], Monthly [12], 77 | Quarterly [4], Yearly []. 78 | input_size: int 79 | input size of the recurrent neural network, usually a 80 | multiple of seasonality 81 | output_size: int 82 | output_size or forecast horizon of the recurrent neural 83 | network, usually multiple of seasonality 84 | random_seed: int 85 | random_seed for pseudo random pytorch initializer and 86 | numpy random generator 87 | exogenous_size: int 88 | size of one hot encoded categorical variable, invariannt 89 | per time series of the panel 90 | min_inp_seq_length: int 91 | description 92 | max_periods: int 93 | Parameter to chop longer series, to last max_periods, 94 | max e.g. 40 years 95 | cell_type: str 96 | Type of RNN cell, available GRU, LSTM, RNN, ResidualLSTM. 97 | state_hsize: int 98 | dimension of hidden state of the recurrent neural network 99 | dilations: int list 100 | each list represents one chunk of Dilated LSTMS, connected in 101 | standard ResNet fashion 102 | add_nl_layer: bool 103 | whether to insert a tanh() layer between the RNN stack and the 104 | linear adaptor (output) layers 105 | device: str 106 | pytorch device either 'cpu' or 'cuda' 107 | Notes 108 | ----- 109 | **References:** 110 | `M4 Competition Conclusions 111 | `__ 112 | `Original Dynet Implementation of ESRNN 113 | `__ 114 | """ 115 | def __init__(self, n_models=1, n_top=1, max_epochs=15, batch_size=1, batch_size_test=128, 116 | freq_of_test=-1, learning_rate=1e-3, lr_scheduler_step_size=9, lr_decay=0.9, 117 | per_series_lr_multip=1.0, gradient_eps=1e-8, gradient_clipping_threshold=20, 118 | rnn_weight_decay=0, noise_std=0.001, level_variability_penalty=80, 119 | testing_percentile=50, training_percentile=50, ensemble=False, cell_type='LSTM', 120 | state_hsize=40, dilations=[[1, 2], [4, 8]], 121 | add_nl_layer=False, seasonality=[4], input_size=4, output_size=8, 122 | frequency='D', max_periods=20, random_seed=1, 123 | device='cuda', root_dir='./'): 124 | super(ESRNNensemble, self).__init__() 125 | 126 | self.n_models = n_models 127 | self.n_top = n_top 128 | assert n_models>=2, "Number of models for ensemble should be greater than 1" 129 | assert n_top<=n_models, "Number of top models should be smaller than models to ensemble" 130 | self.big_float = 1e6 131 | self.mc = ModelConfig(max_epochs=max_epochs, batch_size=batch_size, batch_size_test=batch_size_test, 132 | freq_of_test=freq_of_test, learning_rate=learning_rate, 133 | lr_scheduler_step_size=lr_scheduler_step_size, lr_decay=lr_decay, 134 | per_series_lr_multip=per_series_lr_multip, 135 | gradient_eps=gradient_eps, gradient_clipping_threshold=gradient_clipping_threshold, 136 | rnn_weight_decay=rnn_weight_decay, noise_std=noise_std, 137 | level_variability_penalty=level_variability_penalty, 138 | testing_percentile=testing_percentile, training_percentile=training_percentile, 139 | ensemble=ensemble, cell_type=cell_type, 140 | state_hsize=state_hsize, dilations=dilations, add_nl_layer=add_nl_layer, 141 | seasonality=seasonality, input_size=input_size, output_size=output_size, 142 | frequency=frequency, max_periods=max_periods, random_seed=random_seed, 143 | device=device, root_dir=root_dir) 144 | self._fitted = False 145 | 146 | def fit(self, X_df, y_df, X_test_df=None, y_test_df=None, shuffle=True): 147 | """ 148 | Fit ESRNN ensemble. 149 | 150 | Parameters 151 | ---------- 152 | X_df : pandas dataframe 153 | Train dataframe in long format with columns 'unique_id', 'ds' 154 | and 'x'. 155 | - 'unique_id' an identifier of each independent time series. 156 | - 'ds' is a datetime column 157 | - 'x' is a single exogenous variable 158 | y_df : pandas dataframe 159 | Train dataframe in long format with columns 'unique_id', 'ds' and 'y'. 160 | - 'unique_id' an identifier of each independent time series. 161 | - 'ds' is a datetime column 162 | - 'y' is the column with the target values 163 | X_test_df: pandas dataframe 164 | Optional test dataframe with columns 'unique_id', 'ds' and 'x'. 165 | If provided the fit procedure will evaluate the intermediate 166 | performance within training epochs. 167 | y_test_df: pandas dataframe 168 | Optional test dataframe with columns 'unique_id', 'ds' and 'x' and 169 | y_hat_benchmark column. 170 | If provided the fit procedure will evaluate the intermediate 171 | performance within training epochs. 172 | shuffle: boolean 173 | Name of the benchmark model for the comparison of the relative 174 | improvement of the model. 175 | 176 | Returns 177 | ------- 178 | self : returns an instance of self. 179 | """ 180 | 181 | # Transform long dfs to wide numpy 182 | assert type(X_df) == pd.core.frame.DataFrame 183 | assert type(y_df) == pd.core.frame.DataFrame 184 | assert all([(col in X_df) for col in ['unique_id', 'ds', 'x']]) 185 | assert all([(col in y_df) for col in ['unique_id', 'ds', 'y']]) 186 | 187 | # Storing dfs for OWA evaluation, initializing min_owa 188 | self.y_train_df = y_df 189 | self.X_test_df = X_test_df 190 | self.y_test_df = y_test_df 191 | self.min_owa = 4.0 192 | self.min_epoch = 0 193 | 194 | # Exogenous variables 195 | unique_categories = X_df['x'].unique() 196 | self.mc.category_to_idx = dict((word, index) for index, word in enumerate(unique_categories)) 197 | self.mc.exogenous_size = len(unique_categories) 198 | 199 | self.unique_ids = X_df['unique_id'].unique() 200 | self.mc.n_series = len(self.unique_ids) 201 | 202 | # Set seeds 203 | self.shuffle = shuffle 204 | torch.manual_seed(self.mc.random_seed) 205 | np.random.seed(self.mc.random_seed) 206 | 207 | # Initial series random assignment to models 208 | self.series_models_map = np.zeros((self.mc.n_series, self.n_models)) 209 | n_initial_models = int(np.ceil(self.n_models/2)) 210 | for i in range(self.mc.n_series): 211 | id_models = np.random.choice(self.n_models, n_initial_models) 212 | self.series_models_map[i,id_models] = 1 213 | 214 | self.esrnn_ensemble = [] 215 | for _ in range(self.n_models): 216 | esrnn = ESRNN(max_epochs=self.mc.max_epochs, batch_size=self.mc.batch_size, batch_size_test=self.mc.batch_size_test, 217 | freq_of_test=-1, learning_rate=self.mc.learning_rate, 218 | lr_scheduler_step_size=self.mc.lr_scheduler_step_size, lr_decay=self.mc.lr_decay, 219 | per_series_lr_multip=self.mc.per_series_lr_multip, 220 | gradient_eps=self.mc.gradient_eps, gradient_clipping_threshold=self.mc.gradient_clipping_threshold, 221 | rnn_weight_decay=self.mc.rnn_weight_decay, noise_std=self.mc.noise_std, 222 | level_variability_penalty=self.mc.level_variability_penalty, 223 | testing_percentile=self.mc.testing_percentile, 224 | training_percentile=self.mc.training_percentile, ensemble=self.mc.ensemble, 225 | cell_type=self.mc.cell_type, 226 | state_hsize=self.mc.state_hsize, dilations=self.mc.dilations, add_nl_layer=self.mc.add_nl_layer, 227 | seasonality=self.mc.seasonality, input_size=self.mc.input_size, output_size=self.mc.output_size, 228 | frequency=self.mc.frequency, max_periods=self.mc.max_periods, random_seed=self.mc.random_seed, 229 | device=self.mc.device, root_dir=self.mc.root_dir) 230 | 231 | # To instantiate _ESRNN object within ESRNN class we need n_series 232 | esrnn.instantiate_esrnn(self.mc.exogenous_size, self.mc.n_series) 233 | esrnn._fitted = True 234 | self.esrnn_ensemble.append(esrnn) 235 | 236 | self.X, self.y = esrnn.long_to_wide(X_df, y_df) 237 | assert len(self.X)==len(self.y) 238 | assert self.X.shape[1]>=3 239 | 240 | # Train model 241 | self._fitted = True 242 | self.train() 243 | 244 | def train(self): 245 | """ 246 | Auxiliary function, pytorch train procedure for the ESRNN ensemble 247 | 248 | Parameters: 249 | ------- 250 | self: instance of self. 251 | 252 | Returns 253 | ------- 254 | self : returns an instance of self. 255 | """ 256 | 257 | # Initial performance matrix 258 | self.performance_matrix = np.ones((self.mc.n_series, self.n_models)) * self.big_float 259 | warm_start = False 260 | train_tau = self.mc.training_percentile/100 261 | criterion = DisaggregatedPinballLoss(train_tau) 262 | 263 | # Train epoch loop 264 | for epoch in range(self.mc.max_epochs): 265 | start = time.time() 266 | 267 | # Solve degenerate models 268 | for model_id in range(self.n_models): 269 | if np.sum(self.series_models_map[:,model_id])==0: 270 | print('Reassigning random series to model ', model_id) 271 | n_sample_series= int(self.mc.n_series/2) 272 | index_series = np.random.choice(self.mc.n_series, n_sample_series, replace=False) 273 | self.series_models_map[index_series, model_id] = 1 274 | 275 | # Model loop 276 | for model_id, esrnn in enumerate(self.esrnn_ensemble): 277 | # Train model with subset data 278 | dataloader = Iterator(mc = self.mc, X=self.X, y=self.y, 279 | weights=self.series_models_map[:, model_id]) 280 | esrnn.train(dataloader, max_epochs=1, warm_start=warm_start, shuffle=self.shuffle, verbose=False) 281 | 282 | # Compute model performance for each series 283 | dataloader = Iterator(mc=self.mc, X=self.X, y=self.y) 284 | per_series_evaluation = esrnn.per_series_evaluation(dataloader, criterion=criterion) 285 | self.performance_matrix[:, model_id] = per_series_evaluation 286 | 287 | # Reassign series to models 288 | self.series_models_map = np.zeros((self.mc.n_series, self.n_models)) 289 | top_models = np.argpartition(self.performance_matrix, self.n_top)[:, :self.n_top] 290 | for i in range(self.mc.n_series): 291 | self.series_models_map[i, top_models[i,:]] = 1 292 | 293 | warm_start = True 294 | 295 | print("========= Epoch {} finished =========".format(epoch)) 296 | print("Training time: {}".format(round(time.time()-start, 5))) 297 | self.train_loss = np.einsum('ij,ij->i',self.performance_matrix, self.series_models_map)/self.n_top 298 | self.train_loss = np.mean(self.train_loss) 299 | print("Training loss ({} prc): {:.5f}".format(self.mc.training_percentile, 300 | self.train_loss)) 301 | print('Models num series', np.sum(self.series_models_map, axis=0)) 302 | 303 | if (epoch % self.mc.freq_of_test == 0) and (self.mc.freq_of_test > 0): 304 | if self.y_test_df is not None: 305 | self.evaluate_model_prediction(self.y_train_df, self.X_test_df, 306 | self.y_test_df, epoch=epoch) 307 | print('Train finished! \n') 308 | 309 | def predict(self, X_df): 310 | """ 311 | Predict using the ESRNN ensemble. 312 | 313 | Parameters 314 | ---------- 315 | X_df : pandas dataframe 316 | Dataframe in LONG format with columns 'unique_id', 'ds' 317 | and 'x'. 318 | - 'unique_id' an identifier of each independent time series. 319 | - 'ds' is a datetime column 320 | - 'x' is a single exogenous variable 321 | 322 | Returns 323 | ------- 324 | Y_hat_panel : pandas dataframe 325 | Dataframe in LONG format with columns 'unique_id', 'ds' 326 | and 'x'. 327 | - 'unique_id' an identifier of each independent time series. 328 | - 'ds' datetime columnn that matches the dates in X_df 329 | - 'y_hat' is the column with the predicted target values 330 | """ 331 | 332 | assert type(X_df) == pd.core.frame.DataFrame 333 | assert 'unique_id' in X_df 334 | assert self._fitted, "Model not fitted yet" 335 | 336 | dataloader = Iterator(mc=self.mc, X=self.X, y=self.y) 337 | 338 | output_size = self.mc.output_size 339 | n_unique_id = len(dataloader.sort_key['unique_id']) 340 | 341 | ensemble_y_hat = np.zeros((self.n_models, n_unique_id, output_size)) 342 | 343 | for model_id, esrnn in enumerate(self.esrnn_ensemble): 344 | esrnn.esrnn.eval() 345 | 346 | # Predict ALL series 347 | count = 0 348 | for j in range(dataloader.n_batches): 349 | batch = dataloader.get_batch() 350 | batch_size = batch.y.shape[0] 351 | 352 | y_hat = esrnn.esrnn.predict(batch) 353 | 354 | y_hat = y_hat.data.cpu().numpy() 355 | 356 | ensemble_y_hat[model_id, count:count+batch_size, :] = y_hat 357 | count += batch_size 358 | 359 | # Weighted average of prediction for n_top best models per series 360 | # (n_models x n_unique_id x output_size) (n_unique_id x n_models) 361 | y_hat = np.einsum('ijk,ji->jk', ensemble_y_hat, self.series_models_map) / self.n_top 362 | y_hat = y_hat.flatten() 363 | 364 | panel_unique_id = pd.Series(dataloader.sort_key['unique_id']).repeat(output_size) 365 | panel_last_ds = pd.Series(dataloader.X[:, 2]).repeat(output_size) 366 | 367 | panel_delta = list(range(1, output_size+1)) * n_unique_id 368 | panel_delta = pd.to_timedelta(panel_delta, unit=self.mc.frequency) 369 | panel_ds = panel_last_ds + panel_delta 370 | 371 | assert len(panel_ds) == len(y_hat) == len(panel_unique_id) 372 | 373 | Y_hat_panel_dict = {'unique_id': panel_unique_id, 374 | 'ds': panel_ds, 375 | 'y_hat': y_hat} 376 | 377 | Y_hat_panel = pd.DataFrame.from_dict(Y_hat_panel_dict) 378 | 379 | if 'ds' in X_df: 380 | Y_hat_panel = X_df.merge(Y_hat_panel, on=['unique_id', 'ds'], how='left') 381 | else: 382 | Y_hat_panel = X_df.merge(Y_hat_panel, on=['unique_id'], how='left') 383 | 384 | return Y_hat_panel 385 | 386 | def evaluate_model_prediction(self, y_train_df, X_test_df, y_test_df, epoch=None): 387 | """ 388 | Evaluate ESRNN model against benchmark in y_test_df 389 | 390 | Parameters 391 | ---------- 392 | y_train_df: pandas dataframe 393 | panel with columns 'unique_id', 'ds', 'y' 394 | X_test_df: pandas dataframe 395 | panel with columns 'unique_id', 'ds', 'x' 396 | y_test_df: pandas dataframe 397 | panel with columns 'unique_id', 'ds', 'y' and a column 398 | y_hat_naive2 identifying benchmark predictions 399 | epoch: int 400 | the number of epoch to check early stopping results 401 | 402 | Returns 403 | ------- 404 | model_owa : float 405 | relative improvement of model with respect to benchmark, measured with 406 | the M4 overall weighted average. 407 | smape: float 408 | relative improvement of model with respect to benchmark, measured with 409 | the symmetric mean absolute percentage error. 410 | mase: float 411 | relative improvement of model with respect to benchmark, measured with 412 | the M4 mean absolute scaled error. 413 | """ 414 | 415 | assert self._fitted, "Model not fitted yet" 416 | 417 | y_panel = y_test_df.filter(['unique_id', 'ds', 'y']) 418 | y_naive2_panel = y_test_df.filter(['unique_id', 'ds', 'y_hat_naive2']) 419 | y_naive2_panel.rename(columns={'y_hat_naive2': 'y_hat'}, inplace=True) 420 | y_hat_panel = self.predict(X_test_df) 421 | y_insample = y_train_df.filter(['unique_id', 'ds', 'y']) 422 | 423 | model_owa, model_mase, model_smape = owa(y_panel, y_hat_panel, 424 | y_naive2_panel, y_insample, 425 | seasonality=self.mc.naive_seasonality) 426 | 427 | if self.min_owa > model_owa: 428 | self.min_owa = model_owa 429 | if epoch is not None: 430 | self.min_epoch = epoch 431 | 432 | print('OWA: {} '.format(np.round(model_owa, 3))) 433 | print('SMAPE: {} '.format(np.round(model_smape, 3))) 434 | print('MASE: {} '.format(np.round(model_mase, 3))) 435 | 436 | return model_owa, model_mase, model_smape 437 | -------------------------------------------------------------------------------- /ESRNN/ESRNN.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | 4 | import numpy as np 5 | import pandas as pd 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.optim as optim 10 | from torch.optim.lr_scheduler import StepLR 11 | 12 | from pathlib import Path 13 | from copy import deepcopy 14 | 15 | from ESRNN.utils.config import ModelConfig 16 | from ESRNN.utils.ESRNN import _ESRNN 17 | from ESRNN.utils.losses import SmylLoss, PinballLoss 18 | from ESRNN.utils.data import Iterator 19 | 20 | from ESRNN.utils_evaluation import owa 21 | 22 | 23 | class ESRNN(object): 24 | """ Exponential Smoothing Recurrent Neural Network 25 | 26 | Pytorch Implementation of the M4 time series forecasting competition winner. 27 | Proposed by Smyl. The model uses a hybrid approach of Machine Learning and 28 | statistical methods by combining recurrent neural networks to model a common 29 | trend with shared parameters across series, and multiplicative Holt-Winter 30 | exponential smoothing. 31 | 32 | Parameters 33 | ---------- 34 | max_epochs: int 35 | maximum number of complete passes to train data during fit 36 | freq_of_test: int 37 | period for the diagnostic evaluation of the model. 38 | learning_rate: float 39 | size of the stochastic gradient descent steps 40 | lr_scheduler_step_size: int 41 | this step_size is the period for each learning rate decay 42 | per_series_lr_multip: float 43 | multiplier for per-series parameters smoothing and initial 44 | seasonalities learning rate (default 1.0) 45 | gradient_eps: float 46 | term added to the Adam optimizer denominator to improve 47 | numerical stability (default: 1e-8) 48 | gradient_clipping_threshold: float 49 | max norm of gradient vector, with all parameters treated 50 | as a single vector 51 | rnn_weight_decay: float 52 | parameter to control classic L2/Tikhonov regularization 53 | of the rnn parameters 54 | noise_std: float 55 | standard deviation of white noise added to input during 56 | fit to avoid the model from memorizing the train data 57 | level_variability_penalty: float 58 | this parameter controls the strength of the penalization 59 | to the wigglines of the level vector, induces smoothness 60 | in the output 61 | testing_percentile: float 62 | This value is only for diagnostic evaluation. 63 | In case of percentile predictions this parameter controls 64 | for the value predicted, when forecasting point value, 65 | the forecast is the median, so percentile=50. 66 | training_percentile: float 67 | To reduce the model's tendency to over estimate, the 68 | training_percentile can be set to fit a smaller value 69 | through the Pinball Loss. 70 | batch_size: int 71 | number of training examples for the stochastic gradient steps 72 | seasonality: int list 73 | list of seasonalities of the time series 74 | Hourly [24, 168], Daily [7], Weekly [52], Monthly [12], 75 | Quarterly [4], Yearly []. 76 | input_size: int 77 | input size of the recurrent neural network, usually a 78 | multiple of seasonality 79 | output_size: int 80 | output_size or forecast horizon of the recurrent neural 81 | network, usually multiple of seasonality 82 | random_seed: int 83 | random_seed for pseudo random pytorch initializer and 84 | numpy random generator 85 | exogenous_size: int 86 | size of one hot encoded categorical variable, invariannt 87 | per time series of the panel 88 | min_inp_seq_length: int 89 | description 90 | max_periods: int 91 | Parameter to chop longer series, to last max_periods, 92 | max e.g. 40 years 93 | cell_type: str 94 | Type of RNN cell, available GRU, LSTM, RNN, ResidualLSTM. 95 | state_hsize: int 96 | dimension of hidden state of the recurrent neural network 97 | dilations: int list 98 | each list represents one chunk of Dilated LSTMS, connected in 99 | standard ResNet fashion 100 | add_nl_layer: bool 101 | whether to insert a tanh() layer between the RNN stack and the 102 | linear adaptor (output) layers 103 | device: str 104 | pytorch device either 'cpu' or 'cuda' 105 | Notes 106 | ----- 107 | **References:** 108 | `M4 Competition Conclusions 109 | `__ 110 | `Original Dynet Implementation of ESRNN 111 | `__ 112 | """ 113 | def __init__(self, max_epochs=15, batch_size=1, batch_size_test=64, freq_of_test=-1, 114 | learning_rate=1e-3, lr_scheduler_step_size=9, lr_decay=0.9, 115 | per_series_lr_multip=1.0, gradient_eps=1e-8, gradient_clipping_threshold=20, 116 | rnn_weight_decay=0, noise_std=0.001, 117 | level_variability_penalty=80, 118 | testing_percentile=50, training_percentile=50, ensemble=False, 119 | cell_type='LSTM', 120 | state_hsize=40, dilations=[[1, 2], [4, 8]], 121 | add_nl_layer=False, seasonality=[4], input_size=4, output_size=8, 122 | frequency=None, max_periods=20, random_seed=1, 123 | device='cpu', root_dir='./'): 124 | super(ESRNN, self).__init__() 125 | self.mc = ModelConfig(max_epochs=max_epochs, batch_size=batch_size, batch_size_test=batch_size_test, 126 | freq_of_test=freq_of_test, learning_rate=learning_rate, 127 | lr_scheduler_step_size=lr_scheduler_step_size, lr_decay=lr_decay, 128 | per_series_lr_multip=per_series_lr_multip, 129 | gradient_eps=gradient_eps, gradient_clipping_threshold=gradient_clipping_threshold, 130 | rnn_weight_decay=rnn_weight_decay, noise_std=noise_std, 131 | level_variability_penalty=level_variability_penalty, 132 | testing_percentile=testing_percentile, training_percentile=training_percentile, 133 | ensemble=ensemble, 134 | cell_type=cell_type, 135 | state_hsize=state_hsize, dilations=dilations, add_nl_layer=add_nl_layer, 136 | seasonality=seasonality, input_size=input_size, output_size=output_size, 137 | frequency=frequency, max_periods=max_periods, random_seed=random_seed, 138 | device=device, root_dir=root_dir) 139 | self._fitted = False 140 | 141 | def train(self, dataloader, max_epochs, 142 | warm_start=False, shuffle=True, verbose=True): 143 | """ 144 | Auxiliary function, pytorch train procedure for the ESRNN model 145 | 146 | Parameters: 147 | ------- 148 | dataloader: pytorch dataloader 149 | max_epochs: int 150 | warm_start: bool 151 | shuffle: bool 152 | verbose: bool 153 | 154 | Returns 155 | ------- 156 | self : returns an instance of self. 157 | """ 158 | 159 | if self.mc.ensemble: 160 | self.esrnn_ensemble = [deepcopy(self.esrnn).to(self.mc.device)] * 5 161 | 162 | if verbose: print(15*'='+' Training ESRNN ' + 15*'=' + '\n') 163 | 164 | # Optimizers 165 | if not warm_start: 166 | self.es_optimizer = optim.Adam(params=self.esrnn.es.parameters(), 167 | lr=self.mc.learning_rate*self.mc.per_series_lr_multip, 168 | betas=(0.9, 0.999), eps=self.mc.gradient_eps) 169 | 170 | self.es_scheduler = StepLR(optimizer=self.es_optimizer, 171 | step_size=self.mc.lr_scheduler_step_size, 172 | gamma=0.9) 173 | 174 | self.rnn_optimizer = optim.Adam(params=self.esrnn.rnn.parameters(), 175 | lr=self.mc.learning_rate, 176 | betas=(0.9, 0.999), eps=self.mc.gradient_eps, 177 | weight_decay=self.mc.rnn_weight_decay) 178 | 179 | self.rnn_scheduler = StepLR(optimizer=self.rnn_optimizer, 180 | step_size=self.mc.lr_scheduler_step_size, 181 | gamma=self.mc.lr_decay) 182 | 183 | # Loss Functions 184 | train_tau = self.mc.training_percentile / 100 185 | train_loss = SmylLoss(tau=train_tau, 186 | level_variability_penalty=self.mc.level_variability_penalty) 187 | 188 | eval_tau = self.mc.testing_percentile / 100 189 | eval_loss = PinballLoss(tau=eval_tau) 190 | 191 | for epoch in range(max_epochs): 192 | self.esrnn.train() 193 | start = time.time() 194 | if shuffle: 195 | dataloader.shuffle_dataset(random_seed=epoch) 196 | losses = [] 197 | for j in range(dataloader.n_batches): 198 | self.es_optimizer.zero_grad() 199 | self.rnn_optimizer.zero_grad() 200 | 201 | batch = dataloader.get_batch() 202 | windows_y, windows_y_hat, levels = self.esrnn(batch) 203 | 204 | # Pinball loss on normalized values 205 | loss = train_loss(windows_y, windows_y_hat, levels) 206 | losses.append(loss.data.cpu().numpy()) 207 | #print("loss", loss) 208 | 209 | loss.backward() 210 | 211 | torch.nn.utils.clip_grad_norm_(self.esrnn.rnn.parameters(), 212 | self.mc.gradient_clipping_threshold) 213 | torch.nn.utils.clip_grad_norm_(self.esrnn.es.parameters(), 214 | self.mc.gradient_clipping_threshold) 215 | self.rnn_optimizer.step() 216 | self.es_optimizer.step() 217 | 218 | # Decay learning rate 219 | self.es_scheduler.step() 220 | self.rnn_scheduler.step() 221 | 222 | if self.mc.ensemble: 223 | copy_esrnn = deepcopy(self.esrnn) 224 | copy_esrnn.eval() 225 | self.esrnn_ensemble.pop(0) 226 | self.esrnn_ensemble.append(copy_esrnn) 227 | 228 | 229 | # Evaluation 230 | self.train_loss = np.mean(losses) 231 | if verbose: 232 | print("========= Epoch {} finished =========".format(epoch)) 233 | print("Training time: {}".format(round(time.time()-start, 5))) 234 | print("Training loss ({} prc): {:.5f}".format(self.mc.training_percentile, 235 | self.train_loss)) 236 | 237 | if (epoch % self.mc.freq_of_test == 0) and (self.mc.freq_of_test > 0): 238 | if self.y_test_df is not None: 239 | self.test_loss = self.model_evaluation(dataloader, eval_loss) 240 | print("Testing loss ({} prc): {:.5f}".format(self.mc.testing_percentile, 241 | self.test_loss)) 242 | self.evaluate_model_prediction(self.y_train_df, self.X_test_df, 243 | self.y_test_df, self.y_hat_benchmark, epoch=epoch) 244 | self.esrnn.train() 245 | 246 | if verbose: print('Train finished! \n') 247 | 248 | def per_series_evaluation(self, dataloader, criterion): 249 | """ 250 | Auxiliary function, evaluate ESRNN model for training 251 | procedure supervision. 252 | 253 | Parameters 254 | ---------- 255 | dataloader: pytorch dataloader 256 | criterion: pytorch test criterion 257 | """ 258 | 259 | with torch.no_grad(): 260 | # Create fast dataloader 261 | if self.mc.n_series < self.mc.batch_size_test: new_batch_size = self.mc.n_series 262 | else: new_batch_size = self.mc.batch_size_test 263 | dataloader.update_batch_size(new_batch_size) 264 | 265 | per_series_losses = [] 266 | for j in range(dataloader.n_batches): 267 | batch = dataloader.get_batch() 268 | windows_y, windows_y_hat, _ = self.esrnn(batch) 269 | loss = criterion(windows_y, windows_y_hat) 270 | per_series_losses += loss.data.cpu().numpy().tolist() 271 | 272 | dataloader.update_batch_size(self.mc.batch_size) 273 | return per_series_losses 274 | 275 | def model_evaluation(self, dataloader, criterion): 276 | """ 277 | Auxiliary function, evaluate ESRNN model for training 278 | procedure supervision. 279 | 280 | Parameters 281 | ---------- 282 | dataloader: pytorch dataloader 283 | criterion: pytorch test criterion 284 | 285 | Returns 286 | ------- 287 | model_loss: float 288 | loss for train supervision purpose. 289 | """ 290 | 291 | with torch.no_grad(): 292 | # Create fast dataloader 293 | if self.mc.n_series < self.mc.batch_size_test: new_batch_size = self.mc.n_series 294 | else: new_batch_size = self.mc.batch_size_test 295 | dataloader.update_batch_size(new_batch_size) 296 | 297 | model_loss = 0.0 298 | for j in range(dataloader.n_batches): 299 | batch = dataloader.get_batch() 300 | windows_y, windows_y_hat, _ = self.esrnn(batch) 301 | loss = criterion(windows_y, windows_y_hat) 302 | model_loss += loss.data.cpu().numpy() 303 | 304 | model_loss /= dataloader.n_batches 305 | dataloader.update_batch_size(self.mc.batch_size) 306 | return model_loss 307 | 308 | def evaluate_model_prediction(self, y_train_df, X_test_df, y_test_df, y_hat_benchmark='y_hat_naive2', epoch=None): 309 | """ 310 | Evaluate ESRNN model against benchmark in y_test_df 311 | 312 | Parameters 313 | ---------- 314 | y_train_df: pandas dataframe 315 | panel with columns 'unique_id', 'ds', 'y' 316 | X_test_df: pandas dataframe 317 | panel with columns 'unique_id', 'ds', 'x' 318 | y_test_df: pandas dataframe 319 | panel with columns 'unique_id', 'ds', 'y' and a column 320 | y_hat_benchmark identifying benchmark predictions 321 | y_hat_benchmark: str 322 | column name of benchmark predictions, default y_hat_naive2 323 | 324 | Returns 325 | ------- 326 | model_owa : float 327 | relative improvement of model with respect to benchmark, measured with 328 | the M4 overall weighted average. 329 | smape: float 330 | relative improvement of model with respect to benchmark, measured with 331 | the symmetric mean absolute percentage error. 332 | mase: float 333 | relative improvement of model with respect to benchmark, measured with 334 | the M4 mean absolute scaled error. 335 | """ 336 | 337 | assert self._fitted, "Model not fitted yet" 338 | 339 | y_panel = y_test_df.filter(['unique_id', 'ds', 'y']) 340 | y_benchmark_panel = y_test_df.filter(['unique_id', 'ds', y_hat_benchmark]) 341 | y_benchmark_panel.rename(columns={y_hat_benchmark: 'y_hat'}, inplace=True) 342 | y_hat_panel = self.predict(X_test_df) 343 | y_insample = y_train_df.filter(['unique_id', 'ds', 'y']) 344 | 345 | model_owa, model_mase, model_smape = owa(y_panel, y_hat_panel, 346 | y_benchmark_panel, y_insample, 347 | seasonality=self.mc.naive_seasonality) 348 | 349 | if self.min_owa > model_owa: 350 | self.min_owa = model_owa 351 | if epoch is not None: 352 | self.min_epoch = epoch 353 | 354 | print('OWA: {} '.format(np.round(model_owa, 3))) 355 | print('SMAPE: {} '.format(np.round(model_smape, 3))) 356 | print('MASE: {} '.format(np.round(model_mase, 3))) 357 | 358 | return model_owa, model_mase, model_smape 359 | 360 | def fit(self, X_df, y_df, X_test_df=None, y_test_df=None, y_hat_benchmark='y_hat_naive2', 361 | warm_start=False, shuffle=True, verbose=True): 362 | """ 363 | Fit ESRNN model. 364 | 365 | Parameters 366 | ---------- 367 | X_df : pandas dataframe 368 | Train dataframe in long format with columns 'unique_id', 'ds' 369 | and 'x'. 370 | - 'unique_id' an identifier of each independent time series. 371 | - 'ds' is a datetime column 372 | - 'x' is a single exogenous variable 373 | y_df : pandas dataframe 374 | Train dataframe in long format with columns 'unique_id', 'ds' and 'y'. 375 | - 'unique_id' an identifier of each independent time series. 376 | - 'ds' is a datetime column 377 | - 'y' is the column with the target values 378 | X_test_df: pandas dataframe 379 | Optional test dataframe with columns 'unique_id', 'ds' and 'x'. 380 | If provided the fit procedure will evaluate the intermediate 381 | performance within training epochs. 382 | y_test_df: pandas dataframe 383 | Optional test dataframe with columns 'unique_id', 'ds' and 'x' and 384 | y_hat_benchmark column. 385 | If provided the fit procedure will evaluate the intermediate 386 | performance within training epochs. 387 | y_hat_benchmark: str 388 | Name of the benchmark model for the comparison of the relative 389 | improvement of the model. 390 | 391 | Returns 392 | ------- 393 | self : returns an instance of self. 394 | """ 395 | 396 | # Transform long dfs to wide numpy 397 | assert type(X_df) == pd.core.frame.DataFrame 398 | assert type(y_df) == pd.core.frame.DataFrame 399 | assert all([(col in X_df) for col in ['unique_id', 'ds', 'x']]) 400 | assert all([(col in y_df) for col in ['unique_id', 'ds', 'y']]) 401 | if y_test_df is not None: 402 | assert y_hat_benchmark in y_test_df.columns, 'benchmark is not present in y_test_df, use y_hat_benchmark to define it' 403 | 404 | # Storing dfs for OWA evaluation, initializing min_owa 405 | self.y_train_df = y_df 406 | self.X_test_df = X_test_df 407 | self.y_test_df = y_test_df 408 | self.min_owa = 4.0 409 | self.min_epoch = 0 410 | 411 | self.int_ds = isinstance(self.y_train_df['ds'][0], (int, np.int, np.int64)) 412 | 413 | self.y_hat_benchmark = y_hat_benchmark 414 | 415 | X, y = self.long_to_wide(X_df, y_df) 416 | assert len(X)==len(y) 417 | assert X.shape[1]>=3 418 | 419 | # Exogenous variables 420 | unique_categories = np.unique(X[:, 1]) 421 | self.mc.category_to_idx = dict((word, index) for index, word in enumerate(unique_categories)) 422 | exogenous_size = len(unique_categories) 423 | 424 | # Create batches (device in mc) 425 | self.train_dataloader = Iterator(mc=self.mc, X=X, y=y) 426 | 427 | # Random Seeds (model initialization) 428 | torch.manual_seed(self.mc.random_seed) 429 | np.random.seed(self.mc.random_seed) 430 | 431 | # Initialize model 432 | n_series = self.train_dataloader.n_series 433 | self.instantiate_esrnn(exogenous_size, n_series) 434 | 435 | # Validating frequencies 436 | X_train_frequency = pd.infer_freq(X_df.head()['ds']) 437 | y_train_frequency = pd.infer_freq(y_df.head()['ds']) 438 | self.frequencies = [X_train_frequency, y_train_frequency] 439 | 440 | if (X_test_df is not None) and (y_test_df is not None): 441 | X_test_frequency = pd.infer_freq(X_test_df.head()['ds']) 442 | y_test_frequency = pd.infer_freq(y_test_df.head()['ds']) 443 | self.frequencies += [X_test_frequency, y_test_frequency] 444 | 445 | assert len(set(self.frequencies)) <= 1, \ 446 | "Match the frequencies of the dataframes {}".format(self.frequencies) 447 | 448 | self.mc.frequency = self.frequencies[0] 449 | print("Infered frequency: {}".format(self.mc.frequency)) 450 | 451 | # Train model 452 | self._fitted = True 453 | self.train(dataloader=self.train_dataloader, max_epochs=self.mc.max_epochs, 454 | warm_start=warm_start, shuffle=shuffle, verbose=verbose) 455 | 456 | def instantiate_esrnn(self, exogenous_size, n_series): 457 | """Auxiliary function used at beginning of train to instantiate ESRNN""" 458 | 459 | self.mc.exogenous_size = exogenous_size 460 | self.mc.n_series = n_series 461 | self.esrnn = _ESRNN(self.mc).to(self.mc.device) 462 | 463 | def predict(self, X_df, decomposition=False): 464 | """ 465 | Predict using the ESRNN model. 466 | 467 | Parameters 468 | ---------- 469 | X_df : pandas dataframe 470 | Dataframe in LONG format with columns 'unique_id', 'ds' 471 | and 'x'. 472 | - 'unique_id' an identifier of each independent time series. 473 | - 'ds' is a datetime column 474 | - 'x' is a single exogenous variable 475 | 476 | Returns 477 | ------- 478 | Y_hat_panel : pandas dataframe 479 | Dataframe in LONG format with columns 'unique_id', 'ds' 480 | and 'x'. 481 | - 'unique_id' an identifier of each independent time series. 482 | - 'ds' datetime columnn that matches the dates in X_df 483 | - 'y_hat' is the column with the predicted target values 484 | """ 485 | 486 | #print(9*'='+' Predicting ESRNN ' + 9*'=' + '\n') 487 | assert type(X_df) == pd.core.frame.DataFrame 488 | assert 'unique_id' in X_df 489 | assert self._fitted, "Model not fitted yet" 490 | 491 | self.esrnn.eval() 492 | 493 | # Create fast dataloader 494 | if self.mc.n_series < self.mc.batch_size_test: new_batch_size = self.mc.n_series 495 | else: new_batch_size = self.mc.batch_size_test 496 | self.train_dataloader.update_batch_size(new_batch_size) 497 | dataloader = self.train_dataloader 498 | 499 | # Create Y_hat_panel placeholders 500 | output_size = self.mc.output_size 501 | n_unique_id = len(dataloader.sort_key['unique_id']) 502 | panel_unique_id = pd.Series(dataloader.sort_key['unique_id']).repeat(output_size) 503 | 504 | #access column with last train date 505 | panel_last_ds = pd.Series(dataloader.X[:, 2]) 506 | panel_ds = [] 507 | for i in range(len(panel_last_ds)): 508 | ranges = pd.date_range(start=panel_last_ds[i], periods=output_size+1, freq=self.mc.frequency) 509 | panel_ds += list(ranges[1:]) 510 | 511 | panel_y_hat= np.zeros((output_size * n_unique_id)) 512 | 513 | # Predict 514 | count = 0 515 | for j in range(dataloader.n_batches): 516 | batch = dataloader.get_batch() 517 | batch_size = batch.y.shape[0] 518 | 519 | if self.mc.ensemble: 520 | y_hat = torch.zeros((5,batch_size,output_size)) 521 | for i in range(5): 522 | y_hat[i,:,:] = self.esrnn_ensemble[i].predict(batch) 523 | y_hat = torch.mean(y_hat,0) 524 | else: 525 | y_hat = self.esrnn.predict(batch) 526 | 527 | y_hat = y_hat.data.cpu().numpy() 528 | 529 | panel_y_hat[count:count+output_size*batch_size] = y_hat.flatten() 530 | count += output_size*batch_size 531 | 532 | Y_hat_panel_dict = {'unique_id': panel_unique_id, 533 | 'ds': panel_ds, 534 | 'y_hat': panel_y_hat} 535 | 536 | assert len(panel_ds) == len(panel_y_hat) == len(panel_unique_id) 537 | 538 | Y_hat_panel = pd.DataFrame.from_dict(Y_hat_panel_dict) 539 | 540 | if 'ds' in X_df: 541 | Y_hat_panel = X_df.merge(Y_hat_panel, on=['unique_id', 'ds'], how='left') 542 | else: 543 | Y_hat_panel = X_df.merge(Y_hat_panel, on=['unique_id'], how='left') 544 | 545 | self.train_dataloader.update_batch_size(self.mc.batch_size) 546 | return Y_hat_panel 547 | 548 | def long_to_wide(self, X_df, y_df): 549 | """ 550 | Auxiliary function to wrangle LONG format dataframes 551 | to a wide format compatible with ESRNN inputs. 552 | 553 | Parameters 554 | ---------- 555 | X_df : pandas dataframe 556 | Dataframe in long format with columns 'unique_id', 'ds' 557 | and 'x'. 558 | - 'unique_id' an identifier of each independent time series. 559 | - 'ds' is a datetime column 560 | - 'x' is a single exogenous variable 561 | y_df : pandas dataframe 562 | Dataframe in long format with columns 'unique_id', 'ds' and 'y'. 563 | - 'unique_id' an identifier of each independent time series. 564 | - 'ds' is a datetime column 565 | - 'y' is the column with the target values 566 | 567 | Returns 568 | ------- 569 | X: numpy array, shape (n_unique_ids, n_time) 570 | y: numpy array, shape (n_unique_ids, n_time) 571 | """ 572 | data = X_df.copy() 573 | data['y'] = y_df['y'].copy() 574 | sorted_ds = np.sort(data['ds'].unique()) 575 | ds_map = {} 576 | for dmap, t in enumerate(sorted_ds): 577 | ds_map[t] = dmap 578 | data['ds_map'] = data['ds'].map(ds_map) 579 | data = data.sort_values(by=['ds_map','unique_id']) 580 | df_wide = data.pivot(index='unique_id', columns='ds_map')['y'] 581 | 582 | x_unique = data[['unique_id', 'x']].groupby('unique_id').first() 583 | last_ds = data[['unique_id', 'ds']].groupby('unique_id').last() 584 | assert len(x_unique)==len(data.unique_id.unique()) 585 | df_wide['x'] = x_unique 586 | df_wide['last_ds'] = last_ds 587 | df_wide = df_wide.reset_index().rename_axis(None, axis=1) 588 | 589 | ds_cols = data.ds_map.unique().tolist() 590 | X = df_wide.filter(items=['unique_id', 'x', 'last_ds']).values 591 | y = df_wide.filter(items=ds_cols).values 592 | 593 | return X, y 594 | 595 | def get_dir_name(self, root_dir=None): 596 | """Auxiliary function to save ESRNN model""" 597 | if not root_dir: 598 | assert self.mc.root_dir 599 | root_dir = self.mc.root_dir 600 | 601 | data_dir = self.mc.dataset_name 602 | model_parent_dir = os.path.join(root_dir, data_dir) 603 | model_path = ['esrnn_{}'.format(str(self.mc.copy))] 604 | model_dir = os.path.join(model_parent_dir, '_'.join(model_path)) 605 | return model_dir 606 | 607 | def save(self, model_dir=None, copy=None): 608 | """Auxiliary function to save ESRNN model""" 609 | if copy is not None: 610 | self.mc.copy = copy 611 | 612 | if not model_dir: 613 | assert self.mc.root_dir 614 | model_dir = self.get_dir_name() 615 | 616 | if not os.path.exists(model_dir): 617 | os.makedirs(model_dir) 618 | 619 | rnn_filepath = os.path.join(model_dir, "rnn.model") 620 | es_filepath = os.path.join(model_dir, "es.model") 621 | 622 | print('Saving model to:\n {}'.format(model_dir)+'\n') 623 | torch.save({'model_state_dict': self.es.state_dict()}, es_filepath) 624 | torch.save({'model_state_dict': self.rnn.state_dict()}, rnn_filepath) 625 | 626 | def load(self, model_dir=None, copy=None): 627 | """Auxiliary function to load ESRNN model""" 628 | if copy is not None: 629 | self.mc.copy = copy 630 | 631 | if not model_dir: 632 | assert self.mc.root_dir 633 | model_dir = self.get_dir_name() 634 | 635 | rnn_filepath = os.path.join(model_dir, "rnn.model") 636 | es_filepath = os.path.join(model_dir, "es.model") 637 | path = Path(es_filepath) 638 | 639 | if path.is_file(): 640 | print('Loading model from:\n {}'.format(model_dir)+'\n') 641 | 642 | checkpoint = torch.load(es_filepath, map_location=self.mc.device) 643 | self.es.load_state_dict(checkpoint['model_state_dict']) 644 | self.es.to(self.mc.device) 645 | 646 | checkpoint = torch.load(rnn_filepath, map_location=self.mc.device) 647 | self.rnn.load_state_dict(checkpoint['model_state_dict']) 648 | self.rnn.to(self.mc.device) 649 | else: 650 | print('Model path {} does not exist'.format(path)) 651 | --------------------------------------------------------------------------------