├── .gitignore
├── LICENSE
├── README.md
├── deep4cast
├── __init__.py
├── custom_layers.py
├── datasets.py
├── forecasters.py
├── metrics.py
├── models.py
└── transforms.py
├── docs
├── Makefile
├── conf.py
├── custom_layers.rst
├── datasets.rst
├── examples
│ └── m4daily.ipynb
├── forecasters.rst
├── get_started.rst
├── images
│ └── thumb.jpg
├── index.rst
├── metrics.rst
├── models.rst
├── requirements.txt
└── transforms.rst
├── requirements.txt
└── setup.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Personal files
2 | *personal*
3 |
4 | # Data files
5 | *.pkl
6 | *.tsv
7 | *.xlsx
8 | *.xls
9 | *.ss
10 | *.rds
11 | *.db
12 | *.log
13 |
14 | # Byte-compiled / optimized / DLL files
15 | __pycache__/
16 | *.py[cod]
17 | *$py.class
18 |
19 | # C extensions
20 | *.so
21 |
22 | # Distribution / packaging
23 | .Python
24 | env/
25 | build/
26 | develop-eggs/
27 | dist/
28 | downloads/
29 | eggs/
30 | .eggs/
31 | lib/
32 | lib64/
33 | parts/
34 | sdist/
35 | var/
36 | wheels/
37 | *.egg-info/
38 | .installed.cfg
39 | *.egg
40 |
41 | # Scrapy stuff:
42 | .scrapy
43 |
44 | # Jupyter Notebook
45 | .ipynb_checkpoints
46 |
47 | # Misc
48 | __*/
49 | _*/
50 | .bin/
51 | .idea/
52 | *.pptx
53 | *.ppt
54 | *.docx
55 | *.doc
56 | .DS_Store
57 | .pytest_cache
58 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2014-2018, Kenneth Tran=
2 | All rights reserved.
3 |
4 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following
5 | conditions are met:
6 |
7 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
8 |
9 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
10 |
11 | 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived
12 | from this software without specific prior written permission.
13 |
14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING,
15 | BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT
16 | SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
17 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
18 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE
19 | OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
20 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Deep4cast: Forecasting for Decision Making under Uncertainty
2 |
3 |
4 |
5 | ***This package is under active development. Things may change :-).***
6 |
7 | ``Deep4Cast`` is a scalable machine learning package implemented in ``Python`` and ``Torch``. It has a front-end API similar to ``scikit-learn``. It is designed for medium to large time series data sets and allows for modeling of forecast uncertainties.
8 |
9 | The network architecture is based on ``WaveNet``. Regularization and approximate sampling from posterior predictive distributions of forecasts are achieved via ``Concrete Dropout``.
10 |
11 | Documentation is available at [read the docs](https://deep4cast.readthedocs.io/en/latest/).
12 |
13 | ## Installation
14 |
15 | ### Main Requirements
16 | - [python](http://python.org) - version 3.6
17 | - [pytorch](http://pytorch.org) - version 1.0
18 |
19 | ### Source
20 | Before installing we recommend setting up a clean [virtual environment](https://docs.python.org/3.6/tutorial/venv.html).
21 |
22 | From the package directory install the requirements and then the package.
23 | ```
24 | $ pip install -r requirements.txt
25 | $ python setup.py install
26 | ```
27 |
28 | ## Examples
29 | - [Tutorial Notebooks](https://github.com/MSRDL/Deep4Cast/blob/master/docs/examples)
30 |
31 | ## Authors:
32 | - [Toby Bischoff](http://github.com/bischtob)
33 | - Austin Gross
34 | - [Kenneth Tran](http://www.kentran.net)
35 |
36 | ## References:
37 | - [Concrete Dropout](https://arxiv.org/pdf/1705.07832.pdf) is used for approximate posterior Bayesian inference.
38 | - [Wavenet](https://arxiv.org/pdf/1609.03499.pdf) is used as encoder network.
39 |
--------------------------------------------------------------------------------
/deep4cast/__init__.py:
--------------------------------------------------------------------------------
1 | from .forecasters import *
2 | from .models import *
3 |
4 |
--------------------------------------------------------------------------------
/deep4cast/custom_layers.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 |
5 | class ConcreteDropout(torch.nn.Module):
6 | """Applies Dropout to the input, even at prediction time and learns dropout probability
7 | from the data.
8 |
9 | In convolutional neural networks, we can use dropout to drop entire channels using
10 | the 'channel_wise' argument.
11 |
12 | Arguments:
13 | * dropout_regularizer (float): Should be set to 2 / N, where N is the number of training examples.
14 | * init_range (tuple): Initial range for dropout probabilities.
15 | * channel_wise (boolean): apply dropout over all input or across convolutional channels.
16 |
17 | """
18 | def __init__(self,
19 | dropout_regularizer=1e-5,
20 | init_range=(0.1, 0.3),
21 | channel_wise=False):
22 | super(ConcreteDropout, self).__init__()
23 | self.dropout_regularizer = dropout_regularizer
24 | self.init_range = init_range
25 | self.channel_wise = channel_wise
26 |
27 | # Initialize dropout probability
28 | init_min = np.log(init_range[0]) - np.log(1. - init_range[0])
29 | init_max = np.log(init_range[1]) - np.log(1. - init_range[1])
30 | self.p_logit = torch.nn.Parameter(
31 | torch.empty(1).uniform_(init_min, init_max))
32 |
33 | def forward(self, x):
34 | """Returns input but with randomly dropped out values."""
35 | # Get the dropout probability
36 | p = torch.sigmoid(self.p_logit)
37 |
38 | # Apply Concrete Dropout to input
39 | out = self._concrete_dropout(x, p)
40 |
41 | # Regularization term for dropout parameters
42 | dropout_regularizer = p * torch.log(p)
43 | dropout_regularizer += (1. - p) * torch.log(1. - p)
44 |
45 | # The size of the dropout regularization depends on the kind of input
46 | if self.channel_wise:
47 | # Dropout only applied to channel dimension
48 | input_dim = x.shape[1]
49 | else:
50 | # Dropout applied to all dimensions
51 | input_dim = np.prod(x.shape[1:])
52 | dropout_regularizer *= self.dropout_regularizer * input_dim
53 |
54 | return out, dropout_regularizer.mean()
55 |
56 | def _concrete_dropout(self, x, p):
57 | # Empirical parameters for the concrete distribution
58 | eps = 1e-7
59 | temp = 0.1
60 |
61 | # Apply Concrete dropout channel wise or across all input
62 | if self.channel_wise:
63 | unif_noise = torch.rand_like(x[:, :, [0]])
64 | else:
65 | unif_noise = torch.rand_like(x)
66 |
67 | drop_prob = (torch.log(p + eps)
68 | - torch.log(1 - p + eps)
69 | + torch.log(unif_noise + eps)
70 | - torch.log(1 - unif_noise + eps))
71 | drop_prob = torch.sigmoid(drop_prob / temp)
72 | random_tensor = 1 - drop_prob
73 |
74 | # Need to make sure we have the right shape for the Dropout mask
75 | if self.channel_wise:
76 | random_tensor = random_tensor.repeat([1, 1, x.shape[2]])
77 |
78 | # Drop weights
79 | retain_prob = 1 - p
80 | x = torch.mul(x, random_tensor)
81 | x /= retain_prob
82 |
83 | return x
84 |
85 |
--------------------------------------------------------------------------------
/deep4cast/datasets.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from torch.utils.data import Dataset
3 |
4 | from deep4cast import transforms
5 |
6 |
7 | class TimeSeriesDataset(Dataset):
8 | """Takes a list of time series and provides access to windowed subseries for
9 | training.
10 |
11 | Arguments:
12 | * time_series (list): List of time series ``numpy`` arrays.
13 | * lookback (int): Number of time steps used as input for forecasting.
14 | * horizon (int): Number of time steps to forecast.
15 | * step (int): Time step size between consecutive examples.
16 | * transform (``transforms.Compose``): Specific transformations to apply to time series examples.
17 | * static_covs (list): Static covariates for each item in ``time_series`` list.
18 | * thinning (float): Fraction of examples to include.
19 |
20 | """
21 | def __init__(self,
22 | time_series,
23 | lookback,
24 | horizon,
25 | step,
26 | transform,
27 | static_covs=None,
28 | thinning=1.0):
29 | self.time_series = time_series
30 | self.lookback = lookback
31 | self.horizon = horizon
32 | self.step = step
33 | self.transform = transform
34 | self.static_covs = static_covs
35 |
36 | # Slice each time series into examples, assigning IDs to each
37 | last_id = 0
38 | n_dropped = 0
39 | self.example_ids = {}
40 | for i, ts in enumerate(self.time_series):
41 | num_examples = (ts.shape[-1] - self.lookback - self.horizon + self.step) // self.step
42 | # Time series shorter than the forecast horizon need to be dropped.
43 | if ts.shape[-1] < self.horizon:
44 | n_dropped += 1
45 | continue
46 | # For short time series zero pad the input
47 | if ts.shape[-1] < self.lookback + self.horizon:
48 | num_examples = 1
49 | for j in range(num_examples):
50 | self.example_ids[last_id + j] = (i, j * self.step)
51 | last_id += num_examples
52 |
53 | # Inform user about time series that were too short
54 | if n_dropped > 0:
55 | print("Dropped {}/{} time series due to length.".format(
56 | n_dropped, len(self.time_series)
57 | )
58 | )
59 |
60 | # Store the number of training examples
61 | self._len = int(self.example_ids.__len__() * thinning)
62 |
63 | def __len__(self):
64 | return self._len
65 |
66 | def __getitem__(self, idx):
67 | # Get time series
68 | ts_id, lookback_id = self.example_ids[idx]
69 | ts = self.time_series[ts_id]
70 |
71 | # Prepare input and target. Zero pad if necessary.
72 | if ts.shape[-1] < self.lookback + self.horizon:
73 | # If the time series is too short, we zero pad
74 | X = ts[:, :-self.horizon]
75 | X = np.pad(
76 | X,
77 | pad_width=((0, 0), (self.lookback - X.shape[-1], 0)),
78 | mode='constant',
79 | constant_values=0
80 | )
81 | y = ts[:, -self.horizon:]
82 | else:
83 | X = ts[:, lookback_id:lookback_id + self.lookback]
84 | y = ts[:, lookback_id + self.lookback:lookback_id + self.lookback + self.horizon]
85 |
86 | # Create the input and output for the sample
87 | sample = {'X': X, 'y': y}
88 | sample = self.transform(sample)
89 |
90 | # Static covariates can be attached
91 | if self.static_covs is not None:
92 | sample['X_stat'] = self.static_covs[ts_id]
93 |
94 | return sample
95 |
--------------------------------------------------------------------------------
/deep4cast/forecasters.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import time
3 |
4 | import numpy as np
5 | import torch
6 |
7 |
8 | class Forecaster():
9 | """Handles training of a PyTorch model and can be used to generate samples
10 | from approximate posterior predictive distribution.
11 |
12 | Arguments:
13 | * model (``torch.nn.Module``): Instance of Deep4cast :class:`models`.
14 | * loss (``torch.distributions``): Instance of PyTorch `distribution `_.
15 | * optimizer (``torch.optim``): Instance of PyTorch `optimizer `_.
16 | * n_epochs (int): Number of training epochs.
17 | * device (str): Device used for training (`cpu` or `cuda`).
18 | * checkpoint_path (str): File system path for writing model checkpoints.
19 | * verbose (bool): Verbosity of forecaster.
20 |
21 | """
22 | def __init__(self,
23 | model,
24 | loss,
25 | optimizer,
26 | n_epochs=1,
27 | device='cpu',
28 | checkpoint_path='./',
29 | verbose=True):
30 | self.device = device if torch.cuda.is_available() and 'cuda' in device else 'cpu'
31 | self.model = model.to(device)
32 | self.optimizer = optimizer
33 | self.n_epochs = n_epochs
34 | self.loss = loss
35 | self.history = {'training': [], 'validation': []}
36 | self.checkpoint_path = checkpoint_path
37 | self.verbose = verbose
38 |
39 | def fit(self,
40 | dataloader_train,
41 | dataloader_val=None,
42 | eval_model=False):
43 | """Fits a model to a given a dataset.
44 |
45 | Arguments:
46 | * dataloader_train (``torch.utils.data.DataLoader``): Training data.
47 | * dataloader_val (``torch.utils.data.DataLoader``): Validation data.
48 | * eval_model (bool): Flag to switch on model evaluation after every epoch.
49 |
50 | """
51 | # Iterate over training epochs
52 | start_time = time.time()
53 | for epoch in range(1, self.n_epochs + 1):
54 | self._train(dataloader_train, epoch, start_time)
55 | self._save_checkpoint()
56 | if eval_model:
57 | train_loss = self._evaluate(dataloader_train)
58 | if self.verbose: print('\nTraining error: {:1.2e}.'.format(train_loss))
59 | self.history['training'].append(train_loss)
60 | if dataloader_val:
61 | val_loss = self._evaluate(dataloader_val)
62 | if self.verbose: print('Validation error: {:1.2e}\n.'.format(val_loss))
63 | self.history['validation'].append(val_loss)
64 |
65 | def _train(self, dataloader, epoch, start_time):
66 | """Perform training for one epoch.
67 |
68 | Arguments:
69 | * dataloader (``torch.utils.data.DataLoader``): Training data.
70 | * epoch (int): Current training epoch.
71 | * start_time (``time.time``): Clock time of training start.
72 |
73 | """
74 | n_trained = 0
75 | for idx, batch in enumerate(dataloader):
76 | # Send batch to device
77 | inputs = batch['X'].to(self.device)
78 | targets = batch['y'].to(self.device)
79 |
80 | # Backpropagation
81 | self.optimizer.zero_grad()
82 | outputs = self.model(inputs)
83 | reg = outputs.pop('regularizer')
84 | loss = -self.loss(**outputs).log_prob(targets).mean() + reg
85 | if torch.isnan(loss.mean()):
86 | raise ValueError('NaN in training loss.')
87 | loss.mean().backward()
88 | self.optimizer.step()
89 |
90 | # Status update for the user
91 | if self.verbose:
92 | n_trained += len(inputs)
93 | n_total = len(dataloader.dataset)
94 | percentage = 100.0 * (idx + 1) / len(dataloader)
95 | elapsed = time.time() - start_time
96 | remaining = elapsed*((self.n_epochs*n_total)/((epoch-1)*n_total + n_trained) - 1)
97 | status = '\rEpoch {}/{} [{}/{} ({:.0f}%)]\t' \
98 | + 'Loss: {:.6f}\t' \
99 | + 'Elapsed/Remaining: {:.0f}m{:.0f}s/{:.0f}m{:.0f}s '
100 | print(
101 | status.format(
102 | epoch,
103 | self.n_epochs,
104 | n_trained,
105 | n_total,
106 | percentage,
107 | loss.mean().item(),
108 | elapsed // 60,
109 | elapsed % 60,
110 | remaining // 60,
111 | remaining % 60
112 | ),
113 | end=""
114 | )
115 |
116 | def _evaluate(self, dataloader, n_samples=10):
117 | """Returns the approximate min negative log likelihood of the model
118 | averaged over dataset.
119 |
120 | Arguments:
121 | * dataloader (``torch.utils.data.DataLoader``): Evaluation data.
122 | * n_samples (int): Number of forecast samples.
123 |
124 | """
125 | max_llikelihood = [0]*n_samples
126 | with torch.no_grad():
127 | for batch in dataloader:
128 | inputs = batch['X'].to(self.device)
129 | targets = batch['y'].to(self.device)
130 |
131 | # Forward pass through the model
132 | outputs = self.model(inputs)
133 | outputs.pop('regularizer')
134 |
135 | # Calculate loss (typically probability density)
136 | for i in range(n_samples):
137 | loss = self.loss(**outputs).log_prob(targets)
138 | max_llikelihood[i] += loss.sum().item()
139 | max_llikelihood = np.max(max_llikelihood)
140 |
141 | return -max_llikelihood / len(dataloader.dataset)
142 |
143 | def predict(self, dataloader, n_samples=100) -> np.array:
144 | """Generates predictions.
145 |
146 | Arguments:
147 | * dataloader (``torch.utils.data.DataLoader``): Data to make forecasts.
148 | * n_samples (int): Number of forecast samples.
149 |
150 | """
151 | with torch.no_grad():
152 | predictions = []
153 | for batch in dataloader:
154 | inputs = batch['X'].to(self.device)
155 | samples = []
156 | for i in range(n_samples):
157 | outputs = self.model(inputs)
158 | outputs.pop('regularizer')
159 | outputs = self.loss(**outputs).sample((1,)).cpu()
160 | batch['y'] = outputs[0]
161 | outputs = copy.deepcopy(batch)
162 | outputs = dataloader.dataset.transform.untransform(outputs)
163 | samples.append(outputs['y'][None, :])
164 | samples = np.concatenate(samples, axis=0)
165 | predictions.append(samples)
166 | predictions = np.concatenate(predictions, axis=1)
167 |
168 | return predictions
169 |
170 | def embed(self, dataloader, n_samples=100) -> np.array:
171 | """Generate embedding vectors.
172 |
173 | Arguments:
174 | * dataloader (``torch.utils.data.DataLoader``): Data to make embedding vectors.
175 | * n_samples (int): Number of forecast samples.
176 |
177 | """
178 | with torch.no_grad():
179 | predictions = []
180 | for batch in dataloader:
181 | inputs = batch['X'].to(self.device)
182 | samples = []
183 | for i in range(n_samples):
184 | outputs, __ = self.model.encode(inputs)
185 | samples.append(outputs.cpu().numpy())
186 | samples = np.array(samples)
187 | predictions.append(samples)
188 | predictions = np.concatenate(predictions, axis=1)
189 |
190 | return predictions
191 |
192 | def _save_checkpoint(self):
193 | """Save a complete PyTorch model checkpoint."""
194 | filename = self.checkpoint_path
195 | filename += 'checkpoint_model.pt'
196 | save_dict = {}
197 | save_dict['model_def'] = self.model
198 | save_dict['optimizer_state_dict'] = self.optimizer.state_dict()
199 | save_dict['loss'] = self.loss
200 | torch.save(save_dict, filename)
201 |
202 |
--------------------------------------------------------------------------------
/deep4cast/metrics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import warnings
3 |
4 |
5 | def mae(data_samples, data_truth, agg=None, **kwargs) -> np.array:
6 | """Computes mean absolute error (MAE)
7 |
8 | Arguments:
9 | * data_samples (``np.array``): Sampled predictions (n_samples, n_timeseries, n_variables, n_timesteps).
10 | * data_truth (``np.array``): Ground truth time series values (n_timeseries, n_variables, n_timesteps).
11 | * agg: Aggregation function applied to sampled predictions (defaults to ``np.median``).
12 |
13 | """
14 | if data_samples.shape[1:] != data_truth.shape:
15 | raise ValueError('Last three dimensions of data_samples and data_truth need to be compatible')
16 | agg = np.median if not agg else agg
17 |
18 | # Aggregate over samples
19 | data = agg(data_samples, axis=0)
20 |
21 | return np.mean(np.abs(data - data_truth), axis=(1, 2))
22 |
23 |
24 | def mape(data_samples, data_truth, agg=None, **kwargs) -> np.array:
25 | """Computes mean absolute percentage error (MAPE)
26 |
27 | Arguments:
28 | * data_samples (``np.array``): Sampled predictions (n_samples, n_timeseries, n_variables, n_timesteps).
29 | * data_truth (``np.array``): Ground truth time series values (n_timeseries, n_variables, n_timesteps).
30 | * agg: Aggregation function applied to sampled predictions (defaults to ``np.median``).
31 |
32 | """
33 | if data_samples.shape[1:] != data_truth.shape:
34 | raise ValueError('Last three dimensions of data_samples and data_truth need to be compatible')
35 | agg = np.median if not agg else agg
36 |
37 | # Aggregate over samples
38 | data = agg(data_samples, axis=0)
39 |
40 | norm = np.abs(data_truth)
41 |
42 | return np.mean(np.abs(data - data_truth) / norm, axis=(1, 2)) * 100.0
43 |
44 |
45 | def mase(data_samples,
46 | data_truth,
47 | data_insample,
48 | frequencies,
49 | agg=None,
50 | **kwargs) -> np.array:
51 | """Computes mean absolute scaled error (MASE) as in the `M4 competition
52 | `_.
53 |
54 | Arguments:
55 | * data_samples (``np.array``): Sampled predictions (n_samples, n_timeseries, n_variables, n_timesteps).
56 | * data_truth (``np.array``): Ground truth time series values (n_timeseries, n_variables, n_timesteps).
57 | * data_insample (``np.array``): In-sample time series data (n_timeseries, n_variables, n_timesteps).
58 | * frequencies (list): Frequencies to be used when calculating the naive forecast.
59 | * agg: Aggregation function applied to sampled predictions (defaults to ``np.median``).
60 |
61 | """
62 | if data_samples.shape[1:] != data_truth.shape:
63 | raise ValueError('Last three dimensions of data_samples and data_truth need to be compatible')
64 | agg = np.median if not agg else agg
65 |
66 | # Calculate mean absolute for forecast and naive forecast per time series
67 | errs, naive_errs = [], []
68 | for i in range(data_samples.shape[1]):
69 | ts_sample = data_samples[:, i]
70 | ts_truth = data_truth[i]
71 | ts = data_insample[i]
72 | freq = int(frequencies[i])
73 |
74 | data = agg(ts_sample, axis=0)
75 |
76 | # Build mean absolute error
77 | err = np.mean(np.abs(data - ts_truth))
78 |
79 | # naive forecast is calculated using insample
80 | t_in = ts.shape[-1]
81 | naive_forecast = ts[:, :t_in-freq]
82 | naive_target = ts[:, freq:]
83 | err_naive = np.mean(np.abs(naive_target - naive_forecast))
84 |
85 | errs.append(err)
86 | naive_errs.append(err_naive)
87 |
88 | errs = np.array(errs)
89 | naive_errs = np.array(naive_errs)
90 |
91 | return errs / naive_errs
92 |
93 |
94 | def smape(data_samples, data_truth, agg=None, **kwargs) -> np.array:
95 | """Computes symmetric mean absolute percentage error (SMAPE) on the mean
96 |
97 | Arguments:
98 | * data_samples (``np.array``): Sampled predictions (n_samples, n_timeseries, n_variables, n_timesteps).
99 | * data_truth (``np.array``): Ground truth time series values (n_timeseries, n_variables, n_timesteps).
100 | * agg: Aggregation function applied to sampled predictions (defaults to ``np.median``).
101 |
102 | """
103 | if data_samples.shape[1:] != data_truth.shape:
104 | raise ValueError('Last three dimensions of data_samples and data_truth need to be compatible')
105 | agg = np.median if not agg else agg
106 |
107 | # Aggregate over samples
108 | data = agg(data_samples, axis=0)
109 |
110 | eps = 1e-16 # Need to make sure that denominator is not zero
111 | norm = 0.5 * (np.abs(data) + np.abs(data_truth)) + eps
112 |
113 | return np.mean(np.abs(data - data_truth) / norm, axis=(1, 2)) * 100
114 |
115 |
116 | def mse(data_samples, data_truth, agg=None, **kwargs) -> np.array:
117 | """Computes mean squared error (MSE)
118 |
119 | Arguments:
120 | * data_samples (``np.array``): Sampled predictions (n_samples, n_timeseries, n_variables, n_timesteps).
121 | * data_truth (``np.array``): Ground truth time series values (n_timeseries, n_variables, n_timesteps).
122 | * agg: Aggregation function applied to sampled predictions (defaults to ``np.median``).
123 |
124 | """
125 | if data_samples.shape[1:] != data_truth.shape:
126 | raise ValueError('Last three dimensions of data_samples and data_truth need to be compatible')
127 | agg = np.median if not agg else agg
128 |
129 | # Aggregate over samples
130 | data = agg(data_samples, axis=0)
131 |
132 | return np.mean(np.square((data - data_truth)), axis=(1, 2))
133 |
134 |
135 | def rmse(data_samples, data_truth, agg=None, **kwargs) -> np.array:
136 | """Computes mean squared error (RMSE)
137 |
138 | Arguments:
139 | * data_samples (``np.array``): Sampled predictions (n_samples, n_timeseries, n_variables, n_timesteps).
140 | * data_truth (``np.array``): Ground truth time series values (n_timeseries, n_variables, n_timesteps).
141 | * agg: Aggregation function applied to sampled predictions (defaults to ``np.median``).
142 |
143 | """
144 | if data_samples.shape[1:] != data_truth.shape:
145 | raise ValueError('Last three dimensions of data_samples and data_truth need to be compatible')
146 | agg = np.median if not agg else agg
147 |
148 | # Aggregate over samples
149 | data = agg(data_samples, axis=0)
150 |
151 | return np.sqrt(mse(data, data_truth))
152 |
153 |
154 | def coverage(data_samples, data_truth, percentiles=None, **kwargs) -> list:
155 | """Computes coverage rates of the prediction interval.
156 |
157 | Arguments:
158 | * data_samples (``np.array``): Sampled predictions (n_samples, n_timeseries, n_variables, n_timesteps).
159 | * data_truth (``np.array``): Ground truth time series values (n_timeseries, n_variables, n_timesteps).
160 | * percentiles (list): percentiles to calculate coverage for
161 |
162 | """
163 | if data_samples.shape[1:] != data_truth.shape:
164 | raise ValueError('Last three dimensions of data_samples and data_truth need to be compatible')
165 | if percentiles is None:
166 | percentiles = [0.5, 2.5, 5, 25, 50, 75, 95, 97.5, 99.5]
167 |
168 | data_perc = np.percentile(data_samples, q=percentiles, axis=0)
169 | coverage_percentages = []
170 | for perc in data_perc:
171 | coverage_percentages.append(
172 | np.round(np.mean(data_truth <= perc) * 100.0, 3)
173 | )
174 |
175 | return coverage_percentages
176 |
177 |
178 | def pinball_loss(data_samples, data_truth, percentiles=None, **kwargs) -> np.array:
179 | """Computes pinball loss.
180 |
181 | Arguments:
182 | * data_samples (``np.array``): Sampled predictions (n_samples, n_timeseries, n_variables, n_timesteps).
183 | * data_truth (``np.array``): Ground truth time series values (n_timeseries, n_variables, n_timesteps).
184 | * percentiles (list): Percentiles used to calculate coverage.
185 |
186 | """
187 | if data_samples.shape[1:] != data_truth.shape:
188 | raise ValueError('Last three dimensions of data_samples and data_truth need to be compatible')
189 | if percentiles is None:
190 | percentiles = np.linspace(0, 100, 101)
191 |
192 | num_steps = data_samples.shape[2]
193 |
194 | # Calculate percentiles
195 | data_perc = np.percentile(data_samples, q=percentiles, axis=0)
196 |
197 | # Calculate mean pinball loss
198 | total = 0
199 | for perc, q in zip(data_perc, percentiles):
200 | # Calculate upper and lower branch of pinball loss
201 | upper = data_truth - perc
202 | lower = perc - data_truth
203 | upper = np.sum(q / 100.0 * upper[upper >= 0])
204 | lower = np.sum((1 - q / 100.0) * lower[lower > 0])
205 | total += (upper + lower) / num_steps
206 |
207 | # Add overall mean pinball loss
208 | return np.round(total / len(percentiles), 3)
209 |
210 |
211 | def msis(data_samples,
212 | data_truth,
213 | data_insample,
214 | frequencies,
215 | alpha=0.05, **kwargs) -> np.array:
216 | """Mean Scaled Interval Score (MSIS) as shown in the `M4 competition
217 | `_.
218 |
219 | Arguments:
220 | * data_samples (``np.array``): Sampled predictions (n_samples, n_timeseries, n_variables, n_timesteps).
221 | * data_truth (``np.array``): Ground truth time series values (n_timeseries, n_variables, n_timesteps).
222 | * data_insample (``np.array``): In-sample time series data (n_timeseries, n_variables, n_timesteps).
223 | * frequencies (list): Frequencies to be used when calculating the naive forecast.
224 | * alpha (float): Significance level.
225 |
226 | """
227 | if data_samples.shape[1:] != data_truth.shape:
228 | raise ValueError('Last three dimensions of data_samples and data_truth need to be compatible')
229 | lower = (alpha / 2) * 100
230 | upper = 100 - (alpha / 2) * 100
231 |
232 | # drop individual samples for a given time series where the prediction is
233 | # not finite
234 | penalty_us, penalty_ls, scores, seas_diffs = [], [], [], []
235 | for i in range(data_samples.shape[1]):
236 | # Set up individual time series
237 | ts_sample = data_samples[:, i]
238 | ts_truth = data_truth[i]
239 | ts = data_insample[i]
240 | freq = int(frequencies[i])
241 |
242 | mask = np.where(~np.isfinite(ts_sample))[0]
243 | if mask.shape[0] > 0:
244 | mask = np.unique(mask)
245 | warnings.warn('For time series {}, removing {} of {} total samples.'.format(
246 | i, mask.shape[0], ts_sample.shape[0]))
247 | ts_sample = np.delete(ts_sample, mask, axis=0)
248 |
249 | # Calculate percentiles
250 | data_perc = np.percentile(ts_sample, q=(lower, upper), axis=0)
251 |
252 | # Penalty is (lower - actual) + (actual - upper)
253 | penalty_l = data_perc[0] - ts_truth
254 | penalty_l = np.where(penalty_l > 0, penalty_l, 0)
255 | penalty_u = ts_truth - data_perc[1]
256 | penalty_u = np.where(penalty_u > 0, penalty_u, 0)
257 |
258 | penalty_u = (2 / alpha) * np.mean(penalty_u, axis=1)
259 | penalty_l = (2 / alpha) * np.mean(penalty_l, axis=1)
260 |
261 | # Score is upper - lower
262 | score = np.mean(data_perc[1] - data_perc[0], axis=1)
263 |
264 | # Naive forecast is calculated using insample data
265 | t_in = ts.shape[-1]
266 | ts = ts[-t_in:]
267 | naive_forecast = ts[:, :t_in-freq]
268 | naive_target = ts[:, freq:]
269 | seas_diff = np.mean(np.abs(naive_target - naive_forecast))
270 |
271 | penalty_us.append(penalty_u)
272 | penalty_ls.append(penalty_l)
273 | scores.append(score)
274 | seas_diffs.append(seas_diff)
275 |
276 | penalty_us = np.concatenate(penalty_us)
277 | penalty_ls = np.concatenate(penalty_ls)
278 | scores = np.concatenate(scores)
279 | seas_diffs = np.array(seas_diffs)
280 |
281 | return (scores + penalty_us + penalty_ls) / seas_diffs
282 |
283 |
284 | def acd(data_samples, data_truth, alpha=0.05, **kwargs) -> float:
285 | """The absolute difference between the coverage of the method and the target (0.95).
286 |
287 | Arguments:
288 | * data_samples (``np.array``): Sampled predictions (n_samples, n_timeseries, n_variables, n_timesteps).
289 | * data_truth (``np.array``): Ground truth time series values (n_timeseries, n_variables, n_timesteps).
290 | * alpha (float): percentile to compute coverage difference
291 |
292 | """
293 | if data_samples.shape[1:] != data_truth.shape:
294 | raise ValueError('Last three dimensions of data_samples and data_truth need to be compatible')
295 |
296 | alpha = (1 - alpha) * 100
297 | data_perc = np.percentile(data_samples, q=[alpha], axis=0)
298 | acd = alpha - np.round(np.mean(data_truth <= data_perc[0]) * 100.0, 3)
299 | acd = np.abs(acd) / 100
300 |
301 | return acd
302 |
303 |
--------------------------------------------------------------------------------
/deep4cast/models.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 | from deep4cast import custom_layers
5 |
6 |
7 | class WaveNet(torch.nn.Module):
8 | """Implements `WaveNet` architecture for time series forecasting. Inherits
9 | from pytorch `Module `_.
10 | Vector forecasts are made via a fully-connected linear layer.
11 |
12 | References:
13 | - `WaveNet: A Generative Model for Raw Audio `_
14 |
15 | Arguments:
16 | * input_channels (int): Number of covariates in input time series.
17 | * output_channels (int): Number of target time series.
18 | * horizon (int): Number of time steps to forecast.
19 | * hidden_channels (int): Number of channels in convolutional hidden layers.
20 | * skip_channels (int): Number of channels in convolutional layers for skip connections.
21 | * n_layers (int): Number of layers per Wavenet block (determines receptive field size).
22 | * n_blocks (int): Number of Wavenet blocks.
23 | * dilation (int): Dilation factor for temporal convolution.
24 |
25 | """
26 | def __init__(self,
27 | input_channels,
28 | output_channels,
29 | horizon,
30 | hidden_channels=64,
31 | skip_channels=64,
32 | n_layers=7,
33 | n_blocks=1,
34 | dilation=2):
35 | """Inititalize variables."""
36 | super(WaveNet, self).__init__()
37 | self.output_channels = output_channels
38 | self.horizon = horizon
39 | self.hidden_channels = hidden_channels
40 | self.skip_channels = skip_channels
41 | self.n_layers = n_layers
42 | self.n_blocks = n_blocks
43 | self.dilation = dilation
44 | self.dilations = [dilation**i for i in range(n_layers)] * n_blocks
45 |
46 | # Set up first layer for input
47 | self.do_conv_input = custom_layers.ConcreteDropout(channel_wise=True)
48 | self.conv_input = torch.nn.Conv1d(
49 | in_channels=input_channels,
50 | out_channels=hidden_channels,
51 | kernel_size=1
52 | )
53 |
54 | # Set up main WaveNet layers
55 | self.do, self.conv, self.skip, self.resi = [], [], [], []
56 | for d in self.dilations:
57 | self.do.append(custom_layers.ConcreteDropout(channel_wise=True))
58 | self.conv.append(torch.nn.Conv1d(in_channels=hidden_channels,
59 | out_channels=hidden_channels,
60 | kernel_size=2,
61 | dilation=d))
62 | self.skip.append(torch.nn.Conv1d(in_channels=hidden_channels,
63 | out_channels=skip_channels,
64 | kernel_size=1))
65 | self.resi.append(torch.nn.Conv1d(in_channels=hidden_channels,
66 | out_channels=hidden_channels,
67 | kernel_size=1))
68 | self.do = torch.nn.ModuleList(self.do)
69 | self.conv = torch.nn.ModuleList(self.conv)
70 | self.skip = torch.nn.ModuleList(self.skip)
71 | self.resi = torch.nn.ModuleList(self.resi)
72 |
73 | # Set up nonlinear output layers
74 | self.do_conv_post = custom_layers.ConcreteDropout(channel_wise=True)
75 | self.conv_post = torch.nn.Conv1d(
76 | in_channels=skip_channels,
77 | out_channels=skip_channels,
78 | kernel_size=1
79 | )
80 | self.do_linear_mean = custom_layers.ConcreteDropout()
81 | self.do_linear_std = custom_layers.ConcreteDropout()
82 | self.do_linear_df = custom_layers.ConcreteDropout()
83 | self.linear_mean = torch.nn.Linear(
84 | skip_channels, horizon*output_channels)
85 | self.linear_std = torch.nn.Linear(
86 | skip_channels, horizon*output_channels)
87 | self.linear_df = torch.nn.Linear(
88 | skip_channels, horizon*output_channels)
89 |
90 | def forward(self, inputs):
91 | """Forward function."""
92 | output, reg_e = self.encode(inputs)
93 | output_mean, output_std, output_df, reg_d = self.decode(output)
94 |
95 | # Regularization
96 | regularizer = reg_e + reg_d
97 |
98 | return {'df': output_df, 'loc': output_mean, 'scale': output_std, 'regularizer': regularizer}
99 |
100 | def encode(self, inputs: torch.Tensor):
101 | """Returns embedding vectors.
102 |
103 | Arguments:
104 | * inputs: time series input to make forecasts for
105 |
106 | """
107 | # Input layer
108 | output, res_conv_input = self.do_conv_input(inputs)
109 | output = self.conv_input(output)
110 |
111 | # Loop over WaveNet layers and blocks
112 | regs, skip_connections = [], []
113 | for do, conv, skip, resi in zip(self.do, self.conv, self.skip, self.resi):
114 | layer_in = output
115 | output, reg = do(layer_in)
116 | output = conv(output)
117 | output = torch.nn.functional.relu(output)
118 | skip = skip(output)
119 | output = resi(output)
120 | output = output + layer_in[:, :, -output.size(2):]
121 | regs.append(reg)
122 | skip_connections.append(skip)
123 |
124 | # Sum up regularizer terms and skip connections
125 | regs = sum(r for r in regs)
126 | output = sum([s[:, :, -output.size(2):] for s in skip_connections])
127 |
128 | # Nonlinear output layers
129 | output, res_conv_post = self.do_conv_post(output)
130 | output = torch.nn.functional.relu(output)
131 | output = self.conv_post(output)
132 | output = torch.nn.functional.relu(output)
133 | output = output[:, :, [-1]]
134 | output = output.transpose(1, 2)
135 |
136 | # Regularization terms
137 | regularizer = res_conv_input \
138 | + regs \
139 | + res_conv_post
140 |
141 | return output, regularizer
142 |
143 | def decode(self, inputs: torch.Tensor):
144 | """Returns forecasts based on embedding vectors.
145 |
146 | Arguments:
147 | * inputs: embedding vectors to generate forecasts for
148 |
149 | """
150 | # Apply dense layer to match output length
151 | output_mean, res_linear_mean = self.do_linear_mean(inputs)
152 | output_std, res_linear_std = self.do_linear_std(inputs)
153 | output_df, res_linear_df = self.do_linear_df(inputs)
154 | output_mean = self.linear_mean(output_mean)
155 | output_std = self.linear_std(output_std).exp()
156 | output_df = self.linear_df(output_df).exp()
157 |
158 | # Reshape the layer output to match targets
159 | # Shape is (batch_size, output_channels, horizon)
160 | batch_size = inputs.shape[0]
161 | output_mean = output_mean.reshape(
162 | (batch_size, self.output_channels, self.horizon)
163 | )
164 | output_std = output_std.reshape(
165 | (batch_size, self.output_channels, self.horizon)
166 | )
167 | output_df = output_df.reshape(
168 | (batch_size, self.output_channels, self.horizon)
169 | )
170 |
171 | # Regularization terms
172 | regularizer = res_linear_mean + res_linear_std + res_linear_df
173 |
174 | return output_mean, output_std, output_df, regularizer
175 |
176 | @property
177 | def n_parameters(self):
178 | """Returns the number of model parameters."""
179 | par = list(self.parameters())
180 | s = sum([np.prod(list(d.size())) for d in par])
181 | return s
182 |
183 | @property
184 | def receptive_field_size(self):
185 | """Returns the length of the receptive field."""
186 | return self.dilation * max(self.dilations)
187 |
--------------------------------------------------------------------------------
/deep4cast/transforms.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class Compose(object):
5 | r"""Composes several transforms together.
6 |
7 | List of transforms must currently begin with ``ToTensor`` and end with
8 | ``Target``.
9 |
10 | Args:
11 | * transforms (list of ``Transform`` objects): list of transforms to compose.
12 |
13 | Example:
14 | >>> transforms.Compose([
15 | >>> transforms.ToTensor(),
16 | >>> transforms.LogTransform(targets=[0], offset=1.0),
17 | >>> transforms.Target(targets=[0]),
18 | >>> ])
19 | """
20 |
21 | def __init__(self, transforms):
22 | self.transforms = transforms
23 |
24 | def __call__(self, example):
25 | for t in self.transforms:
26 | example = t(example)
27 | return example
28 |
29 | def untransform(self, example):
30 | for t in self.transforms[::-1]:
31 | example = t.untransform(example)
32 | return example
33 |
34 |
35 | class LogTransform(object):
36 | r"""Natural logarithm of target covariate + `offset`.
37 |
38 | .. math:: y_i = log_e ( x_i + \mbox{offset} )
39 |
40 | Args:
41 | * offset (float): amount to add before taking the natural logarithm
42 | * targets (list): list of indices to transform.
43 |
44 | Example:
45 | >>> transforms.LogTransform(targets=[0], offset=1.0)
46 | """
47 |
48 | def __init__(self, targets=None, offset=0.0):
49 | self.offset = offset
50 | self.targets = targets
51 |
52 | def __call__(self, sample):
53 | X = sample['X']
54 | y = sample['y']
55 |
56 | if self.targets:
57 | X[self.targets, :] = torch.log(self.offset + X[self.targets, :])
58 | y[self.targets, :] = torch.log(self.offset + y[self.targets, :])
59 | else:
60 | X = torch.log(self.offset + X)
61 | y = torch.log(self.offset + y)
62 |
63 | sample['X'] = X
64 | sample['y'] = y
65 |
66 | return sample
67 |
68 | def untransform(self, sample):
69 | X, y = sample['X'], sample['y']
70 |
71 | # Unpack nested list of forecasting targets.
72 | Target_targets = [torch.unique(x).tolist()
73 | for x in sample['Target_targets']]
74 | Target_targets = sum(Target_targets, [])
75 |
76 | # If the transform target and forecast target overlap then find the
77 | # corresponding index in the y array.
78 | intersect = set(self.targets).intersection(Target_targets)
79 | indices_y = [i for i, item in enumerate(
80 | Target_targets) if item in intersect]
81 |
82 | if self.targets:
83 | X[:, self.targets, :] = \
84 | torch.exp(X[:, self.targets, :]) - self.offset
85 | else:
86 | X = torch.exp(X) - self.offset
87 | y = torch.exp(y) - self.offset
88 |
89 | # Exponentiate only those forecasting targets where we took the
90 | # natural log.
91 | if len(intersect) > 0:
92 | y[:, indices_y, :] = torch.exp(y[:, indices_y, :]) - self.offset
93 |
94 | sample['X'] = X
95 | sample['y'] = y
96 |
97 | return sample
98 |
99 |
100 | class RemoveLast(object):
101 | r"""Subtract final point in lookback window from all points in example.
102 |
103 | Args:
104 | * targets (list): list of indices to transform.
105 |
106 | Example:
107 | >>> transforms.RemoveLast(targets=[0])
108 | """
109 |
110 | def __init__(self, targets=None):
111 | self.targets = targets
112 |
113 | def __call__(self, sample):
114 | X, y = sample['X'], sample['y']
115 |
116 | if self.targets:
117 | offset = X[self.targets, -1]
118 | X[self.targets, :] = X[self.targets, :] - offset[:, None]
119 | y[self.targets, :] = y[self.targets, :] - offset[:, None]
120 | else:
121 | offset = X[:, -1]
122 | X = X - offset[:, None]
123 | y = y - offset[:, None]
124 |
125 | sample['RemoveLast_offset'] = offset
126 |
127 | return sample
128 |
129 | def untransform(self, sample):
130 | X, y = sample['X'], sample['y']
131 | offset = sample['RemoveLast_offset']
132 |
133 | # Unpack nested list of forecasting targets.
134 | Target_targets = \
135 | [torch.unique(x).tolist() for x in sample['Target_targets']]
136 | Target_targets = sum(Target_targets, [])
137 |
138 | # If the transform target and forecast target overlap then find the
139 | # corresponding index in the y array.
140 | intersect = set(self.targets).intersection(Target_targets)
141 |
142 | if self.targets:
143 | X[:, self.targets, :] = \
144 | X[:, self.targets, :] + offset[:, :, None].float()
145 | else:
146 | X += offset[:, :, None].float()
147 | y += offset[:, Target_targets, None].float()
148 |
149 | # Add back to the correct forecasted index the quantity removed
150 | if len(intersect) > 0:
151 | indices_o = \
152 | [i for i, item in enumerate(self.targets) if item in intersect]
153 | indices_y = \
154 | [i for i, item in enumerate(
155 | Target_targets) if item in intersect]
156 | y[:, indices_y, :] = \
157 | y[:, indices_y, :] + offset[:, indices_o, None].float()
158 |
159 | sample['X'] = X
160 | sample['y'] = y
161 |
162 | return sample
163 |
164 |
165 | class ToTensor(object):
166 | r"""Convert ``numpy.ndarrays`` to tensor.
167 |
168 | Args:
169 | * device (str): device on which to load the tensor.
170 |
171 | Example:
172 | >>> transforms.ToTensor(device='cpu')
173 | """
174 |
175 | def __init__(self, device='cpu'):
176 | self.device = torch.device(device)
177 |
178 | def __call__(self, sample):
179 | sample['X'] = torch.tensor(sample['X'], device=self.device).float()
180 | sample['y'] = torch.tensor(sample['y'], device=self.device).float()
181 |
182 | return sample
183 |
184 | def untransform(self, sample):
185 | return sample
186 |
187 |
188 | class Target(object):
189 | r"""Retain only target indices for output.
190 |
191 | Args:
192 | * targets (list): list of indices to retain.
193 |
194 | Example:
195 | >>> transforms.Target(targets=[0])
196 | """
197 |
198 | def __init__(self, targets):
199 | self.targets = targets
200 |
201 | def __call__(self, sample):
202 | sample['y'] = sample['y'][self.targets, :]
203 | sample['Target_targets'] = self.targets
204 |
205 | return sample
206 |
207 | def untransform(self, sample):
208 | return sample
209 |
210 |
211 | class Standardize(object):
212 | """Subtract the mean and divide by the standard deviation from the lookback.
213 |
214 | Args:
215 | * targets (list): list of indices to transform.
216 |
217 | Example:
218 | >>> transforms.Standardize(targets=[0])
219 | """
220 |
221 | def __init__(self, targets=None):
222 | self.targets = targets
223 |
224 | def __call__(self, sample):
225 | X, y = sample['X'], sample['y']
226 |
227 | # Remove mean from X and y and rescale by standard deviation
228 | if self.targets:
229 | mean = X[self.targets, :].mean(dim=1)
230 | std = X[self.targets, :].std(dim=1)
231 | X[self.targets, :] -= mean[:, None]
232 | X[self.targets, :] /= std[:, None]
233 | y[self.targets, :] -= mean[:, None]
234 | y[self.targets, :] /= std[:, None]
235 | else:
236 | mean = X.mean(dim=1)
237 | std = X.std(dim=1)
238 | X -= mean[:, None]
239 | X /= std[:, None]
240 | y -= mean[:, None]
241 | y /= std[:, None]
242 |
243 | sample['X'] = X
244 | sample['y'] = y
245 | sample['Standardize_mean'] = mean
246 | sample['Standardize_std'] = std
247 |
248 | return sample
249 |
250 | def untransform(self, sample):
251 | X, y = sample['X'], sample['y']
252 |
253 | # Unpack nested list of forecasting targets.
254 | Target_targets = \
255 | [torch.unique(x).tolist() for x in sample['Target_targets']]
256 | Target_targets = sum(Target_targets, [])
257 |
258 | # If the transform target and forecast target overlap then find the
259 | # corresponding index in the y array.
260 | intersect = set(self.targets).intersection(Target_targets)
261 |
262 | if self.targets:
263 | X[:, self.targets, :] = \
264 | X[:, self.targets, :] * sample['Standardize_std'][:, :, None]
265 | X[:, self.targets, :] = \
266 | X[:, self.targets, :] + sample['Standardize_mean'][:, :, None]
267 | else:
268 | X = X * sample['Standardize_std']
269 | X = X + sample['Standardize_mean']
270 | y = y * sample['Standardize_std'][:, Target_targets, None]
271 | y = y + sample['Standardize_mean'][:, Target_targets, None]
272 |
273 | # Add back to the correct index the quantity removed
274 | if len(intersect) > 0:
275 | # indices for the offset
276 | indices_o = \
277 | [i for i, item in enumerate(self.targets) if item in intersect]
278 | # indices for the target
279 | indices_y = \
280 | [i for i, item in enumerate(
281 | Target_targets) if item in intersect]
282 | y[:, indices_y, :] = \
283 | y[:, indices_y, :] * sample['Standardize_std'][:, indices_o, None]
284 | y[:, indices_y, :] = \
285 | y[:, indices_y, :] + sample['Standardize_mean'][:, indices_o, None]
286 |
287 | return sample
288 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line.
5 | SPHINXOPTS =
6 | SPHINXBUILD = sphinx-build
7 | SOURCEDIR = .
8 | BUILDDIR = _build
9 |
10 | # Put it first so that "make" without argument is like "make help".
11 | help:
12 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
13 |
14 | .PHONY: help Makefile
15 |
16 | # Catch-all target: route all unknown targets to Sphinx using the new
17 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
18 | %: Makefile
19 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
--------------------------------------------------------------------------------
/docs/conf.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 |
4 |
5 | sys.path.insert(0, os.path.abspath('../'))
6 | sys.path.insert(1, os.path.abspath('../deep4cast/'))
7 |
8 | extensions = ['sphinx.ext.autodoc',
9 | 'sphinx.ext.mathjax',
10 | 'nbsphinx']
11 | source_suffix = '.rst'
12 | master_doc = 'index'
13 | project = u'Deep4Cast'
14 | copyright = u''
15 | exclude_patterns = ['_build', '**.ipynb_checkpoints']
16 | pygments_style = 'sphinx'
17 | html_theme = "sphinx_rtd_theme"
18 | html_logo = "images/thumb.jpg"
19 | html_theme_options = {
20 | 'logo_only': True,
21 | "style_nav_header_background" : "#3176BB"
22 | }
23 | autoclass_content = "both"
24 | use_system_site_packages = True
25 | autodoc_mock_imports = ["numpy", "torch"]
--------------------------------------------------------------------------------
/docs/custom_layers.rst:
--------------------------------------------------------------------------------
1 | Custom Layers
2 | =============
3 |
4 | Custom layers that can be used to build extended PyTorch models for forecasting.
5 |
6 | References:
7 | - `Concrete Dropout `_ is used for approximate posterior Bayesian inference.
8 |
9 | .. automodule:: custom_layers
10 | :members:
11 |
--------------------------------------------------------------------------------
/docs/datasets.rst:
--------------------------------------------------------------------------------
1 | Datasets
2 | ========
3 |
4 | Inherits from `pytorch datasets `_
5 | to allow use with `pytorch dataloader `_.
6 |
7 | .. automodule:: datasets
8 | :members:
9 |
--------------------------------------------------------------------------------
/docs/examples/m4daily.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Tutorial: M4 Daily\n",
8 | "\n",
9 | "This notebook is designed to give a simple introduction to forecasting using the Deep4Cast package. The time series data is taken from the [M4 dataset](https://github.com/M4Competition/M4-methods/tree/master/Dataset), specifically, the ``Daily`` subset of the data. "
10 | ]
11 | },
12 | {
13 | "cell_type": "code",
14 | "execution_count": 1,
15 | "metadata": {
16 | "ExecuteTime": {
17 | "end_time": "2019-06-28T17:15:02.007580Z",
18 | "start_time": "2019-06-28T17:15:01.380345Z"
19 | },
20 | "scrolled": true
21 | },
22 | "outputs": [],
23 | "source": [
24 | "import numpy as np\n",
25 | "import os\n",
26 | "import pandas as pd\n",
27 | "import datetime as dt\n",
28 | "import matplotlib.pyplot as plt\n",
29 | "\n",
30 | "import torch\n",
31 | "from torch.utils.data import DataLoader\n",
32 | "\n",
33 | "from deep4cast.forecasters import Forecaster\n",
34 | "from deep4cast.models import WaveNet\n",
35 | "from deep4cast.datasets import TimeSeriesDataset\n",
36 | "import deep4cast.transforms as transforms\n",
37 | "import deep4cast.metrics as metrics\n",
38 | "\n",
39 | "# Make RNG predictable\n",
40 | "np.random.seed(0)\n",
41 | "torch.manual_seed(0)\n",
42 | "# Use a gpu if available, otherwise use cpu\n",
43 | "device = ('cuda' if torch.cuda.is_available() else 'cpu')\n",
44 | "\n",
45 | "%matplotlib inline"
46 | ]
47 | },
48 | {
49 | "cell_type": "markdown",
50 | "metadata": {},
51 | "source": [
52 | "## Dataset\n",
53 | "In this section we inspect the dataset, split it into a training and a test set, and prepare it for easy consuption with PyTorch-based data loaders. Model construction and training will be done in the next section."
54 | ]
55 | },
56 | {
57 | "cell_type": "code",
58 | "execution_count": 2,
59 | "metadata": {
60 | "ExecuteTime": {
61 | "end_time": "2019-06-28T17:15:02.017357Z",
62 | "start_time": "2019-06-28T17:15:02.011736Z"
63 | }
64 | },
65 | "outputs": [],
66 | "source": [
67 | "if not os.path.exists('data/Daily-train.csv'):\n",
68 | " !wget https://raw.githubusercontent.com/M4Competition/M4-methods/master/Dataset/Train/Daily-train.csv -P data/\n",
69 | "if not os.path.exists('data/Daily-test.csv'):\n",
70 | " !wget https://raw.githubusercontent.com/M4Competition/M4-methods/master/Dataset/Test/Daily-test.csv -P data/"
71 | ]
72 | },
73 | {
74 | "cell_type": "code",
75 | "execution_count": 3,
76 | "metadata": {
77 | "ExecuteTime": {
78 | "end_time": "2019-06-28T17:15:18.767394Z",
79 | "start_time": "2019-06-28T17:15:02.019564Z"
80 | }
81 | },
82 | "outputs": [],
83 | "source": [
84 | "data_arr = pd.read_csv('data/Daily-train.csv')\n",
85 | "data_arr = data_arr.iloc[:, 1:].values\n",
86 | "data_arr = list(data_arr)\n",
87 | "for i, ts in enumerate(data_arr):\n",
88 | " data_arr[i] = ts[~np.isnan(ts)][None, :]"
89 | ]
90 | },
91 | {
92 | "cell_type": "markdown",
93 | "metadata": {},
94 | "source": [
95 | "### Divide into train and test\n",
96 | "We use the DataLoader object from PyTorch to build batches from the test data set.\n",
97 | "\n",
98 | "However, we first need to specify how much history to use in creating a forecast of a given length:\n",
99 | "- horizon = time steps to forecast\n",
100 | "- lookback = time steps leading up to the period to be forecast"
101 | ]
102 | },
103 | {
104 | "cell_type": "code",
105 | "execution_count": 4,
106 | "metadata": {
107 | "ExecuteTime": {
108 | "end_time": "2019-06-28T17:15:18.771334Z",
109 | "start_time": "2019-06-28T17:15:18.769032Z"
110 | }
111 | },
112 | "outputs": [],
113 | "source": [
114 | "horizon = 14\n",
115 | "lookback = 128"
116 | ]
117 | },
118 | {
119 | "cell_type": "markdown",
120 | "metadata": {},
121 | "source": [
122 | "We've also found that it is not necessary to train on the full dataset, so we here select a 10% random sample of time series for training. We will evaluate on the full dataset later."
123 | ]
124 | },
125 | {
126 | "cell_type": "code",
127 | "execution_count": 5,
128 | "metadata": {
129 | "ExecuteTime": {
130 | "end_time": "2019-06-28T17:15:18.873938Z",
131 | "start_time": "2019-06-28T17:15:18.772798Z"
132 | }
133 | },
134 | "outputs": [],
135 | "source": [
136 | "import random\n",
137 | "\n",
138 | "data_train = []\n",
139 | "for time_series in data_arr:\n",
140 | " data_train.append(time_series[:, :-horizon],)\n",
141 | "data_train = random.sample(data_train, int(len(data_train) * 0.1))"
142 | ]
143 | },
144 | {
145 | "cell_type": "markdown",
146 | "metadata": {},
147 | "source": [
148 | "We follow [Torchvision](https://pytorch.org/docs/stable/torchvision) in processing examples using [Transforms](https://pytorch.org/docs/stable/torchvision/transforms.html) chained together by [Compose](https://pytorch.org/docs/stable/torchvision/transforms.html#torchvision.transforms.Compose).\n",
149 | "\n",
150 | "* `Tensorize` creates a tensor of the example.\n",
151 | "* `LogTransform` natural logarithm of the targets after adding the offset (similar to [torch.log1p](https://pytorch.org/docs/stable/torch.html#torch.log1p)).\n",
152 | "* `RemoveLast` subtracts the final value in the `lookback` from both `lookback` and `horizon`.\n",
153 | "* `Target` specifies which index in the array to forecast.\n",
154 | "\n",
155 | "We need to perform these transformations to have input features that are of the unit scale. If the input features are not of unit scale (i.e., of O(1)) for all features, the optimizer won't be able to find an optimium due to blow-ups in the gradient calculations."
156 | ]
157 | },
158 | {
159 | "cell_type": "code",
160 | "execution_count": 6,
161 | "metadata": {
162 | "ExecuteTime": {
163 | "end_time": "2019-06-28T17:15:18.950829Z",
164 | "start_time": "2019-06-28T17:15:18.876296Z"
165 | }
166 | },
167 | "outputs": [],
168 | "source": [
169 | "transform = transforms.Compose([\n",
170 | " transforms.ToTensor(),\n",
171 | " transforms.LogTransform(targets=[0], offset=1.0),\n",
172 | " transforms.RemoveLast(targets=[0]),\n",
173 | " transforms.Target(targets=[0]),\n",
174 | "])"
175 | ]
176 | },
177 | {
178 | "cell_type": "markdown",
179 | "metadata": {},
180 | "source": [
181 | "`TimeSeriesDataset` inherits from [Torch Datasets](https://pytorch.org/docs/stable/data.html#torch.utils.data.Dataset) for use with [Torch DataLoader](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader). It handles the creation of the examples used to train the network using `lookback` and `horizon` to partition the time series.\n",
182 | "\n",
183 | "The parameter 'step' controls how far apart consective windowed samples from a time series are spaced. For example, for a time series of length 100 and a setup with lookback 24 and horizon 12, we split the original time series into smaller training examples of length 24+12=36. How much these examples are overlapping is controlled by the parameter `step` in `TimeSeriesDataset`."
184 | ]
185 | },
186 | {
187 | "cell_type": "code",
188 | "execution_count": 7,
189 | "metadata": {
190 | "ExecuteTime": {
191 | "end_time": "2019-06-28T17:15:19.243876Z",
192 | "start_time": "2019-06-28T17:15:18.954125Z"
193 | }
194 | },
195 | "outputs": [],
196 | "source": [
197 | "data_train = TimeSeriesDataset(\n",
198 | " data_train, \n",
199 | " lookback, \n",
200 | " horizon,\n",
201 | " step=1,\n",
202 | " transform=transform\n",
203 | ")\n",
204 | "\n",
205 | "# Create mini-batch data loader\n",
206 | "dataloader_train = DataLoader(\n",
207 | " data_train, \n",
208 | " batch_size=512, \n",
209 | " shuffle=True, \n",
210 | " pin_memory=True,\n",
211 | " num_workers=1\n",
212 | ")"
213 | ]
214 | },
215 | {
216 | "cell_type": "markdown",
217 | "metadata": {},
218 | "source": [
219 | "## Modeling and Forecasting"
220 | ]
221 | },
222 | {
223 | "cell_type": "markdown",
224 | "metadata": {},
225 | "source": [
226 | "### Temporal Convolutions\n",
227 | "The network architecture used here is based on ideas related to [WaveNet](https://deepmind.com/blog/wavenet-generative-model-raw-audio/). We employ the same architecture with a few modifications (e.g., a fully connected output layer for vector forecasts). It turns out that we do not need many layers in this example to achieve state-of-the-art results, most likely because of the simple autoregressive nature of the data.\n",
228 | "\n",
229 | "In many ways, a temporal convoluational architecture is among the simplest possible architecures that we could employ using neural networks. In our approach, every layer has the same number of convolutional filters and uses residual connections.\n",
230 | "\n",
231 | "When it comes to loss functions, we use the log-likelihood of probability distributions from the `torch.distributions` module. This mean that if one supplues a normal distribution the likelihood of the transformed data is modeled as coming from a normal distribution."
232 | ]
233 | },
234 | {
235 | "cell_type": "code",
236 | "execution_count": 8,
237 | "metadata": {
238 | "ExecuteTime": {
239 | "end_time": "2019-06-28T17:15:19.261939Z",
240 | "start_time": "2019-06-28T17:15:19.246822Z"
241 | }
242 | },
243 | "outputs": [
244 | {
245 | "name": "stdout",
246 | "output_type": "stream",
247 | "text": [
248 | "Number of model parameters: 341347.\n",
249 | "Receptive field size: 128.\n",
250 | "Using 2 GPUs.\n"
251 | ]
252 | }
253 | ],
254 | "source": [
255 | "# Define the model architecture\n",
256 | "model = WaveNet(input_channels=1,\n",
257 | " output_channels=1,\n",
258 | " horizon=horizon, \n",
259 | " hidden_channels=89,\n",
260 | " skip_channels=199,\n",
261 | " n_layers=7)\n",
262 | "\n",
263 | "print('Number of model parameters: {}.'.format(model.n_parameters))\n",
264 | "print('Receptive field size: {}.'.format(model.receptive_field_size))\n",
265 | "\n",
266 | "# Enable multi-gpu if available\n",
267 | "if torch.cuda.device_count() > 1:\n",
268 | " print('Using {} GPUs.'.format(torch.cuda.device_count()))\n",
269 | " model = torch.nn.DataParallel(model)\n",
270 | "\n",
271 | "# .. and the optimizer\n",
272 | "optim = torch.optim.Adam(model.parameters(), lr=0.0008097436666349985)\n",
273 | "\n",
274 | "# .. and the loss\n",
275 | "loss = torch.distributions.StudentT"
276 | ]
277 | },
278 | {
279 | "cell_type": "code",
280 | "execution_count": 9,
281 | "metadata": {
282 | "ExecuteTime": {
283 | "end_time": "2019-06-28T17:52:16.907027Z",
284 | "start_time": "2019-06-28T17:15:19.263466Z"
285 | }
286 | },
287 | "outputs": [
288 | {
289 | "name": "stderr",
290 | "output_type": "stream",
291 | "text": [
292 | "/home/austin/miniconda3/envs/d4cGithub/lib/python3.6/site-packages/torch/nn/parallel/_functions.py:61: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
293 | " warnings.warn('Was asked to gather along dimension 0, but all '\n"
294 | ]
295 | },
296 | {
297 | "name": "stdout",
298 | "output_type": "stream",
299 | "text": [
300 | "Epoch 1/5 [915731/915731 (100%)]\tLoss: -1.863526\tElapsed/Remaining: 3m52s/15m30s \n",
301 | "Training error: -2.67e+01.\n",
302 | "Epoch 2/5 [915731/915731 (100%)]\tLoss: -1.963631\tElapsed/Remaining: 11m21s/17m2s \n",
303 | "Training error: -2.71e+01.\n",
304 | "Epoch 3/5 [915731/915731 (100%)]\tLoss: -1.983338\tElapsed/Remaining: 18m42s/12m28s \n",
305 | "Training error: -2.75e+01.\n",
306 | "Epoch 4/5 [915731/915731 (100%)]\tLoss: -1.974977\tElapsed/Remaining: 26m2s/6m30s \n",
307 | "Training error: -2.78e+01.\n",
308 | "Epoch 5/5 [915731/915731 (100%)]\tLoss: -2.073579\tElapsed/Remaining: 33m20s/0m0s \n",
309 | "Training error: -2.83e+01.\n"
310 | ]
311 | }
312 | ],
313 | "source": [
314 | "# Fit the forecaster\n",
315 | "forecaster = Forecaster(model, loss, optim, n_epochs=5, device=device)\n",
316 | "forecaster.fit(dataloader_train, eval_model=True)"
317 | ]
318 | },
319 | {
320 | "cell_type": "markdown",
321 | "metadata": {},
322 | "source": [
323 | "## Evaluation\n",
324 | "Before any evaluation score can be calculated, we load the held out test data."
325 | ]
326 | },
327 | {
328 | "cell_type": "code",
329 | "execution_count": 10,
330 | "metadata": {
331 | "ExecuteTime": {
332 | "end_time": "2019-06-28T17:52:33.409674Z",
333 | "start_time": "2019-06-28T17:52:16.911086Z"
334 | }
335 | },
336 | "outputs": [],
337 | "source": [
338 | "data_train = pd.read_csv('data/Daily-train.csv')\n",
339 | "data_test = pd.read_csv('data/Daily-test.csv')\n",
340 | "data_train = data_train.iloc[:, 1:].values\n",
341 | "data_test = data_test.iloc[:, 1:].values\n",
342 | "\n",
343 | "data_arr = []\n",
344 | "for ts_train, ts_test in zip(data_train, data_test):\n",
345 | " ts_a = ts_train[~np.isnan(ts_train)]\n",
346 | " ts_b = ts_test\n",
347 | " ts = np.concatenate([ts_a, ts_b])[None, :]\n",
348 | " data_arr.append(ts)"
349 | ]
350 | },
351 | {
352 | "cell_type": "code",
353 | "execution_count": 11,
354 | "metadata": {
355 | "ExecuteTime": {
356 | "end_time": "2019-06-28T17:52:33.421253Z",
357 | "start_time": "2019-06-28T17:52:33.411359Z"
358 | }
359 | },
360 | "outputs": [],
361 | "source": [
362 | "# Sequentialize the training and testing dataset\n",
363 | "data_test = []\n",
364 | "for time_series in data_arr:\n",
365 | " data_test.append(time_series[:, -horizon-lookback:])\n",
366 | "\n",
367 | "data_test = TimeSeriesDataset(\n",
368 | " data_test, \n",
369 | " lookback, \n",
370 | " horizon, \n",
371 | " step=1,\n",
372 | " transform=transform\n",
373 | ")\n",
374 | "dataloader_test = DataLoader(\n",
375 | " data_test, \n",
376 | " batch_size=1024, \n",
377 | " shuffle=False,\n",
378 | " num_workers=2\n",
379 | ")"
380 | ]
381 | },
382 | {
383 | "cell_type": "markdown",
384 | "metadata": {},
385 | "source": [
386 | "We need to transform the output forecasts. The output from the foracaster is of the form (n_samples, n_time_series, n_variables, n_timesteps).\n",
387 | "This means, that a point forcast needs to be calculated from the samples, for example, by taking the mean or the median."
388 | ]
389 | },
390 | {
391 | "cell_type": "code",
392 | "execution_count": 12,
393 | "metadata": {
394 | "ExecuteTime": {
395 | "end_time": "2019-06-28T17:52:55.851568Z",
396 | "start_time": "2019-06-28T17:52:33.422806Z"
397 | }
398 | },
399 | "outputs": [],
400 | "source": [
401 | "# Get time series of actuals for the testing period\n",
402 | "y_test = []\n",
403 | "for example in dataloader_test:\n",
404 | " example = dataloader_test.dataset.transform.untransform(example)\n",
405 | " y_test.append(example['y'])\n",
406 | "y_test = np.concatenate(y_test)\n",
407 | "\n",
408 | "# Get corresponding predictions\n",
409 | "y_samples = forecaster.predict(dataloader_test, n_samples=100)"
410 | ]
411 | },
412 | {
413 | "cell_type": "markdown",
414 | "metadata": {},
415 | "source": [
416 | "We calculate the [symmetric MAPE](https://en.wikipedia.org/wiki/Symmetric_mean_absolute_percentage_error)."
417 | ]
418 | },
419 | {
420 | "cell_type": "code",
421 | "execution_count": 13,
422 | "metadata": {
423 | "ExecuteTime": {
424 | "end_time": "2019-06-28T17:52:55.953031Z",
425 | "start_time": "2019-06-28T17:52:55.853679Z"
426 | }
427 | },
428 | "outputs": [
429 | {
430 | "name": "stdout",
431 | "output_type": "stream",
432 | "text": [
433 | "SMAPE: 3.1666347980499268%\n"
434 | ]
435 | }
436 | ],
437 | "source": [
438 | "# Evaluate forecasts\n",
439 | "test_smape = metrics.smape(y_samples, y_test)\n",
440 | "\n",
441 | "print('SMAPE: {}%'.format(test_smape.mean()))"
442 | ]
443 | }
444 | ],
445 | "metadata": {
446 | "kernelspec": {
447 | "display_name": "d4cGithub",
448 | "language": "python",
449 | "name": "d4cgithub"
450 | },
451 | "language_info": {
452 | "codemirror_mode": {
453 | "name": "ipython",
454 | "version": 3
455 | },
456 | "file_extension": ".py",
457 | "mimetype": "text/x-python",
458 | "name": "python",
459 | "nbconvert_exporter": "python",
460 | "pygments_lexer": "ipython3",
461 | "version": "3.6.7"
462 | },
463 | "toc": {
464 | "base_numbering": 1,
465 | "nav_menu": {},
466 | "number_sections": true,
467 | "sideBar": true,
468 | "skip_h1_title": false,
469 | "title_cell": "Table of Contents",
470 | "title_sidebar": "Contents",
471 | "toc_cell": true,
472 | "toc_position": {},
473 | "toc_section_display": true,
474 | "toc_window_display": false
475 | }
476 | },
477 | "nbformat": 4,
478 | "nbformat_minor": 1
479 | }
480 |
--------------------------------------------------------------------------------
/docs/forecasters.rst:
--------------------------------------------------------------------------------
1 | Forecasters
2 | ===========
3 |
4 | Module that handles all forecaster objects for training PyTorch models.
5 |
6 | .. automodule:: forecasters
7 | :members:
8 |
--------------------------------------------------------------------------------
/docs/get_started.rst:
--------------------------------------------------------------------------------
1 | ===============
2 | Getting Started
3 | ===============
4 |
5 | Deep4Cast is a deep learning-based forecasting solution based on PyTorch. It can be used to build forecasters based on PyTorch models that are trained over large sets of time series.
6 |
7 | Main Requirements
8 | =================
9 |
10 | - `python 3.6 `_
11 | - `pytorch 1.0 `_
12 |
13 | Installation
14 | ============
15 |
16 | Deep4cast can be cloned from `GitHub `_. Before installing we recommend setting up a clean `virtual environment `_.
17 |
18 | From the package directory install the requirements and then the package.
19 |
20 | .. code-block::
21 |
22 | $ pip install -r requirements.txt
23 | $ python setup.py install
24 |
25 |
--------------------------------------------------------------------------------
/docs/images/thumb.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MSRDL/Deep4Cast/c9ddf868d203597114e20e075f4f2dcf6411b4df/docs/images/thumb.jpg
--------------------------------------------------------------------------------
/docs/index.rst:
--------------------------------------------------------------------------------
1 | =======================
2 | Deep4Cast Documentation
3 | =======================
4 |
5 | Forecasting for decision making under uncertainty
6 | =================================================
7 |
8 | **This package is under active development. Things may change :-).**
9 |
10 | ``Deep4Cast`` is a scalable machine learning package implemented in ``Python``
11 | and ``Torch``. It has a front-end API similar to ``scikit-learn``. It is
12 | designed for medium to large time series data sets and allows for modeling of
13 | forecast uncertainties.
14 |
15 | The network architecture is based on ``WaveNet``. Regularization and
16 | approximate sampling from posterior predictive distributions of forecasts are
17 | achieved via ``Concrete Dropout``.
18 |
19 | Examples
20 | --------
21 |
22 | :ref:`/examples/m4daily.ipynb`
23 |
24 |
25 | Authors
26 | -------
27 | - `Toby Bischoff `_
28 | - Austin Gross
29 | - `Kenneth Tran `_
30 |
31 |
32 | References
33 | ----------
34 | - `Concrete Dropout `_ is used for approximate posterior Bayesian inference.
35 | - `Wavenet `_ is used as encoder network.
36 |
37 |
38 | .. toctree::
39 | :maxdepth: 2
40 | :glob:
41 | :hidden:
42 |
43 | get_started
44 | examples/*
45 | datasets
46 | transforms
47 | models
48 | forecasters
49 | metrics
50 | custom_layers
51 |
--------------------------------------------------------------------------------
/docs/metrics.rst:
--------------------------------------------------------------------------------
1 | Metrics
2 | =======
3 |
4 | Common evaluation metrics for time series forecasts.
5 |
6 | .. automodule:: metrics
7 | :members:
--------------------------------------------------------------------------------
/docs/models.rst:
--------------------------------------------------------------------------------
1 | Models
2 | ======
3 |
4 | .. automodule:: models
5 | :members:
6 |
--------------------------------------------------------------------------------
/docs/requirements.txt:
--------------------------------------------------------------------------------
1 | sphinx_rtd_theme
2 | nbsphinx
--------------------------------------------------------------------------------
/docs/transforms.rst:
--------------------------------------------------------------------------------
1 | Transformations
2 | ===============
3 |
4 | Transformations of the time series intended to be used in a similar fashion to
5 | `torchvision `_.
6 |
7 | .. automodule:: transforms
8 | :members:
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch>=1.0.0
2 | torchvision>=0.2.1
3 | matplotlib>=3.0.3
4 | numpy>=1.16.2
5 | pandas>=0.24.2
6 | scipy>=1.2.1
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 |
3 | setup(
4 | name='deep4cast',
5 | version='0.1a',
6 | description='package for forecasting using deep learning',
7 | url='https://github.com/MSRDL/Deep4Cast',
8 | author='Microsoft',
9 | author_email='ktran@microsoft.com',
10 | license='BSD',
11 | packages=['deep4cast'],
12 | install_requires=[],
13 | zip_safe=False
14 | )
15 |
--------------------------------------------------------------------------------