├── .dockerignore ├── .gitignore ├── .travis.yml ├── Dockerfile ├── LICENSE ├── Makefile ├── README.md ├── da_rnn ├── __init__.py ├── common.py ├── keras │ ├── __init__.py │ └── model.py └── torch │ ├── __init__.py │ └── model.py ├── dev-requirements.txt ├── notebook ├── __init__.py ├── common.py ├── keras.ipynb ├── nasdaq100_padding.csv └── pytorch.ipynb ├── requirements.txt ├── setup.cfg ├── setup.py └── test ├── __init__.py └── test_da_rnn.py /.dockerignore: -------------------------------------------------------------------------------- 1 | * 2 | !da_rnn 3 | !requirements.txt 4 | *.pyc 5 | *.log.* 6 | *.log 7 | __pycache__ 8 | *.csv 9 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Python 2 | __pycache__ 3 | *.pyc 4 | *.egg-info 5 | /dist 6 | /.tox 7 | /.pytest_cache 8 | /data 9 | /notebook/* 10 | !/notebook/*.py 11 | !/notebook/keras.ipynb 12 | !/notebook/pytorch.ipynb 13 | 14 | # Jupyter notebook checkpoints 15 | .ipynb_checkpoints 16 | 17 | # Tensorflow checkpoints 18 | *.hdf5 19 | 20 | # Coverage 21 | /.nyc_output 22 | /coverage 23 | /.coverage 24 | /coverage.lcov 25 | 26 | # Tarballs 27 | *.tgz 28 | 29 | # Numerous always-ignore extensions 30 | *.bak 31 | *.patch 32 | *.diff 33 | *.err 34 | *.orig 35 | *.log 36 | *.rej 37 | *.swo 38 | *.swp 39 | *.zip 40 | *.vi 41 | *~ 42 | *.sass-cache 43 | *.lock 44 | *.rdb 45 | *.db 46 | nohup.out 47 | 48 | # OS or Editor folders 49 | .DS_Store 50 | ._* 51 | .cache 52 | .project 53 | .settings 54 | .tmproj 55 | *.esproj 56 | *.*-project 57 | *.*-workspace 58 | nbproject 59 | thumbs.db 60 | 61 | # Folders to ignore 62 | .hg 63 | .svn 64 | .CVS 65 | .idea 66 | node_modules 67 | old/ 68 | *_old/ 69 | *_notrack/ 70 | no_track/ 71 | *_no_track.* 72 | *.no_track.* 73 | no_track.* 74 | build/ 75 | combo/ 76 | reference/ 77 | jscoverage_lib/ 78 | temp/ 79 | tmp/ 80 | 81 | # Java 82 | .mvn 83 | .gradle 84 | .vscode 85 | .project 86 | .classpath 87 | /gradle 88 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | 3 | python: 4 | # - "3.7" 5 | - "3.8" 6 | 7 | install: 8 | - make install 9 | 10 | script: 11 | - make test 12 | 13 | after_success: 14 | - make report 15 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3 2 | 3 | WORKDIR /usr/src/app 4 | 5 | COPY requirements.txt ./ 6 | 7 | RUN pip install -r requirements.txt 8 | 9 | COPY . . 10 | 11 | CMD ["python", "start.py"] 12 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2013 kaelzhang <>, contributors 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining 4 | a copy of this software and associated documentation files (the 5 | "Software"), to deal in the Software without restriction, including 6 | without limitation the rights to use, copy, modify, merge, publish, 7 | distribute, sublicense, and/or sell copies of the Software, and to 8 | permit persons to whom the Software is furnished to do so, subject to 9 | the following conditions: 10 | 11 | The above copyright notice and this permission notice shall be 12 | included in all copies or substantial portions of the Software. 13 | 14 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 15 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 16 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 17 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE 18 | LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 19 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION 20 | WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 21 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | files = da_rnn test *.py 2 | test_files = * 3 | 4 | test: 5 | pytest -s -v test/test_$(test_files).py --doctest-modules --cov da_rnn --cov-config=.coveragerc --cov-report term-missing 6 | 7 | lint: 8 | flake8 $(files) 9 | 10 | fix: 11 | autopep8 --in-place -r $(files) 12 | 13 | install: 14 | pip install -U -r requirements.txt -r dev-requirements.txt 15 | 16 | report: 17 | codecov 18 | 19 | build: da_rnn 20 | rm -rf dist 21 | python setup.py sdist bdist_wheel 22 | 23 | publish: 24 | make build 25 | twine upload --config-file ~/.pypirc -r pypi dist/* 26 | 27 | .PHONY: test build 28 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![](https://travis-ci.org/kaelzhang/DA-RNN-in-Tensorflow-2-and-PyTorch.svg?branch=master)](https://travis-ci.org/kaelzhang/DA-RNN-in-Tensorflow-2-and-PyTorch) 2 | [![](https://codecov.io/gh/kaelzhang/DA-RNN-in-Tensorflow-2-and-PyTorch/branch/master/graph/badge.svg)](https://codecov.io/gh/kaelzhang/DA-RNN-in-Tensorflow-2-and-PyTorch) 3 | [![](https://img.shields.io/pypi/v/da-rnn.svg)](https://pypi.org/project/da_rnn/) 4 | [![](https://img.shields.io/pypi/l/da-rnn.svg)](https://github.com/kaelzhang/DA-RNN-in-Tensorflow-2-and-PyTorch) 5 | 6 | # Tensorflow 2 / Torch DA-RNN 7 | 8 | A Tensorflow 2 (Keras) and pytorch implementation of the [Dual-Stage Attention-Based Recurrent Neural Network for Time Series Prediction](https://arxiv.org/abs/1704.02971) 9 | 10 | Paper: [https://arxiv.org/abs/1704.02971](https://arxiv.org/abs/1704.02971) 11 | 12 | ## Run notebook demo 13 | 14 | Install dependencies (It is recommended to use [anaconda](https://docs.anaconda.com/anaconda/install/) to manage environments): 15 | 16 | ```sh 17 | make install 18 | ``` 19 | 20 | Run notebook: 21 | 22 | ```sh 23 | cd notebook 24 | jupyter lab 25 | 26 | # Run `pytorch.ipynb` 27 | ``` 28 | 29 | 30 | ## Install 31 | 32 | For Tensorflow 2 33 | 34 | ```sh 35 | pip install da-rnn[keras] 36 | ``` 37 | 38 | For PyTorch 39 | 40 | ```sh 41 | pip install da-rnn[torch] 42 | ``` 43 | 44 | ## Usage 45 | 46 | For Tensorflow 2 (Still buggy for now) 47 | 48 | ```py 49 | from da_rnn.keras import DARNN 50 | 51 | model = DARNN(T=10, m=128) 52 | 53 | # Train 54 | model.fit( 55 | train_ds, 56 | validation_data=val_ds, 57 | epochs=100, 58 | verbose=1 59 | ) 60 | 61 | # Predict 62 | y_hat = model(inputs) 63 | ``` 64 | 65 | For PyTorch (Tested. Works) 66 | 67 | ```py 68 | import torch 69 | from poutyne import Model 70 | from da_rnn.torch import DARNN 71 | 72 | darnn = DARNN(n=50, T=10, m=128) 73 | model = Model(darnn) 74 | 75 | # Train 76 | model.fit( 77 | train_ds, 78 | validation_data=val_ds, 79 | epochs=100, 80 | verbose=1 81 | ) 82 | 83 | # Predict 84 | with torch.no_grad(): 85 | y_hat = model(inputs) 86 | ``` 87 | 88 | ### Python Docstring Notations 89 | 90 | 91 | In docstrings of the methods of this project, we have the following notation convention: 92 | 93 | ``` 94 | variable_{subscript}__{superscript} 95 | ``` 96 | 97 | For example: 98 | 99 | - `y_T__i` means ![y_T__i](https://render.githubusercontent.com/render/math?math=y_T^1), the `i`-th prediction value at time `T`. 100 | - `alpha_t__k` means ![alpha_t__k](https://render.githubusercontent.com/render/math?math=\alpha_t^k), the attention weight measuring the importance of the `k`-th input feature (driving series) at time `t`. 101 | 102 | ### DARNN(T, m, p, y_dim=1) 103 | ### DARNN(n, T, m, p, y_dim=1) 104 | 105 | > The naming of the following (hyper)parameters is consistent with the paper, except `y_dim` which is not mentioned in the paper. 106 | 107 | - **n** (torch only) `int` input size, the number of features of a single driving series 108 | - **T** `int` the length (time steps) of the window 109 | - **m** `int` the number of the encoder hidden states 110 | - **p** `int` the number of the decoder hidden states 111 | - **y_dim** `int=1` the prediction dimension. Defaults to `1`. 112 | 113 | Return the DA-RNN model instance. 114 | 115 | ## Data Processing 116 | 117 | Each feature item of the dataset should be of shape `(batch_size, T, length_of_driving_series + y_dim)` 118 | 119 | And each label item of the dataset should be of shape `(batch_size, y_dim)` 120 | 121 | ## TODO 122 | - [x] no hardcoding (`1` for now) for prediction dimentionality 123 | 124 | ## License 125 | 126 | [MIT](LICENSE) 127 | -------------------------------------------------------------------------------- /da_rnn/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = '1.0.2' 2 | -------------------------------------------------------------------------------- /da_rnn/common.py: -------------------------------------------------------------------------------- 1 | def check_T(T: int): 2 | if T < 2: 3 | raise ValueError( 4 | f'T must be an integer larger than 1, but got `{T}`' 5 | ) 6 | -------------------------------------------------------------------------------- /da_rnn/keras/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import ( 2 | DARNN, 3 | Encoder, 4 | Decoder 5 | ) 6 | -------------------------------------------------------------------------------- /da_rnn/keras/model.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import tensorflow as tf 4 | import tensorflow.keras.backend as K 5 | 6 | from tensorflow.keras.layers import ( 7 | Layer, 8 | LSTM, 9 | Dense, 10 | Permute 11 | ) 12 | 13 | from tensorflow.keras.models import Model 14 | 15 | from da_rnn.common import ( 16 | check_T 17 | ) 18 | 19 | 20 | """ 21 | Notation (according to the paper) 22 | 23 | Naming Convention:: 24 | 25 | Variable_{time_step}__{sequence_number_of_driving_series} 26 | 27 | Variables / HyperParameters: 28 | T (int): the size (time steps) of the window 29 | m (int): the number of the encoder hidden states 30 | p (int): the number of the decoder hidden states 31 | n (int): the number of features of a single driving series 32 | X: the n driving (exogenous) series of shape (batch_size, T, n) 33 | X_tilde: the new input for the encoder, i.e. X̃ = (x̃_1, ..., x̃_t, x̃_T) 34 | Y: the historical/previous T - 1 predictions, (y_1, y_2, ..., y_Tminus1) 35 | 36 | hidden_state / h: hidden state 37 | cell_state / s: cell state 38 | Alpha_t: attention weights of the input attention layer at time t 39 | Beta_t: attention weights of the temporal attention layer at time t 40 | """ 41 | 42 | 43 | class InputAttention(Layer): 44 | T: int 45 | 46 | def __init__(self, T, **kwargs): 47 | """ 48 | Calculates the encoder attention weight Alpha_t at time t 49 | 50 | Args: 51 | T (int): the size (time steps) of the window 52 | """ 53 | 54 | super().__init__(name='input_attention', **kwargs) 55 | 56 | self.T = T 57 | 58 | self.W_e = Dense(T) 59 | self.U_e = Dense(T) 60 | self.v_e = Dense(1) 61 | 62 | def call( 63 | self, 64 | hidden_state, 65 | cell_state, 66 | X 67 | ): 68 | """ 69 | Args: 70 | hidden_state: hidden state of shape (batch_size, m) at time t - 1 71 | cell_state: cell state of shape (batch_size, m) at time t - 1 72 | X: the n driving (exogenous) series of shape (batch_size, T, n) 73 | 74 | Returns: 75 | The attention weights (Alpha_t) at time t, i.e. 76 | (a_t__1, a_t__2, ..., a_t__n) 77 | """ 78 | 79 | n = X.shape[2] 80 | 81 | # [h_t-1; s_t-1] 82 | hs = K.repeat( 83 | tf.concat([hidden_state, cell_state], axis=-1), 84 | # -> (batch_size, m * 2) 85 | n 86 | ) 87 | # -> (batch_size, n, m * 2) 88 | 89 | tanh = tf.math.tanh( 90 | tf.concat([ 91 | self.W_e(hs), 92 | # -> (batch_size, n, T) 93 | 94 | self.U_e( 95 | Permute((2, 1))(X) 96 | # -> (batch_size, n, T) 97 | ), 98 | # -> (batch_size, n, T) 99 | ], axis=-1) 100 | # -> (batch_size, n, T * 2) 101 | ) 102 | # -> (batch_size, n, T * 2) 103 | 104 | # Equation 8: 105 | e = self.v_e(tanh) 106 | # -> (batch_size, n, 1) 107 | 108 | # Equation: 9 109 | return tf.nn.softmax( 110 | Permute((2, 1))(e) 111 | # -> (batch_size, 1, n) 112 | ) 113 | # -> (batch_size, 1, n) 114 | 115 | def get_config(self): 116 | config = super().get_config().copy() 117 | config.update({ 118 | 'T': self.T 119 | }) 120 | return config 121 | 122 | 123 | class Encoder(Layer): 124 | T: int 125 | m: int 126 | 127 | def __init__( 128 | self, 129 | T: int, 130 | m: int, 131 | **kwargs 132 | ): 133 | """ 134 | Generates the new input X_tilde for encoder 135 | 136 | Args: 137 | T (int): the size (time steps) of the window 138 | m (int): the number of the encoder hidden states 139 | """ 140 | 141 | super().__init__(name='encoder_input', **kwargs) 142 | 143 | self.T = T 144 | self.m = m 145 | 146 | self.input_lstm = LSTM(m, return_state=True) 147 | self.input_attention = InputAttention(T) 148 | 149 | def call(self, X) -> tf.Tensor: 150 | """ 151 | Args: 152 | X: the n driving (exogenous) series of shape (batch_size, T, n) 153 | 154 | Returns: 155 | The encoder hidden state of shape (batch_size, T, m) 156 | """ 157 | 158 | batch_size = K.shape(X)[0] 159 | 160 | hidden_state = tf.zeros((batch_size, self.m)) 161 | cell_state = tf.zeros((batch_size, self.m)) 162 | 163 | X_encoded = [] 164 | 165 | for t in range(self.T): 166 | Alpha_t = self.input_attention(hidden_state, cell_state, X) 167 | 168 | # Equation 10 169 | X_tilde_t = tf.multiply( 170 | Alpha_t, 171 | # TODO: 172 | # make sure it can share the underlying data 173 | X[:, None, t, :] 174 | ) 175 | # -> (batch_size, 1, n) 176 | 177 | # Equation 11 178 | hidden_state, _, cell_state = self.input_lstm( 179 | X_tilde_t, 180 | initial_state=[hidden_state, cell_state] 181 | ) 182 | 183 | X_encoded.append( 184 | hidden_state[:, None, :] 185 | # -> (batch_size, 1, m) 186 | ) 187 | 188 | return tf.concat(X_encoded, axis=1) 189 | # -> (batch_size, T, m) 190 | 191 | def get_config(self): 192 | config = super().get_config().copy() 193 | config.update({ 194 | 'T': self.T, 195 | 'm': self.m 196 | }) 197 | return config 198 | 199 | 200 | class TemporalAttention(Layer): 201 | m: int 202 | 203 | def __init__(self, m: int, **kwargs): 204 | """ 205 | Calculates the attention weights:: 206 | 207 | Beta_t = (beta_t__1, ..., beta_t__i, ..., beta_t__T) (1 <= i <= T) 208 | 209 | for each encoder hidden state h_t at the time step t 210 | 211 | Args: 212 | m (int): the number of the encoder hidden states 213 | """ 214 | 215 | super().__init__(name='temporal_attention', **kwargs) 216 | 217 | self.m = m 218 | 219 | self.W_d = Dense(m) 220 | self.U_d = Dense(m) 221 | self.v_d = Dense(1) 222 | 223 | def call( 224 | self, 225 | hidden_state, 226 | cell_state, 227 | X_encoded 228 | ): 229 | """ 230 | Args: 231 | hidden_state: hidden state `d` of shape (batch_size, p) 232 | cell_state: cell state `s` of shape (batch_size, p) 233 | X_encoded: the encoder hidden states (batch_size, T, m) 234 | 235 | Returns: 236 | The attention weights for encoder hidden states (beta_t) 237 | """ 238 | 239 | # Equation 12 240 | l = self.v_d( 241 | tf.math.tanh( 242 | tf.concat([ 243 | self.W_d( 244 | K.repeat( 245 | tf.concat([hidden_state, cell_state], axis=-1), 246 | # -> (batch_size, p * 2) 247 | X_encoded.shape[1] 248 | ) 249 | # -> (batch_size, T, p * 2) 250 | ), 251 | # -> (batch_size, T, m) 252 | self.U_d(X_encoded) 253 | ], axis=-1) 254 | # -> (batch_size, T, m * 2) 255 | ) 256 | # -> (batch_size, T, m) 257 | ) 258 | # -> (batch_size, T, 1) 259 | 260 | # Equation 13 261 | return tf.nn.softmax(l, axis=1) 262 | # -> (batch_size, T, 1) 263 | 264 | def get_config(self): 265 | config = super().get_config().copy() 266 | config.update({ 267 | 'm': self.m 268 | }) 269 | return config 270 | 271 | 272 | class Decoder(Layer): 273 | T: int 274 | m: int 275 | p: int 276 | y_dim: int 277 | 278 | def __init__( 279 | self, 280 | T: int, 281 | m: int, 282 | p: int, 283 | y_dim: int, 284 | **kwargs 285 | ): 286 | """ 287 | Calculates y_hat_T 288 | 289 | Args: 290 | T (int): the size (time steps) of the window 291 | m (int): the number of the encoder hidden states 292 | p (int): the number of the decoder hidden states 293 | y_dim (int): prediction dimentionality 294 | """ 295 | 296 | super().__init__(name='decoder', **kwargs) 297 | 298 | self.T = T 299 | self.m = m 300 | self.p = p 301 | self.y_dim = y_dim 302 | 303 | self.temp_attention = TemporalAttention(m) 304 | self.dense = Dense(1) 305 | self.decoder_lstm = LSTM(p, return_state=True) 306 | 307 | self.Wb = Dense(p) 308 | self.vb = Dense(y_dim) 309 | 310 | def call(self, Y, X_encoded) -> tf.Tensor: 311 | """ 312 | Args: 313 | Y: prediction data of shape (batch_size, T - 1, y_dim) from time 1 to time T - 1. See Figure 1(b) in the paper 314 | X_encoded: encoder hidden states of shape (batch_size, T, m) 315 | 316 | Returns: 317 | y_hat_T: the prediction of shape (batch_size, y_dim) 318 | """ 319 | 320 | batch_size = K.shape(X_encoded)[0] 321 | hidden_state = tf.zeros((batch_size, self.p)) 322 | cell_state = tf.zeros((batch_size, self.p)) 323 | 324 | # c in the paper 325 | context_vector = tf.zeros((batch_size, 1, self.m)) 326 | # -> (batch_size, 1, m) 327 | 328 | for t in range(self.T - 1): 329 | Beta_t = self.temp_attention( 330 | hidden_state, 331 | cell_state, 332 | X_encoded 333 | ) 334 | # -> (batch_size, T, 1) 335 | 336 | # Equation 14 337 | context_vector = tf.matmul( 338 | Beta_t, X_encoded, transpose_a=True 339 | ) 340 | # -> (batch_size, 1, m) 341 | 342 | # Equation 15 343 | y_tilde = self.dense( 344 | tf.concat([Y[:, None, t, :], context_vector], axis=-1) 345 | # -> (batch_size, 1, y_dim + m) 346 | ) 347 | # -> (batch_size, 1, 1) 348 | 349 | # Equation 16 350 | hidden_state, _, cell_state = self.decoder_lstm( 351 | y_tilde, 352 | initial_state=[hidden_state, cell_state] 353 | ) 354 | # -> (batch_size, p) 355 | 356 | concatenated = tf.concat( 357 | [hidden_state[:, None, :], context_vector], axis=-1 358 | ) 359 | # -> (batch_size, 1, m + p) 360 | 361 | # Equation 22 362 | y_hat_T = self.vb( 363 | self.Wb(concatenated) 364 | # -> (batch_size, 1, p) 365 | ) 366 | # -> (batch_size, 1, y_dim) 367 | 368 | return tf.squeeze(y_hat_T, axis=1) 369 | 370 | def get_config(self): 371 | config = super().get_config().copy() 372 | config.update({ 373 | 'T': self.T, 374 | 'm': self.m, 375 | 'p': self.p, 376 | 'y_dim': self.y_dim 377 | }) 378 | return config 379 | 380 | 381 | class DARNN(Model): 382 | def __init__( 383 | self, 384 | T: int, 385 | m: int, 386 | p: Optional[int] = None, 387 | y_dim: int = 1 388 | ): 389 | """ 390 | Args: 391 | T (int): the size (time steps) of the window 392 | m (int): the number of the encoder hidden states 393 | p (:obj:`int`, optional): the number of the decoder hidden states. Defaults to `m` 394 | y_dim (:obj:`int`, optional): prediction dimentionality. Defaults to `1` 395 | 396 | Model Args: 397 | inputs: the concatenation of 398 | - n driving series (x_1, x_2, ..., x_T) and 399 | - the previous (historical) T - 1 predictions (y_1, y_2, ..., y_Tminus1, zero) 400 | 401 | `inputs` Explanation:: 402 | 403 | inputs_t = (x_t__1, x_t__2, ..., x_t__n, y_t__1, y_t__2, ..., y_t__d) 404 | 405 | where 406 | - d is the prediction dimention 407 | - y_T__i = 0, 1 <= i <= d. 408 | 409 | Actually, the model will not use the value of y_T 410 | 411 | Usage:: 412 | 413 | model = DARNN(10, 64, 64) 414 | y_hat = model(inputs) 415 | """ 416 | 417 | super().__init__(name='DARNN') 418 | 419 | check_T(T) 420 | 421 | self.T = T 422 | self.m = m 423 | self.p = p or m 424 | self.y_dim = y_dim 425 | 426 | self.encoder = Encoder(T, m) 427 | self.decoder = Decoder(T, m, self.p, y_dim=y_dim) 428 | 429 | # Equation 1 430 | def call(self, inputs): 431 | X = inputs[:, :, :-self.y_dim] 432 | # -> (batch_size, T, n) 433 | 434 | # Y's window size is one less than X's 435 | # so, abandon `y_T` 436 | 437 | # By doing this, there are some benefits which makes it pretty easy to 438 | # process datasets 439 | Y = inputs[:, :, -self.y_dim:] 440 | # -> (batch_size, T - 1, y_dim) 441 | 442 | X_encoded = self.encoder(X) 443 | 444 | y_hat_T = self.decoder(Y, X_encoded) 445 | # -> (batch_size, y_dim) 446 | 447 | return y_hat_T 448 | 449 | def get_config(self): 450 | return { 451 | 'T': self.T, 452 | 'm': self.m, 453 | 'p': self.p, 454 | 'y_dim': self.y_dim 455 | } 456 | -------------------------------------------------------------------------------- /da_rnn/torch/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import ( 2 | DARNN, 3 | Encoder, 4 | Decoder, 5 | DEVICE 6 | ) 7 | -------------------------------------------------------------------------------- /da_rnn/torch/model.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | from torch.nn import ( 5 | Module, 6 | Linear, 7 | LSTM 8 | ) 9 | 10 | from da_rnn.common import ( 11 | check_T 12 | ) 13 | 14 | 15 | DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 16 | 17 | 18 | class Encoder(Module): 19 | n: int 20 | T: int 21 | m: int 22 | 23 | DEVICE = DEVICE 24 | 25 | def __init__( 26 | self, 27 | n: int, 28 | T: int, 29 | m: int, 30 | dropout 31 | ): 32 | """ 33 | Generates the new input X_tilde for encoder 34 | 35 | Args: 36 | n (int): input size, the number of features of a single driving series 37 | T (int): the size (time steps) of the window 38 | m (int): the number of the encoder hidden states 39 | """ 40 | 41 | super().__init__() 42 | 43 | self.n = n 44 | self.T = T 45 | self.m = m 46 | 47 | # Two linear layers forms a bigger linear layer 48 | self.WU_e = Linear(m * 2 + T, T, False) 49 | 50 | # Since v_e ∈ R^T, the input size is T 51 | self.v_e = Linear(T, 1, False) 52 | 53 | self.lstm = LSTM(self.n, self.m, dropout=dropout) 54 | 55 | def forward(self, X): 56 | """ 57 | Args: 58 | X: the n driving (exogenous) series of shape (batch_size, T, n) 59 | 60 | Returns: 61 | The encoder hidden state of shape (T, batch_size, m) 62 | """ 63 | 64 | batch_size = X.shape[0] 65 | 66 | hidden_state = torch.zeros(1, batch_size, self.m, device=self.DEVICE) 67 | cell_state = torch.zeros(1, batch_size, self.m, device=self.DEVICE) 68 | 69 | X_encoded = torch.zeros(self.T, batch_size, self.m, device=self.DEVICE) 70 | 71 | for t in range(self.T): 72 | # [h_t-1; s_t-1] 73 | hs = torch.cat((hidden_state, cell_state), 2) 74 | # -> (1, batch_size, m * 2) 75 | 76 | hs = hs.permute(1, 0, 2).repeat(1, self.n, 1) 77 | # -> (batch_size, n, m * 2) 78 | 79 | tanh = torch.tanh( 80 | self.WU_e( 81 | torch.cat((hs, X.permute(0, 2, 1)), 2) 82 | # -> (batch_size, n, m * 2 + T) 83 | ) 84 | ) 85 | # -> (batch_size, n, T) 86 | 87 | # Equation 8 88 | E = self.v_e(tanh).view(batch_size, self.n) 89 | # -> (batch_size, n) 90 | 91 | # Equation 9 92 | Alpha_t = torch.softmax(E, 1) 93 | # -> (batch_size, n) 94 | 95 | # Ref 96 | # https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html 97 | # The input shape of torch LSTM should be 98 | # (seq_len, batch, n) 99 | _, (hidden_state, cell_state) = self.lstm( 100 | (X[:, t, :] * Alpha_t).unsqueeze(0), 101 | # -> (1, batch_size, n) 102 | (hidden_state, cell_state) 103 | ) 104 | 105 | X_encoded[t] = hidden_state[0] 106 | 107 | return X_encoded 108 | 109 | 110 | class Decoder(Module): 111 | T: int 112 | p: int 113 | 114 | DEVICE = DEVICE 115 | 116 | def __init__( 117 | self, 118 | T: int, 119 | m: int, 120 | p: int, 121 | y_dim: int, 122 | dropout 123 | ): 124 | """ 125 | Calculates y_hat_T 126 | 127 | Args: 128 | T (int): the size (time steps) of the window 129 | m (int): the number of the encoder hidden states 130 | p (int): the number of the decoder hidden states 131 | y_dim (int): prediction dimentionality 132 | """ 133 | 134 | super().__init__() 135 | 136 | self.T = T 137 | self.p = p 138 | 139 | self.WU_d = Linear(p * 2 + m, m, False) 140 | self.v_d = Linear(m, 1, False) 141 | self.wb_tilde = Linear(y_dim + m, 1, False) 142 | 143 | self.lstm = LSTM(1, p, dropout=dropout) 144 | 145 | self.Wb = Linear(p + m, p) 146 | self.vb = Linear(p, y_dim) 147 | 148 | def forward(self, Y, X_encoded): 149 | """ 150 | Args: 151 | Y: prediction data of shape (batch_size, T - 1, y_dim) from time 1 to time T - 1. See Figure 1(b) in the paper 152 | X_encoded: encoder hidden states of shape (T, batch_size, m) 153 | 154 | Returns: 155 | y_hat_T: the prediction of shape (batch_size, y_dim) 156 | """ 157 | 158 | batch_size = Y.shape[0] 159 | 160 | hidden_state = torch.zeros(1, batch_size, self.p, device=self.DEVICE) 161 | cell_state = torch.zeros(1, batch_size, self.p, device=self.DEVICE) 162 | 163 | for t in range(self.T - 1): 164 | # Equation 12 165 | l = self.v_d( 166 | torch.tanh( 167 | self.WU_d( 168 | torch.cat( 169 | ( 170 | torch.cat( 171 | (hidden_state, cell_state), 172 | 2 173 | ).permute(1, 0, 2).repeat(1, self.T, 1), 174 | # -> (batch_size, T, p * 2) 175 | 176 | X_encoded.permute(1, 0, 2) 177 | # -> (batch_size, T, m) 178 | ), 179 | 2 180 | ) 181 | ) 182 | # -> (batch_size, T, m * 2) 183 | ) 184 | # -> (batch_size, T, m) 185 | ).view(batch_size, self.T) 186 | # -> (batch_size, T) 187 | 188 | # Equation 13 189 | Beta_t = torch.softmax(l, 1) 190 | # -> (batch_size, T) 191 | 192 | # Equation 14 193 | context_vector = torch.bmm( 194 | Beta_t.unsqueeze(1), 195 | # -> (batch_size, 1, T) 196 | X_encoded.permute(1, 0, 2) 197 | # -> (batch_size, T, m) 198 | ).squeeze(1) 199 | # -> (batch_size, m) 200 | 201 | # Equation 15 202 | y_tilde = self.wb_tilde( 203 | torch.cat((Y[:, t, :], context_vector), 1) 204 | # -> (batch_size, y_dim + m) 205 | ) 206 | # -> (batch_size, 1) 207 | 208 | # Equation 16 209 | _, (hidden_state, cell_state) = self.lstm( 210 | y_tilde.unsqueeze(0), 211 | # -> (1, batch_size, 1) 212 | (hidden_state, cell_state) 213 | ) 214 | 215 | # Equation 22 216 | y_hat_T = self.vb( 217 | self.Wb( 218 | torch.cat((hidden_state.squeeze(0), context_vector), 1) 219 | # -> (batch_size, p + m) 220 | ) 221 | # -> (batch_size, p) 222 | ) 223 | # -> (batch_size, 1) 224 | 225 | return y_hat_T 226 | 227 | 228 | class DARNN(Module): 229 | y_dim: int 230 | 231 | DEVICE = DEVICE 232 | 233 | def __init__( 234 | self, 235 | n: int, 236 | T: int, 237 | m: int, 238 | p: Optional[int] = None, 239 | y_dim: int = 1, 240 | dropout=0 241 | ): 242 | """ 243 | Args: 244 | n (int): input size, the number of features of a single driving series 245 | T (int): the size (time steps) of the window 246 | m (int): the number of the encoder hidden states 247 | p (:obj:`int`, optional): the number of the decoder hidden states. Defaults to `m` 248 | y_dim (:obj:`int`, optional): prediction dimentionality. Defaults to `1` 249 | 250 | Model Args: 251 | inputs: the concatenation of 252 | - n driving series (x_1, x_2, ..., x_T) and 253 | - the previous (historical) T - 1 predictions (y_1, y_2, ..., y_Tminus1, zero) 254 | 255 | `inputs` Explanation:: 256 | 257 | inputs_t = (x_t__1, x_t__2, ..., x_t__n, y_t__1, y_t__2, ..., y_t__d) 258 | 259 | where 260 | - d is the prediction dimention 261 | - y_T__i = 0, 1 <= i <= d. 262 | 263 | Actually, the model will not use the value of y_T 264 | 265 | Usage:: 266 | 267 | model = DARNN(10, 64, 64) 268 | y_hat = model(inputs) 269 | """ 270 | 271 | super().__init__() 272 | 273 | check_T(T) 274 | 275 | self.y_dim = y_dim 276 | 277 | self.encoder = Encoder(n, T, m, dropout) 278 | self.decoder = Decoder(T, m, p or m, y_dim, dropout) 279 | 280 | def forward(self, inputs): 281 | X, Y = torch.split( 282 | inputs, 283 | [inputs.shape[2] - self.y_dim, self.y_dim], 284 | dim=2 285 | ) 286 | 287 | return self.decoder(Y, self.encoder(X)) 288 | -------------------------------------------------------------------------------- /dev-requirements.txt: -------------------------------------------------------------------------------- 1 | codecov 2 | coverage 3 | flake8 4 | autopep8 5 | pytest 6 | pytest-cov 7 | setuptools 8 | twine 9 | jupyterlab 10 | get_rolling_window 11 | torch 12 | # tensorflow 13 | Poutyne 14 | pandas 15 | scikit-learn 16 | matplotlib 17 | -------------------------------------------------------------------------------- /notebook/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kaelzhang/DA-RNN-in-Tensorflow-2-and-PyTorch/1a30c3dffd7e8556c775480faa500d84fed5334c/notebook/__init__.py -------------------------------------------------------------------------------- /notebook/common.py: -------------------------------------------------------------------------------- 1 | from get_rolling_window import rolling_window 2 | 3 | 4 | def get_labels_from_features(features, window_size, y_dim): 5 | return features[window_size - 1:, -y_dim:] 6 | 7 | 8 | def split_by_ratio(features, validation_ratio): 9 | length = len(features) 10 | validation_length = int(validation_ratio * length) 11 | 12 | return features[:-validation_length], features[-validation_length:] 13 | 14 | 15 | def split_data( 16 | data, 17 | apply, 18 | window_size, 19 | y_dim, 20 | validation_ratio 21 | ): 22 | train_data, val_data = split_by_ratio(data, validation_ratio) 23 | 24 | train_f, train_l = rolling_window( 25 | train_data, window_size, 1 26 | ), get_labels_from_features(train_data, window_size, y_dim) 27 | 28 | val_f, val_l = rolling_window( 29 | val_data, window_size, 1 30 | ), get_labels_from_features(val_data, window_size, y_dim) 31 | 32 | return apply(train_f), apply(train_l), apply(val_f), apply(val_l) 33 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # Newer version of numpy is not compatible with tensorflow 2.4.1 2 | numpy # <= 1.19.5 3 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | description-file = README.md 3 | 4 | [flake8] 5 | per-file-ignores = 6 | **/__init__.py:F401 7 | # exclude = 8 | # no_track 9 | ignore = 10 | E501 11 | E741 12 | W503 13 | 14 | [tool:pytest] 15 | log_cli=true 16 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from setuptools import setup 3 | 4 | from da_rnn import __version__ 5 | 6 | 7 | # Utility function to read the README file. 8 | # Used for the long_description. It's nice, because now 1) we have a top level 9 | # README file and 2) it's easier to type in the README file than to put a raw 10 | # string in below ... 11 | def read(fname): 12 | return open(Path(__file__).parent / 'README.md').read() 13 | 14 | 15 | def read_requirements(filename): 16 | with open(filename) as f: 17 | return f.read().splitlines() 18 | 19 | 20 | settings = dict( 21 | name='da-rnn', 22 | packages=[ 23 | 'da_rnn', 24 | 'da_rnn/keras', 25 | 'da_rnn/torch' 26 | ], 27 | version=__version__, 28 | author='kaelzhang', 29 | author_email='', 30 | description=('A tensorflow 2.0 implementation of the Dual-Stage Attention-Based Recurrent Neural Network for Time Series Prediction'), 31 | license='MIT', 32 | keywords='da_rnn', 33 | url='https://github.com/kaelzhang/tensorflow-2.0-DA-RNN', 34 | long_description=read('README.md'), 35 | long_description_content_type='text/markdown', 36 | python_requires='>=3.7', 37 | install_requires=read_requirements('requirements.txt'), 38 | extras_require={ 39 | 'keras': ['tensorflow >= 2'], 40 | 'torch': ['torch'] 41 | }, 42 | tests_require=read_requirements('dev-requirements.txt'), 43 | classifiers=[ 44 | 'Topic :: Software Development :: Libraries :: Python Modules', 45 | 'Programming Language :: Python :: 3.7', 46 | 'Programming Language :: Python :: 3.8', 47 | 'Programming Language :: Python :: Implementation :: PyPy', 48 | 'Topic :: Software Development :: Libraries :: Python Modules', 49 | 'License :: OSI Approved :: MIT License', 50 | ] 51 | ) 52 | 53 | 54 | if __name__ == '__main__': 55 | setup(**settings) 56 | -------------------------------------------------------------------------------- /test/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kaelzhang/DA-RNN-in-Tensorflow-2-and-PyTorch/1a30c3dffd7e8556c775480faa500d84fed5334c/test/__init__.py -------------------------------------------------------------------------------- /test/test_da_rnn.py: -------------------------------------------------------------------------------- 1 | import da_rnn 2 | 3 | 4 | def test_main(): 5 | pass 6 | --------------------------------------------------------------------------------