├── solar.sh
├── traffic.sh
├── electricity.sh
├── exchange_rate.sh
├── lstnet_datautil.py
├── main.py
├── README.md
├── lstnet_plot.py
├── lstnet_util.py
└── lstnet_model.py
/solar.sh:
--------------------------------------------------------------------------------
1 | python3.6 main.py --data="data/solar_AL.txt" --SkipGRUUnits=10 --save="save/solar" --test --savehistory --logfilename="log/lstnet" --debuglevel=20
2 |
--------------------------------------------------------------------------------
/traffic.sh:
--------------------------------------------------------------------------------
1 | python3.6 main.py --data="data/traffic.txt" --SkipGRUUnits=10 --save="save/traffic" --test --savehistory --logfilename="log/lstnet" --debuglevel=20
2 |
--------------------------------------------------------------------------------
/electricity.sh:
--------------------------------------------------------------------------------
1 | python3.6 main.py --data="data/electricity.txt" --horizon=24 --save="save/electricity" --test --savehistory --logfilename="log/lstnet" --debuglevel=20
2 |
--------------------------------------------------------------------------------
/exchange_rate.sh:
--------------------------------------------------------------------------------
1 | python3.6 main.py --data="data/exchange_rate.txt" --CNNFilters=50 --GRUUnits=50 --skip=12 --save="save/exchange_rate" --test --savehistory --logfilename="log/lstnet" --debuglevel=20
2 |
--------------------------------------------------------------------------------
/lstnet_datautil.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | # Logging
4 | from __main__ import logger_name
5 | import logging
6 | log = logging.getLogger(logger_name)
7 |
8 | class DataUtil(object):
9 | #
10 | # This class contains data specific information.
11 | # It does the following:
12 | # - Read data from file
13 | # - Normalise it
14 | # - Split it into train, dev (validation) and test
15 | # - Create X and Y for each of the 3 sets (train, dev, test) according to the following:
16 | # Every sample (x, y) shall be created as follows:
17 | # - x --> window number of values
18 | # - y --> one value that is at horizon in the future i.e. that is horizon away past the last value of x
19 | # This way X and Y will have the following dimensions:
20 | # - X [number of samples, window, number of multivariate time series]
21 | # - Y [number of samples, number of multivariate time series]
22 |
23 | def __init__(self, filename, train, valid, horizon, window, normalise = 2):
24 | try:
25 | fin = open(filename)
26 |
27 | log.debug("Start reading data")
28 | self.rawdata = np.loadtxt(fin, delimiter=',')
29 | log.debug("End reading data")
30 |
31 | self.w = window
32 | self.h = horizon
33 | self.data = np.zeros(self.rawdata.shape)
34 | self.n, self.m = self.data.shape
35 | self.normalise = normalise
36 | self.scale = np.ones(self.m)
37 |
38 | self.normalise_data(normalise)
39 | self.split_data(train, valid)
40 | except IOError as err:
41 | # In case file is not found, all of the above attributes will not have been created
42 | # Hence, in order to check if this call was successful, you can call hasattr on this object
43 | # to check if it has attribute 'data' for example
44 | log.error("Error opening data file ... %s", err)
45 |
46 |
47 | def normalise_data(self, normalise):
48 | log.debug("Normalise: %d", normalise)
49 |
50 | if normalise == 0: # do not normalise
51 | self.data = self.rawdata
52 |
53 | if normalise == 1: # same normalisation for all timeseries
54 | self.data = self.rawdata / np.max(self.rawdata)
55 |
56 | if normalise == 2: # normalise each timeseries alone. This is the default mode
57 | for i in range(self.m):
58 | self.scale[i] = np.max(np.abs(self.rawdata[:, i]))
59 | self.data[:, i] = self.rawdata[:, i] / self.scale[i]
60 |
61 | def split_data(self, train, valid):
62 | log.info("Splitting data into training set (%.2f), validation set (%.2f) and testing set (%.2f)", train, valid, 1 - (train + valid))
63 |
64 | train_set = range(self.w + self.h - 1, int(train * self.n))
65 | valid_set = range(int(train * self.n), int((train + valid) * self.n))
66 | test_set = range(int((train + valid) * self.n), self.n)
67 |
68 | self.train = self.get_data(train_set)
69 | self.valid = self.get_data(valid_set)
70 | self.test = self.get_data(test_set)
71 |
72 | def get_data(self, rng):
73 | n = len(rng)
74 |
75 | X = np.zeros((n, self.w, self.m))
76 | Y = np.zeros((n, self.m))
77 |
78 | for i in range(n):
79 | end = rng[i] - self.h + 1
80 | start = end - self.w
81 |
82 | X[i,:,:] = self.data[start:end, :]
83 | Y[i,:] = self.data[rng[i],:]
84 |
85 | return [X, Y]
86 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | ####################################################################################
2 | # Implementation of the following paper: https://arxiv.org/pdf/1703.07015.pdf #
3 | # #
4 | # Modeling Long- and Short-Term Temporal Patterns with Deep Neural Networks #
5 | ####################################################################################
6 |
7 | # This must be set in the beggining because in model_util, we import it
8 | logger_name = "lstnet"
9 |
10 | # Path appended in order to import from util
11 | import sys
12 | sys.path.append('..')
13 | from util.model_util import LoadModel, SaveModel, SaveResults, SaveHistory
14 | from util.Msglog import LogInit
15 |
16 | from datetime import datetime
17 |
18 | from lstnet_util import GetArguments, LSTNetInit
19 | from lstnet_datautil import DataUtil
20 | from lstnet_model import PreSkipTrans, PostSkipTrans, PreARTrans, PostARTrans, LSTNetModel, ModelCompile
21 | from lstnet_plot import AutoCorrelationPlot, PlotHistory, PlotPrediction
22 |
23 | import tensorflow as tf
24 |
25 |
26 | custom_objects = {
27 | 'PreSkipTrans': PreSkipTrans,
28 | 'PostSkipTrans': PostSkipTrans,
29 | 'PreARTrans': PreARTrans,
30 | 'PostARTrans': PostARTrans
31 | }
32 |
33 | def train(model, data, init, tensorboard = None):
34 | if init.validate == True:
35 | val_data = (data.valid[0], data.valid[1])
36 | else:
37 | val_data = None
38 |
39 | start_time = datetime.now()
40 | history = model.fit(
41 | x = data.train[0],
42 | y = data.train[1],
43 | epochs = init.epochs,
44 | batch_size = init.batchsize,
45 | validation_data = val_data,
46 | callbacks = [tensorboard] if tensorboard else None
47 | )
48 | end_time = datetime.now()
49 | log.info("Training time took: %s", str(end_time - start_time))
50 |
51 | return history
52 |
53 |
54 | if __name__ == '__main__':
55 | try:
56 | args = GetArguments()
57 | except SystemExit as err:
58 | print("Error reading arguments")
59 | exit(0)
60 |
61 | test_result = None
62 |
63 | # Initialise parameters
64 | lstnet_init = LSTNetInit(args)
65 |
66 | # Initialise logging
67 | log = LogInit(logger_name, lstnet_init.logfilename, lstnet_init.debuglevel, lstnet_init.log)
68 | log.info("Python version: %s", sys.version)
69 | log.info("Tensorflow version: %s", tf.__version__)
70 | log.info("Keras version: %s ... Using tensorflow embedded keras", tf.keras.__version__)
71 |
72 | # Dumping configuration
73 | lstnet_init.dump()
74 |
75 | # Reading data
76 | Data = DataUtil(lstnet_init.data,
77 | lstnet_init.trainpercent,
78 | lstnet_init.validpercent,
79 | lstnet_init.horizon,
80 | lstnet_init.window,
81 | lstnet_init.normalise)
82 |
83 | # If file does not exist, then Data will not have attribute 'data'
84 | if hasattr(Data, 'data') is False:
85 | log.critical("Could not load data!! Exiting")
86 | exit(1)
87 |
88 | log.info("Training shape: X:%s Y:%s", str(Data.train[0].shape), str(Data.train[1].shape))
89 | log.info("Validation shape: X:%s Y:%s", str(Data.valid[0].shape), str(Data.valid[1].shape))
90 | log.info("Testing shape: X:%s Y:%s", str(Data.test[0].shape), str(Data.test[1].shape))
91 |
92 | if lstnet_init.plot == True and lstnet_init.autocorrelation is not None:
93 | AutoCorrelationPlot(Data, lstnet_init)
94 |
95 | # If --load is set, load model from file, otherwise create model
96 | if lstnet_init.load is not None:
97 | log.info("Load model from %s", lstnet_init.load)
98 | lstnet = LoadModel(lstnet_init.load, custom_objects)
99 | else:
100 | log.info("Creating model")
101 | lstnet = LSTNetModel(lstnet_init, Data.train[0].shape)
102 |
103 | if lstnet is None:
104 | log.critical("Model could not be loaded or created ... exiting!!")
105 | exit(1)
106 |
107 | # Compile model
108 | lstnet_tensorboard = ModelCompile(lstnet, lstnet_init)
109 | if lstnet_tensorboard is not None:
110 | log.info("Model compiled ... Open tensorboard in order to visualise it!")
111 | else:
112 | log.info("Model compiled ... No tensorboard visualisation is available")
113 |
114 | # Model Training
115 | if lstnet_init.train is True:
116 | # Train the model
117 | log.info("Training model ... ")
118 | h = train(lstnet, Data, lstnet_init, lstnet_tensorboard)
119 |
120 | # Plot training metrics
121 | if lstnet_init.plot is True:
122 | PlotHistory(h.history, ['loss', 'rse', 'corr'], lstnet_init)
123 |
124 | # Saving model if lstnet_init.save is not None.
125 | # There's no reason to save a model if lstnet_init.train == False
126 | SaveModel(lstnet, lstnet_init.save)
127 | if lstnet_init.saveresults == True:
128 | SaveResults(lstnet, lstnet_init, h.history, test_result, ['loss', 'rse', 'corr'])
129 | if lstnet_init.savehistory == True:
130 | SaveHistory(lstnet_init.save, h.history)
131 |
132 | # Validation
133 | if lstnet_init.train is False and lstnet_init.validate is True:
134 | loss, rse, corr = lstnet.evaluate(Data.valid[0], Data.valid[1])
135 | log.info("Validation on the validation set returned: Loss:%f, RSE:%f, Correlation:%f", loss, rse, corr)
136 | elif lstnet_init.validate == True:
137 | log.info("Validation on the validation set returned: Loss:%f, RSE:%f, Correlation:%f",
138 | h.history['val_loss'][-1], h.history['val_rse'][-1], h.history['val_corr'][-1])
139 |
140 | # Testing evaluation
141 | if lstnet_init.evaltest is True:
142 | loss, rse, corr = lstnet.evaluate(Data.test[0], Data.test[1])
143 | log.info("Validation on the test set returned: Loss:%f, RSE:%f, Correlation:%f", loss, rse, corr)
144 | test_result = {'loss': loss, 'rse': rse, 'corr': corr}
145 |
146 | # Prediction
147 | if lstnet_init.predict is not None:
148 | if lstnet_init.predict == 'trainingdata' or lstnet_init.predict == 'all':
149 | log.info("Predict training data")
150 | trainPredict = lstnet.predict(Data.train[0])
151 | else:
152 | trainPredict = None
153 | if lstnet_init.predict == 'validationdata' or lstnet_init.predict == 'all':
154 | log.info("Predict validation data")
155 | validPredict = lstnet.predict(Data.valid[0])
156 | else:
157 | validPredict = None
158 | if lstnet_init.predict == 'testingdata' or lstnet_init.predict == 'all':
159 | log.info("Predict testing data")
160 | testPredict = lstnet.predict(Data.test[0])
161 | else:
162 | testPredict = None
163 |
164 | if lstnet_init.plot is True:
165 | PlotPrediction(Data, lstnet_init, trainPredict, validPredict, testPredict)
166 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # LSTNet
2 | This repository is a Tensorflow / Keras implementation of __*Modeling Long- and Short-Term Temporal Patterns with Deep Neural Networks*__ paper https://arxiv.org/pdf/1703.07015.pdf
3 |
4 | This implementation has been inspired by the following Pytorch implementation https://github.com/laiguokun/LSTNet
5 |
6 | ## Installation
7 | Clone this prerequisite repository:
8 | ```shell
9 | git clone https://github.com/fbadine/util.git
10 | ```
11 |
12 | Clone this repository:
13 | ```shell
14 | git clone https://github.com/fbadine/LSTNet.git
15 | cd LSTNet
16 | mkdir log/ save/ data/
17 | ```
18 |
19 | Download the dataset from https://github.com/laiguokun/multivariate-time-series-data and copy the text files into LSTNet/data/
20 |
21 | ## Usage
22 | ### Training
23 | There are 4 different script samples to train, validate and test the model on the different datasets:
24 | - electricity.sh
25 | - exchange_rate.sh
26 | - solar.sh
27 | - traffic.sh
28 |
29 | ### Predict
30 | In order to predict and plot traffic you will need to run `main.py` as follows (example for the electricity traffic)
31 | ```shell
32 | python main.py --data="data/electricity.txt" --no-train --load="save/electricity/electricity" --predict=all --plot --series-to-plot=0
33 | ```
34 |
35 | ### Running Options
36 | The following are the parameters that the python script takes along with their description:
37 |
38 | | Input Parameters | Default | Description |
39 | | :-----------------| :------------------| :-----------|
40 | | --data | |Full Path of the data file. __(REQUIRED)__|
41 | | --normalize |2 |Type of data normalisation:
- 0: No Normalisation
- 1: Normalise all timeseries together
- 2: Normalise each timeseries alone|
42 | | --trainpercent |0.6 |Percentage of the given data to use for training|
43 | | --validpercent |0.2 |Percentage of the given data to use for validation|
44 | | --window |24 * 7 |Number of time values to consider in each input X|
45 | | --horizon |12 |How far is the predicted value Y. It is horizon values away from the last value of X (into the future)|
46 | | --CNNFilters |100 |Number of output filters in the CNN layer
A value of 0 will remove this layer|
47 | | --CNNKernel |6 |CNN filter size that will be (CNNKernel, number of multivariate timeseries)
A value of 0 will remove this layer|
48 | | --GRUUnits |100 |Number of hidden states in the GRU layer|
49 | | --SkipGRUUnits |5 |Number of hidden states in the SkipGRU layer|
50 | | --skip |24 |Number of timeslots to skip.
A value of 0 will remove this layer|
51 | | --dropout |0.2 |Dropout frequency|
52 | | --highway |24 |Number of timeslots values to consider for the linear layer (AR layer)|
53 | | --initializer |glorot_uniform |The weights initialiser to use|
54 | | --loss |mean_absolute_error |The loss function to use for optimisation|
55 | | --optimizer |Adam |The optimiser to use
Accepted values:
- SGD
- RMSprop
- Adam|
56 | | --lr |0.001 |Learning rate|
57 | | --batchsize |128 |Training batchsize|
58 | | --epochs |100 |Number of training epochs|
59 | | --tensorboard |None |Set to the folder where to put the tensorboard file
If set to None => no tensorboard|
60 | | --no-train | |Do not train the model|
61 | | --no-validation | |Do not validate the model|
62 | | --test | |Evaluate the model on the test data|
63 | | --load |None |Location and Name of the file to load a pre-trained model from as follows:
- Model in filename.json
- Weights in filename.h5|
64 | | --save |None |Full path of the file to save the model in as follows:
- Model in filename.json
- Weights in filename.h5
This location is also used to save results and history as follows:
- Results in filename.txt
- History in filename_history.csv if --savehistory is passed|
65 | | --no-saveresults | |Do not save results|
66 | | --savehistory | |Save training / validation history in file as described in parameter --save above|
67 | | --predict |None |Predict timeseries using the trained model
It takes one of the following values:
- trainingdata: predict the training data only
- validationdata: predict the validation data only
- testingdata: predict the testing data only
- all: all of the above
- None: none of the above|
68 | | --plot | |Generate plots|
69 | | --series-to-plot |0 |Series to plot
Format: series,start,end
- series: the number of the series you wish to plot
- start: start timeslot (default is the start of the timeseries)
- end: end timeslot (default is the end of the timeseries)|
70 | | --autocorrelation |None |Autocorrelation plotting
Format: series,start,end
- series: the number of random timeseries you wish to plot the autocorrelation for
- start: start timeslot (default is the start of the timeseries)
- end: end timeslot (default is the end of the timeseries)|
71 | | --save-plot | None | Location and name of the file to save the plotted images to
- Autocorrelation in filename_autocorrelation.png
- Training history in filename_training.png
- Prediction in filename_prediction.png|
72 | | --no-log | |Do not create logfiles
However error and critical messages will still appear|
73 | | --logfilename |log/lstnet |Full path of the logging file|
74 | | --debuglevel |20 |Logging debug level|
75 |
76 |
77 | ## Results
78 | The followinng are the results that were obtained:
79 |
80 | | Dataset | Width | Horizon | Correlation | RSE |
81 | | :-------------| :-----------| :-----------| :-----------| :-----------|
82 | | Solar | 28 hours | 2 hours | 0.9548 | 0.3060 |
83 | | Traffic | 7 days | 12 hours | 0.8932 | 0.4089 |
84 | | Electricity | 7 days | 24 hours | 0.8856 | 0.3746 |
85 | | Exchange Rate | 168 days | 12 days | 0.9731 | 0.1540 |
86 |
87 | ## Dataset
88 | As described in the paper the data is composed of 4 publicly available datasets downloadable from https://github.com/laiguokun/multivariate-time-series-data:
89 | - __Traffic:__ A collection of 48 months (2015-2016) hourly data from the California Department of Transportation
90 | - __Solar Energy:__ The solar power production records in 2006, sampled every 10 minutes from 137 PV plants in the state of Alabama
91 | - __Electricity:__ Electricity consumption for 321 clients recorded every 15 minutes from 2012 to 2014
92 | - __Exchange Rate:__ A collection of daily average rates of 8 currencies from 1990 to 2016
93 |
94 | ## Environment
95 | ### Primary environment
96 | The results were obtained on a system with the following versions:
97 | - Python 3.6.8
98 | - Tensorflow 1.11.0
99 | - Keras 2.1.6-tf
100 |
101 | ### TensorFlow 2.0 Ready
102 | The model has also been tested on TF 2.0 alpha version:
103 | - Python 3.6.7
104 | - Tensorflow 2.0.0-alpha0
105 | - Keras 2.2.4-tf
106 |
--------------------------------------------------------------------------------
/lstnet_plot.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 |
4 | from pandas.plotting import autocorrelation_plot
5 |
6 | # logging
7 | from __main__ import logger_name
8 | import logging
9 | log = logging.getLogger(logger_name)
10 |
11 |
12 | def AutoCorrelationPlot(Data, init):
13 | if Data is not None and init is not None:
14 | log.info("Plotting autocorrelation ...")
15 | #
16 | # init.autocorrelation has the following format: number_of_series,start,end
17 | # which means that we will be plotting an autocorrelation for number_of_series random series from start to end
18 | #
19 | # Here we are transforming this series into a list of integers if possible
20 | #
21 | s = [int(i) if i.isdigit() else i for i in init.autocorrelation.split(',')]
22 |
23 | #
24 | # Check if the first element in the list is an integer
25 | # and is between 1 and the number of timeseries otherwise set it to the number of timeseries i.e. plot all
26 | #
27 | try:
28 | assert(s[0] and type(s[0]) == int and s[0] > 0 and s[0] <= Data.m)
29 | number = s[0]
30 | except AssertionError as err:
31 | log.warning("The number of series to plot autocorrelation for must be in the range [1,%d]. Setting it to %d", Data.m, Data.m)
32 | number = Data.m
33 |
34 | #
35 | # Check if the second element in the list exists (len(s)>1) and is an integer
36 | # and is less than the length of the timeseries otherwise set it to 0 (start of the timeseries)
37 | #
38 | try:
39 | assert(len(s) > 1 and s[1] and type(s[1]) == int and s[1] < Data.n)
40 | start_plot = s[1]
41 | except AssertionError as err:
42 | log.warning("start must be an integer less than %d. Setting it to 0", Data.n)
43 | start_plot = 0
44 |
45 | #
46 | # Check if the third element in the list exists (len(s)>2) and is an integer and is bigger than the start_plot
47 | # and is less than the length of the timeseries otherwise set it to end of the timeseries
48 | #
49 | try:
50 | assert(len(s) > 2 and s[2] and type(s[2]) == int and s[2] > start_plot and s[2] < Data.n)
51 | end_plot = s[2]
52 | except AssertionError as err:
53 | log.warning("end must be an integer in the range ]%d,%d[. Setting it to %d", start_plot, Data.n, Data.n - 1)
54 | end_plot = Data.n - 1
55 |
56 | fig = plt.figure()
57 |
58 | log.debug("Plotting autocorrelation for %d random timeseries out of %d. Timeslot from %d to %d", number, Data.m, start_plot, end_plot)
59 | series = np.random.choice(range(Data.m), number, replace=False)
60 | for i in series:
61 | autocorrelation_plot(Data.data[start_plot:end_plot,i])
62 |
63 | fig.canvas.set_window_title('Auto Correlation')
64 | plt.show()
65 |
66 | if init.save_plot is not None:
67 | log.debug("Saving autocorrelation plot to: %s", init.save_plot + "_autocorrelation.png")
68 | fig.savefig(init.save_plot + "_autocorrelation.png")
69 |
70 |
71 | def PlotHistory(history, metrics, init):
72 | if history is not None and metrics is not None and init is not None:
73 | log.info("Plotting history ...")
74 |
75 | # Number of keys present in the history dictionary
76 | i = 1
77 |
78 | #
79 | # Number of metrics that were trained and are available in history
80 | # This will help us determine the width of the canvas as well as correctly set
81 | # the parameters to subplot
82 | #
83 | n = len(history)
84 |
85 | #
86 | # The number of rows is set so that the training results are on one line and the
87 | # validation ones are on the second. Therefore:
88 | # number of available metrics in history
89 | # rows = --------------------------------------- = 2 in case of validate=True. Otherwise 1
90 | # number of metrics
91 | #
92 | rows = int(n / len(metrics))
93 |
94 | #
95 | # Number of columns i.e. number of different metrics plotted for each of the training and validation
96 | #
97 | cols = int(n / rows)
98 |
99 | #
100 | # Set the plotting image size
101 | # If the number of columns is greater than 2, choose 16, otherwise 12
102 | # If the number of rows is greater than 1, choose 10, otherwise 5
103 | #
104 | fig = plt.figure(figsize=(16 if cols > 2 else 12, 10 if rows > 1 else 5))
105 |
106 | # Training data history plot
107 | for m in metrics:
108 | key = m
109 | log.debug("Plotting metrics %s", key)
110 | plt.subplot(rows, cols, i)
111 | plt.plot(history[key])
112 | plt.ylabel(m.title())
113 | plt.xlabel("Epochs")
114 | plt.title("Training " + m.title())
115 | i = i + 1
116 |
117 | # Validation data history plot (if available)
118 | for m in metrics:
119 | key = "val_" + m
120 | log.debug("Plotting metrics %s", key)
121 | # if key is not in history.keys() => --validate was set to False and therefore history for validation is not available
122 | if key in history.keys():
123 | plt.subplot(rows, cols, i)
124 | plt.plot(history[key])
125 | plt.ylabel(m.title())
126 | plt.xlabel("Epochs")
127 | plt.title("Validation " + m.title())
128 | i = i + 1
129 |
130 | fig.canvas.set_window_title('Training History')
131 | plt.show()
132 |
133 | if init.save_plot is not None:
134 | log.debug("Saving training history plot to: %s", init.save_plot + "_training.png")
135 | fig.savefig(init.save_plot + "_training.png")
136 |
137 | def PlotPrediction(Data, init, trainPredict, validPredict, testPredict):
138 | if Data is not None and init is not None:
139 | log.info("Plotting Prediction ...")
140 | #
141 | # init.series_to_plot has the following format: series_number,start,end
142 | # which means that we will be plotting series # series_number from start to end
143 | #
144 | # Here we are transforming this series into a list of integers if possible
145 | #
146 | s = [int(i) if i.isdigit() else i for i in init.series_to_plot.split(',')]
147 |
148 | #
149 | # Check if the first element in the list is an integer
150 | # and is less than the number of timeseries otherwise set it to 0
151 | #
152 | try:
153 | assert(s[0] and type(s[0]) == int and s[0] < Data.m)
154 | series = s[0]
155 | except AssertionError as err:
156 | log.warning("The series to plot must be an integer in the range [0,%d[. Setting it to 0", Data.m)
157 | series = 0
158 |
159 | #
160 | # Check if the second element in the list exists (len(s)>1) and is an integer
161 | # and is less than the length of the timeseries otherwise set it to 0 (start of the timeseries)
162 | #
163 | try:
164 | assert(len(s) > 1 and s[1] and type(s[1]) == int and s[1] < Data.n)
165 | start_plot = s[1]
166 | except AssertionError as err:
167 | log.warning("start must be an integer less than %d. Setting it to 0", Data.n)
168 | start_plot = 0
169 |
170 | #
171 | # Check if the third element in the list exists (len(s)>2) and is an integer and is bigger than the start_plot
172 | # and is less than the length of the timeseries otherwise set it to end of the timeseries
173 | #
174 | try:
175 | assert(len(s) > 2 and s[2] and type(s[2]) == int and s[2] > start_plot and s[2] < Data.n)
176 | end_plot = s[2]
177 | except AssertionError as err:
178 | log.warning("end must be an integer in the range ]%d,%d[. Setting it to %d", start_plot, Data.n, Data.n - 1)
179 | end_plot = Data.n - 1
180 |
181 |
182 | #
183 | # Create empty series of the same length of the data and set the values to nan
184 | # This way, we can fill the appropriate section for train, valid, test so that
185 | # when we print them, they appear at the appropriate loction with respect to the original timeseries
186 | #
187 | log.debug("Initialising trainPredictPlot, ValidPredictPlot, testPredictPlot")
188 | trainPredictPlot = np.empty((Data.n, Data.m))
189 | trainPredictPlot[:,:] = np.nan
190 | validPredictPlot = np.empty((Data.n, Data.m))
191 | validPredictPlot[:,:] = np.nan
192 | testPredictPlot = np.empty((Data.n, Data.m))
193 | testPredictPlot[:,:] = np.nan
194 |
195 | #
196 | # We use window data to predict a value at horizon from the end of the window, therefore start is
197 | # is at the end of the horizon
198 | #
199 | start = init.window + init.horizon - 1
200 | end = start + len(Data.train[0]) # Same length as trainPredict however we might not have trainPredict
201 | if trainPredict is not None:
202 | log.debug("Filling trainPredictPlot from %d to %d", start, end)
203 | trainPredictPlot[start:end, :] = trainPredict
204 |
205 | start = end
206 | end = start + len(Data.valid[0]) # Same length as validPredict however we might not have validPredict
207 | if validPredict is not None:
208 | log.debug("Filling validPredictPlot from %d to %d", start, end)
209 | validPredictPlot[start:end, :] = validPredict
210 |
211 | start = end
212 | end = start + len(Data.test[0]) # Same length as testPredict however we might not have testPredict
213 | if testPredict is not None:
214 | log.debug("Filling testPredictPlot from %d to %d", start, end)
215 | testPredictPlot[start:end, :] = testPredict
216 |
217 | # Plotting the original series and whatever is available of trainPredictPlot, validPredictPlot and testPredictPlot
218 | fig = plt.figure()
219 |
220 | plt.plot(Data.data[start_plot:end_plot, series])
221 | plt.plot(trainPredictPlot[start_plot:end_plot, series])
222 | plt.plot(validPredictPlot[start_plot:end_plot, series])
223 | plt.plot(testPredictPlot[start_plot:end_plot, series])
224 |
225 | plt.ylabel("Timeseries")
226 | plt.xlabel("Time")
227 | plt.title("Prediction Plotting for timeseries # %d" % (series))
228 |
229 | fig.canvas.set_window_title('Prediction')
230 |
231 | plt.show()
232 |
233 | if init.save_plot is not None:
234 | log.debug("Saving prediction plot to: %s", init.save_plot + "_prediction.png")
235 | fig.savefig(init.save_plot + "_prediction.png")
236 |
--------------------------------------------------------------------------------
/lstnet_util.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | class LSTNetInit(object):
4 | #
5 | # This class contains all initialisation information that are passed as arguments.
6 | #
7 | # data: Location of the data file
8 | # window: Number of time values to consider in each input X
9 | # Default: 24*7
10 | # horizon: How far is the predicted value Y. It is horizon values away from the last value of X (into the future)
11 | # Default: 12
12 | # CNNFilters: Number of output filters in the CNN layer
13 | # Default: 100
14 | # If set to 0, the CNN layer will be omitted
15 | # CNNKernel: CNN filter size that will be (CNNKernel, number of multivariate timeseries)
16 | # Default: 6
17 | # If set to 0, the CNN layer will be omitted
18 | # GRUUnits: Number of hidden states in the GRU layer
19 | # Default: 100
20 | # SkipGRUUnits: Number of hidden states in the SkipGRU layer
21 | # Default: 5
22 | # skip: Number of timeseries to skip. 0 => do not add Skip GRU layer
23 | # Default: 24
24 | # If set to 0, the SkipGRU layer will be omitted
25 | # dropout: Dropout frequency
26 | # Default: 0.2
27 | # normalise: Type of normalisation:
28 | # - 0: No normalisation
29 | # - 1: Normalise all timeseries together
30 | # - 2: Normalise each timeseries alone
31 | # Default: 2
32 | # batchsize: Training batch size
33 | # Default: 128
34 | # epochs: Number of training epochs
35 | # Default: 100
36 | # initialiser: The weights initialiser to use.
37 | # Default: glorot_uniform
38 | # trainpercent: How much percentage of the given data to use for training.
39 | # Default: 0.6 (60%)
40 | # validpercent: How much percentage of the given data to use for validation.
41 | # Default: 0.2 (20%)
42 | # The remaining (1 - trainpercent -validpercent) shall be the amount of test data
43 | # highway: Number of timeseries values to consider for the linear layer (AR layer)
44 | # Default: 24
45 | # If set to 0, the AR layer will be omitted
46 | # train: Whether to train the model or not
47 | # Default: True
48 | # validate: Whether to validate the model against the validation data
49 | # Default: True
50 | # If set and train is set, validation will be done while training.
51 | # evaltest: Evaluate the model using testing data
52 | # Default: False
53 | # save: Location and Name of the file to save the model in as follows:
54 | # Model in "save.json"
55 | # Weights in "save.h5"
56 | # Default: None
57 | # This location is also used to save results and history in, as follows:
58 | # Results in "save.txt" if --saveresults is passed
59 | # History in "save_history.csv" if --savehistory is passed
60 | # saveresults: Save results as described in 'save' above.
61 | # This has no effect if --save is not set
62 | # Default: True
63 | # savehistory: Save training / validation history as described in 'save' above.
64 | # This has no effect if --save is not set
65 | # Default: False
66 | # load: Location and Name of the file to load a pretrained model from as follows:
67 | # Model in "load.json"
68 | # Weights in "load.h5"
69 | # Default: None
70 | # loss: The loss function to use for optimisation.
71 | # Default: mean_absolute_error
72 | # lr: Learning rate
73 | # Default: 0.001
74 | # optimizer: The optimiser to use
75 | # Default: Adam
76 | # test: Evaluate the model on the test data
77 | # Default: False
78 | # tensorboard: Set to the folder where to put tensorboard file
79 | # Default: None (no tensorboard callback)
80 | # predict: Predict timeseries using the trained model.
81 | # It takes one of the following values:
82 | # - trainingdata => predict the training data only
83 | # - validationdata => predict the validation data only
84 | # - testingdata => predict the testing data only
85 | # - all => all of the above
86 | # - None => none of the above
87 | # Default: None
88 | # plot: Generate plots
89 | # Default: False
90 | # series_to_plot: The number of the series that you wish to plot. The value must be less than the number of series available
91 | # Default: 0
92 | # autocorrelation: The number of the random series that you wish to plot their autocorrelation. The value must be less or equal to the number of series available
93 | # Default: None
94 | # save_plot: Location and Name of the file to save the plotted images to as follows:
95 | # Autocorrelation in "save_plot_autocorrelation.png"
96 | # Training results in "save_plot_training.png"
97 | # Prediction in "save_plot_prediction.png"
98 | # Default: None
99 | # log: Whether to generate logging
100 | # Default: True
101 | # debuglevel: Logging debuglevel.
102 | # It takes one of the following values:
103 | # - 10 => DEBUG
104 | # - 20 => INFO
105 | # - 30 => WARNING
106 | # - 40 => ERROR
107 | # - 50 => CRITICAL
108 | # Default: 20
109 | # logfilename: Filename where logging will be written.
110 | # Default: log/lstnet
111 | #
112 | def __init__(self, args, args_is_dictionary = False):
113 | if args_is_dictionary is True:
114 | self.data = args["data"]
115 | self.window = args["window"]
116 | self.horizon = args["horizon"]
117 | self.CNNFilters = args["CNNFilters"]
118 | self.CNNKernel = args["CNNKernel"]
119 | self.GRUUnits = args["GRUUnits"]
120 | self.SkipGRUUnits = args["SkipGRUUnits"]
121 | self.skip = args["skip"]
122 | self.dropout = args["dropout"]
123 | self.normalise = args["normalize"]
124 | self.highway = args["highway"]
125 | self.batchsize = args["batchsize"]
126 | self.epochs = args["epochs"]
127 | self.initialiser = args["initializer"]
128 | self.trainpercent = args["trainpercent"]
129 | self.validpercent = args["validpercent"]
130 | self.highway = args["highway"]
131 | self.train = not args["no_train"]
132 | self.validate = not args["no_validation"]
133 | self.save = args["save"]
134 | self.saveresults = not args["no_saveresults"]
135 | self.savehistory = args["savehistory"]
136 | self.load = args["load"]
137 | self.loss = args["loss"]
138 | self.lr = args["lr"]
139 | self.optimiser = args["optimizer"]
140 | self.evaltest = args["test"]
141 | self.tensorboard = args["tensorboard"]
142 | self.plot = args["plot"]
143 | self.predict = args["predict"]
144 | self.series_to_plot = args["series_to_plot"]
145 | self.autocorrelation = args["autocorrelation"]
146 | self.save_plot = args["save_plot"]
147 | self.log = not args["no_log"]
148 | self.debuglevel = args["debuglevel"]
149 | self.logfilename = args["logfilename"]
150 | else:
151 | self.data = args.data
152 | self.window = args.window
153 | self.horizon = args.horizon
154 | self.CNNFilters = args.CNNFilters
155 | self.CNNKernel = args.CNNKernel
156 | self.GRUUnits = args.GRUUnits
157 | self.SkipGRUUnits = args.SkipGRUUnits
158 | self.skip = args.skip
159 | self.dropout = args.dropout
160 | self.normalise = args.normalize
161 | self.highway = args.highway
162 | self.batchsize = args.batchsize
163 | self.epochs = args.epochs
164 | self.initialiser = args.initializer
165 | self.trainpercent = args.trainpercent
166 | self.validpercent = args.validpercent
167 | self.highway = args.highway
168 | self.train = not args.no_train
169 | self.validate = not args.no_validation
170 | self.save = args.save
171 | self.saveresults = not args.no_saveresults
172 | self.savehistory = args.savehistory
173 | self.load = args.load
174 | self.loss = args.loss
175 | self.lr = args.lr
176 | self.optimiser = args.optimizer
177 | self.evaltest = args.test
178 | self.tensorboard = args.tensorboard
179 | self.plot = args.plot
180 | self.predict = args.predict
181 | self.series_to_plot = args.series_to_plot
182 | self.autocorrelation = args.autocorrelation
183 | self.save_plot = args.save_plot
184 | self.log = not args.no_log
185 | self.debuglevel = args.debuglevel
186 | self.logfilename = args.logfilename
187 |
188 | def dump(self):
189 | from __main__ import logger_name
190 | import logging
191 | log = logging.getLogger(logger_name)
192 |
193 | log.debug("Data: %s", self.data)
194 | log.debug("Window: %d", self.window)
195 | log.debug("Horizon: %d", self.horizon)
196 | log.debug("CNN Filters: %d", self.CNNFilters)
197 | log.debug("CNN Kernel: %d", self.CNNKernel)
198 | log.debug("GRU Units: %d", self.GRUUnits)
199 | log.debug("Skip GRU Units: %d", self.SkipGRUUnits)
200 | log.debug("Skip: %d", self.skip)
201 | log.debug("Dropout: %f", self.dropout)
202 | log.debug("Normalise: %d", self.normalise)
203 | log.debug("Highway: %d", self.highway)
204 | log.debug("Batch size: %d", self.batchsize)
205 | log.debug("Epochs: %d", self.epochs)
206 | log.debug("Learning rate: %s", str(self.lr))
207 | log.debug("Initialiser: %s", self.initialiser)
208 | log.debug("Optimiser: %s", self.optimiser)
209 | log.debug("Loss function to use: %s", self.loss)
210 | log.debug("Fraction of data to be used for training: %.2f", self.trainpercent)
211 | log.debug("Fraction of data to be used for validation: %.2f", self.validpercent)
212 | log.debug("Train model: %s", self.train)
213 | log.debug("Validate model: %s", self.validate)
214 | log.debug("Test model: %s", self.evaltest)
215 | log.debug("Save model location: %s", self.save)
216 | log.debug("Save Results: %s", self.saveresults)
217 | log.debug("Save History: %s", self.savehistory)
218 | log.debug("Load Model from: %s", self.load)
219 | log.debug("TensorBoard: %s", self.tensorboard)
220 | log.debug("Plot: %s", self.plot)
221 | log.debug("Predict: %s", self.predict)
222 | log.debug("Series to plot: %s", self.series_to_plot)
223 | log.debug("Save plot: %s", self.save_plot)
224 | log.debug("Create log: %s", self.log)
225 | log.debug("Debug level: %d", self.debuglevel)
226 | log.debug("Logfile: %s", self.logfilename)
227 |
228 |
229 | def GetArguments():
230 | # Creating the argument parser
231 | parser = argparse.ArgumentParser(description='LSTNet Model')
232 |
233 | parser.add_argument('--data', type=str, required=True, help='Location of the data file. Required!!')
234 | parser.add_argument('--window', type=int, default=24*7, help='Window size. Default=24*7')
235 | parser.add_argument('--horizon', type=int, default=12, help='Horizon width. Default=12')
236 | parser.add_argument('--CNNFilters', type=int, default=100, help='Number of CNN layer filters. Default=100. If set to 0, the CNN layer will be omitted')
237 | parser.add_argument('--CNNKernel', type=int, default=6, help='Size of the CNN filters. Default=6. If set to 0, the CNN layer will be omitted')
238 | parser.add_argument('--GRUUnits', type=int, default=100, help='Number of GRU hidden units. Default=100')
239 | parser.add_argument('--SkipGRUUnits', type=int, default=5, help='Number of hidden units in the Skip GRU layer. Default=5')
240 | parser.add_argument('--skip', type=int, default=24,
241 | help='Size of the window to skip in the Skip GRU layer. Default=24. If set to 0, the SkipGRU layer will be omitted')
242 | parser.add_argument('--dropout', type=float, default=0.2, help='Dropout to be applied to layers. 0 means no dropout. Default=0.2')
243 | parser.add_argument('--normalize', type=int, default=2,
244 | help='0 = do not normalise, 1 = use same normalisation for all timeseries, 2 = normalise each timeseries independently. Default=2')
245 | parser.add_argument('--highway', type=int, default=24, help='The window size of the highway component. Default=24. If set to 0, the AR layer will be omitted')
246 | parser.add_argument('--lr', type=float, default=0.001, help='Learning rate. Default=0.001')
247 | parser.add_argument('--batchsize', type=int, default=128, help='Training batchsize. Default=128')
248 | parser.add_argument('--epochs', type=int, default=100, help='Number of epochs to run for training. Default=100')
249 | parser.add_argument('--initializer', type=str, default="glorot_uniform", help='Weights initialiser to use. Default=glorot_uniform')
250 | parser.add_argument('--loss', type=str, default="mean_absolute_error", help='Loss function to use. Default=mean_absolute_error')
251 | parser.add_argument('--optimizer', type=str, default="Adam", help='Optimisation function to use. Default=Adam')
252 | parser.add_argument('--trainpercent', type=float, default=0.6, help='Percentage of data to be used for training. Default=0.6')
253 | parser.add_argument('--validpercent', type=float, default=0.2, help='Percentage of data to be used for validation. Default=0.2')
254 | parser.add_argument('--save', type=str, default=None, help='Filename initial to save the model and the results in. Default=None')
255 | parser.add_argument('--load', type=str, default=None, help='Filename initial of the saved model to be loaded (model and weights). Default=None')
256 | parser.add_argument('--tensorboard', type=str, default=None,
257 | help='Location of the tensorboard folder. If not set, tensorboard will not be launched. Default=None i.e. no tensorboard callback')
258 | parser.add_argument('--predict', type=str, choices=['trainingdata', 'validationdata', 'testingdata', 'all', None], default=None,
259 | help='Predict timesseries. Default None')
260 | parser.add_argument('--series-to-plot', type=str, default='0', help='Series to plot. Default 0 (i.e. plot the first timeseries)')
261 | parser.add_argument('--autocorrelation', type=str, default=None,
262 | help='Plot an autocorrelation of the input data. Format --autocorrelation=i,j,k which means to plot an autocorrelation of timeseries i from timeslot j to timeslot k')
263 | parser.add_argument('--save-plot', type=str, default=None, help='Filename initial to save the plots to in PNG format. Default=None')
264 |
265 | parser.add_argument('--no-train', action='store_true', help='Do not train model.')
266 | parser.add_argument('--no-validation', action='store_true',
267 | help='Do not validate model. When not set and no-train is not set, data will be validated while training')
268 | parser.add_argument('--test', action='store_true', help='Test model.')
269 | parser.add_argument('--no-saveresults', action='store_true', help='Do not save training / validation results.')
270 | parser.add_argument('--savehistory', action='store_true', help='Save training / validation history.')
271 | parser.add_argument('--plot', action='store_true', help='Generate plots.')
272 |
273 | parser.add_argument('--no-log', action='store_true', help='Do not create log files. Only error and critical messages will appear on the console.')
274 | parser.add_argument('--debuglevel', type=int, choices=[10, 20, 30, 40, 50], default=20, help='Logging debug level. Default 20 (INFO)')
275 | parser.add_argument('--logfilename', type=str, default="log/lstnet", help="Filename where logging will be written. Default: log/lstnet")
276 |
277 | args = parser.parse_args()
278 |
279 | return args
280 |
281 |
--------------------------------------------------------------------------------
/lstnet_model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import tensorflow as tf
4 |
5 | from tensorflow.keras.models import Model, model_from_json
6 | from tensorflow.keras.layers import Input, GRU, Conv2D, Dropout, Flatten, Dense, Reshape, Concatenate, Add
7 | from tensorflow.keras.optimizers import SGD, RMSprop, Adam
8 |
9 | from tensorflow.keras import backend as K
10 | from tensorflow.keras.callbacks import TensorBoard
11 |
12 | #######################################################################################################################
13 | # Start Skip RNN specific layers subsclass #
14 | # #
15 | # The SkipRNN layer is implemented as follows: #
16 | # - Pre Transformation layer that takes a 'window' size of data and apply couple of reshapes and axis transformation #
17 | # in order to simulate the skip RNN that is described in the paper #
18 | # - Apply a normal GRU RNN on the transformed data. This way not the adjacent data points are connected but rather #
19 | # data points that are 'skip' away. #
20 | # - Post Transformation layer that brings back dimensions to its original shape as PreTrans has changed all #
21 | # dimensions including Batch Size #
22 | #######################################################################################################################
23 |
24 | class PreSkipTrans(tf.keras.layers.Layer):
25 | def __init__(self, pt, skip, **kwargs):
26 | #
27 | # pt: Number of different RNN cells = (window / skip)
28 | # skip: Number of points to skip
29 | #
30 | self.pt = pt
31 | self.skip = skip
32 | super(PreSkipTrans, self).__init__(**kwargs)
33 |
34 | def build(self, input_shape):
35 | super(PreSkipTrans, self).build(input_shape)
36 |
37 | def call(self, inputs):
38 | # Get input tensors; in this case it's just one tensor
39 | x = inputs
40 |
41 | # Get the batchsize which is tf.shape(x)[0] since inputs is either X or C which has the same
42 | # batchsize as the input to the model
43 | batchsize = tf.shape(x)[0]
44 |
45 | # Get the shape of the input data
46 | input_shape = K.int_shape(x)
47 |
48 | # Create output data by taking a 'window' size from the end of input (:-self.pt * self.skip)
49 | output = x[:,-self.pt * self.skip:,:]
50 |
51 | # Reshape the output tensor by:
52 | # - Changing first dimension (batchsize) from None to the current batchsize
53 | # - Splitting second dimension into 2 dimensions
54 | output = tf.reshape(output, [batchsize, self.pt, self.skip, input_shape[2]])
55 |
56 | # Permutate axis 1 and axis 2
57 | output = tf.transpose(output, perm=[0, 2, 1, 3])
58 |
59 | # Reshape by merging axis 0 and 1 now hence changing the batch size
60 | # to be equal to current batchsize * skip.
61 | # This way the time dimension will only contain 'pt' entries which are
62 | # just values that were originally 'skip' apart from each other => hence skip RNN ready
63 | output = tf.reshape(output, [batchsize * self.skip, self.pt, input_shape[2]])
64 |
65 | # Adjust the output shape by setting back the batch size dimension to None
66 | output_shape = tf.TensorShape([None]).concatenate(output.get_shape()[1:])
67 |
68 | return output
69 |
70 | def compute_output_shape(self, input_shape):
71 | # Since the batch size is None and dimension on axis=2 has not changed,
72 | # all we need to do is set shape[1] = pt in order to compute the output shape
73 | shape = tf.TensorShape(input_shape).as_list()
74 | shape[1] = self.pt
75 |
76 | return tf.TensorShape(shape)
77 |
78 |
79 | def get_config(self):
80 | config = {'pt': self.pt, 'skip': self.skip}
81 | base_config = super(PreSkipTrans, self).get_config()
82 |
83 | return dict(list(base_config.items()) + list(config.items()))
84 |
85 |
86 | @classmethod
87 | def from_config(cls, config):
88 | return cls(**config)
89 |
90 |
91 | class PostSkipTrans(tf.keras.layers.Layer):
92 | def __init__(self, skip, **kwargs):
93 | #
94 | # skip: Number of points to skip
95 | #
96 | self.skip = skip
97 | super(PostSkipTrans, self).__init__(**kwargs)
98 |
99 | def build(self, input_shape):
100 | super(PostSkipTrans, self).build(input_shape)
101 |
102 | def call(self, inputs):
103 | # Get input tensors
104 | # - First one is the output of the SkipRNN layer which we will operate on
105 | # - The second is the oiriginal model input tensor which we will use to get
106 | # the original batchsize
107 | x, original_model_input = inputs
108 |
109 | # Get the batchsize which is tf.shape(original_model_input)[0]
110 | batchsize = tf.shape(original_model_input)[0]
111 |
112 | # Get the shape of the input data
113 | input_shape = K.int_shape(x)
114 |
115 | # Split the batch size into the original batch size before PreTrans and 'Skip'
116 | output = tf.reshape(x, [batchsize, self.skip, input_shape[1]])
117 |
118 | # Merge the 'skip' with axis=1
119 | output = tf.reshape(output, [batchsize, self.skip * input_shape[1]])
120 |
121 | # Adjust the output shape by setting back the batch size dimension to None
122 | output_shape = tf.TensorShape([None]).concatenate(output.get_shape()[1:])
123 |
124 | return output
125 |
126 | def compute_output_shape(self, input_shape):
127 | # Adjust shape[1] to be equal to shape[1] * skip in order for the
128 | # shape to reflect the transformation that was done
129 | shape = tf.TensorShape(input_shape).as_list()
130 | shape[1] = self.skip * shape[1]
131 |
132 | return tf.TransformShape(shape)
133 |
134 |
135 | def get_config(self):
136 | config = {'skip': self.skip}
137 | base_config = super(PostSkipTrans, self).get_config()
138 |
139 | return dict(list(base_config.items()) + list(config.items()))
140 |
141 |
142 | @classmethod
143 | def from_config(cls, config):
144 | return cls(**config)
145 |
146 |
147 | #######################################################################################################################
148 | # End Skip RNN specific layers subsclass #
149 | #######################################################################################################################
150 |
151 |
152 | #######################################################################################################################
153 | # Start AR specific layers subsclass #
154 | # #
155 | # The AR layer is implemented as follows: #
156 | # - Pre Transformation layer that takes a 'highway' size of data and apply a reshape and axis transformation #
157 | # - Flatten the output and pass it through a Dense layer with one output #
158 | # - Post Transformation layer that bring back dimensions to its original shape # #
159 | #######################################################################################################################
160 |
161 | class PreARTrans(tf.keras.layers.Layer):
162 | def __init__(self, hw, **kwargs):
163 | #
164 | # hw: Highway = Number of timeseries values to consider for the linear layer (AR layer)
165 | #
166 | self.hw = hw
167 | super(PreARTrans, self).__init__(**kwargs)
168 |
169 | def build(self, input_shape):
170 | super(PreARTrans, self).build(input_shape)
171 |
172 | def call(self, inputs):
173 | # Get input tensors; in this case it's just one tensor: X = the input to the model
174 | x = inputs
175 |
176 | # Get the batchsize which is tf.shape(x)[0]
177 | batchsize = tf.shape(x)[0]
178 |
179 | # Get the shape of the input data
180 | input_shape = K.int_shape(x)
181 |
182 | # Select only 'highway' length of input to create output
183 | output = x[:,-self.hw:,:]
184 |
185 | # Permute axis 1 and 2. axis=2 is the the dimension having different time-series
186 | # This dimension should be equal to 'm' which is the number of time-series.
187 | output = tf.transpose(output, perm=[0,2,1])
188 |
189 | # Merge axis 0 and 1 in order to change the batch size
190 | output = tf.reshape(output, [batchsize * input_shape[2], self.hw])
191 |
192 | # Adjust the output shape by setting back the batch size dimension to None
193 | output_shape = tf.TensorShape([None]).concatenate(output.get_shape()[1:])
194 |
195 | return output
196 |
197 | def compute_output_shape(self, input_shape):
198 | # Set the shape of axis=1 to be hw since the batchsize is NULL
199 | shape = tf.TensorShape(input_shape).as_list()
200 | shape[1] = self.hw
201 |
202 | return tf.TensorShape(shape)
203 |
204 | def get_config(self):
205 | config = {'hw': self.hw}
206 | base_config = super(PreARTrans, self).get_config()
207 |
208 | return dict(list(base_config.items()) + list(config.items()))
209 |
210 | @classmethod
211 | def from_config(cls, config):
212 | return cls(**config)
213 |
214 |
215 | class PostARTrans(tf.keras.layers.Layer):
216 | def __init__(self, m, **kwargs):
217 | #
218 | # m: Number of timeseries
219 | #
220 | self.m = m
221 | super(PostARTrans, self).__init__(**kwargs)
222 |
223 | def build(self, input_shape):
224 | super(PostARTrans, self).build(input_shape)
225 |
226 | def call(self, inputs):
227 | # Get input tensors
228 | # - First one is the output of the Dense(1) layer which we will operate on
229 | # - The second is the oiriginal model input tensor which we will use to get
230 | # the original batchsize
231 | x, original_model_input = inputs
232 |
233 | # Get the batchsize which is tf.shape(original_model_input)[0]
234 | batchsize = tf.shape(original_model_input)[0]
235 |
236 | # Get the shape of the input data
237 | input_shape = K.int_shape(x)
238 |
239 | # Reshape the output to have the batch size equal to the original batchsize before PreARTrans
240 | # and the second dimension as the number of timeseries
241 | output = tf.reshape(x, [batchsize, self.m])
242 |
243 | # Adjust the output shape by setting back the batch size dimension to None
244 | output_shape = tf.TensorShape([None]).concatenate(output.get_shape()[1:])
245 |
246 | return output
247 |
248 | def compute_output_shape(self, input_shape):
249 | # Adjust shape[1] to be equal 'm'
250 | shape = tf.TensorShape(input_shape).as_list()
251 | shape[1] = self.m
252 |
253 | return tf.TensorShape(shape)
254 |
255 | def get_config(self):
256 | config = {'m': self.m}
257 | base_config = super(PostARTrans, self).get_config()
258 |
259 | return dict(list(base_config.items()) + list(config.items()))
260 |
261 | @classmethod
262 | def from_config(cls, config):
263 | return cls(**config)
264 |
265 | #######################################################################################################################
266 | # End AR specific layers subsclass #
267 | #######################################################################################################################
268 |
269 | #######################################################################################################################
270 | # Model Start #
271 | # #
272 | # The model, as per the paper has the following layers: #
273 | # - CNN #
274 | # - GRU #
275 | # - SkipGRU #
276 | # - AR #
277 | #######################################################################################################################
278 |
279 | def LSTNetModel(init, input_shape):
280 |
281 | # m is the number of time-series
282 | m = input_shape[2]
283 |
284 | # Get tensor shape except batchsize
285 | tensor_shape = input_shape[1:]
286 |
287 | if K.image_data_format() == 'channels_last':
288 | ch_axis = 3
289 | else:
290 | ch_axis = 1
291 |
292 | X = Input(shape = tensor_shape)
293 |
294 | # CNN
295 | if init.CNNFilters > 0 and init.CNNKernel > 0:
296 | # Add an extra dimension of size 1 which is the channel dimension in Conv2D
297 | C = Reshape((input_shape[1], input_shape[2], 1))(X)
298 |
299 | # Apply a Conv2D that will transform it into data of dimensions (batchsize, time, 1, NumofFilters)
300 | C = Conv2D(filters=init.CNNFilters, kernel_size=(init.CNNKernel, m), kernel_initializer=init.initialiser)(C)
301 | C = Dropout(init.dropout)(C)
302 |
303 | # Adjust data dimensions by removing axis=2 which is always equal to 1
304 | c_shape = K.int_shape(C)
305 | C = Reshape((c_shape[1], c_shape[3]))(C)
306 | else:
307 | # If configured not to apply CNN, copy the input
308 | C = X
309 |
310 | # GRU
311 | # Apply a GRU layer (with activation set to 'relu' as per the paper) and take the returned states as result
312 | _, R = GRU(init.GRUUnits, activation="relu", return_sequences = False, return_state = True)(C)
313 | R = Dropout(init.dropout)(R)
314 |
315 | # SkipGRU
316 | if init.skip > 0:
317 | # Calculate the number of values to use which is equal to the window divided by how many time values to skip
318 | pt = int(init.window / init.skip)
319 |
320 | S = PreSkipTrans(pt, int((init.window - init.CNNKernel + 1) / pt))(C)
321 | _, S = GRU(init.SkipGRUUnits, activation="relu", return_sequences = False, return_state = True)(S)
322 | S = PostSkipTrans(int((init.window - init.CNNKernel + 1) / pt))([S,X])
323 |
324 | # Concatenate the outputs of GRU and SkipGRU
325 | R = Concatenate(axis=1)([R,S])
326 |
327 | # Dense layer
328 | Y = Flatten()(R)
329 | Y = Dense(m)(Y)
330 |
331 | # AR
332 | if init.highway > 0:
333 | Z = PreARTrans(init.highway)(X)
334 | Z = Flatten()(Z)
335 | Z = Dense(1)(Z)
336 | Z = PostARTrans(m)([Z,X])
337 |
338 | # Generate output as the summation of the Dense layer output and the AR one
339 | Y = Add()([Y,Z])
340 |
341 | # Generate Model
342 | model = Model(inputs = X, outputs = Y)
343 |
344 | return model
345 |
346 | #######################################################################################################################
347 | # Model End #
348 | #######################################################################################################################
349 |
350 | #######################################################################################################################
351 | # Model Utilities #
352 | # #
353 | # Below is a collection of functions: #
354 | # - rse: A metrics function that calculates the root square error #
355 | # - corr: A metrics function that calculates the correlation #
356 | #######################################################################################################################
357 | def rse(y_true, y_pred):
358 | #
359 | # The formula is:
360 | # K.sqrt(K.sum(K.square(y_true - y_pred)))
361 | # RSE = -----------------------------------------------
362 | # K.sqrt(K.sum(K.square(y_true_mean - y_true)))
363 | #
364 | # K.sqrt(K.sum(K.square(y_true - y_pred))/(N-1))
365 | # = ----------------------------------------------------
366 | # K.sqrt(K.sum(K.square(y_true_mean - y_true)/(N-1)))
367 | #
368 | #
369 | # K.sqrt(K.mean(K.square(y_true - y_pred)))
370 | # = ------------------------------------------
371 | # K.std(y_true)
372 | #
373 | num = K.sqrt(K.mean(K.square(y_true - y_pred), axis=None))
374 | den = K.std(y_true, axis=None)
375 |
376 | return num / den
377 |
378 |
379 | def corr(y_true, y_pred):
380 | #
381 | # This function calculates the correlation between the true and the predicted outputs
382 | #
383 | num1 = y_true - K.mean(y_true, axis=0)
384 | num2 = y_pred - K.mean(y_pred, axis=0)
385 |
386 | num = K.mean(num1 * num2, axis=0)
387 | den = K.std(y_true, axis=0) * K.std(y_pred, axis=0)
388 |
389 | return K.mean(num / den)
390 |
391 | #
392 | # A function that compiles 'model' after setting the appropriate:
393 | # - optimiser function passed via init
394 | # - learning rate passed via init
395 | # - loss function also set in init
396 | # - metrics
397 | #
398 | def ModelCompile(model, init):
399 | # Select the appropriate optimiser and set the learning rate from input values (arguments)
400 | if init.optimiser == "SGD":
401 | opt = SGD(lr=init.lr, momentum=0.0, decay=0.0, nesterov=False)
402 | elif init.optimiser == "RMSprop":
403 | opt = RMSprop(lr=init.lr, rho=0.9, epsilon=None, decay=0.0)
404 | else: # Adam
405 | opt = Adam(lr=init.lr, beta_1=0.9, beta_2=0.999, epsilon=None, decay=0.0, amsgrad=False)
406 |
407 | # Compile using the previously defined metrics
408 | model.compile(optimizer = opt, loss = init.loss, metrics=[rse, corr])
409 |
410 | # Launch Tensorboard if selected in arguments
411 | if init.tensorboard != None:
412 | tensorboard = TensorBoard(log_dir=init.tensorboard)
413 | else:
414 | tensorboard = None
415 |
416 | return tensorboard
417 |
--------------------------------------------------------------------------------