├── data ├── __init__.py ├── data_download.py └── custom_dataset.py ├── utils ├── __init__.py └── utils.py ├── data_formatters ├── __init__.py ├── traffic.py ├── volatility.py ├── base.py └── electricity.py ├── README.md ├── .gitignore ├── quantile_loss.py └── models.py /data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data_formatters/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | # Generic. 2 | def get_single_col_by_input_type(input_type, column_definition): 3 | """Returns name of single column. 4 | Args: 5 | input_type: Input type of column to extract 6 | column_definition: Column definition list for experiment 7 | """ 8 | 9 | l = [tup[0] for tup in column_definition if tup[2] == input_type] 10 | 11 | if len(l) != 1: 12 | raise ValueError('Invalid number of columns for {}'.format(input_type)) 13 | 14 | return l[0] 15 | 16 | def extract_cols_from_data_type(data_type, column_definition, 17 | excluded_input_types): 18 | """Extracts the names of columns that correspond to a define data_type. 19 | Args: 20 | data_type: DataType of columns to extract. 21 | column_definition: Column definition to use. 22 | excluded_input_types: Set of input types to exclude 23 | Returns: 24 | List of names for columns with data type specified. 25 | """ 26 | return [ 27 | tup[0] 28 | for tup in column_definition 29 | if tup[1] == data_type and tup[2] not in excluded_input_types 30 | ] -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Temporal Fusion Transformers for Interpretable Multi-horizon Time Series Forecasting implemented in Pytorch 2 | Authors: Bryan Lim, Sercan Arik, Nicolas Loeff and Tomas Pfister 3 | 4 | Paper Link: [https://arxiv.org/pdf/1912.09363.pdf](https://arxiv.org/pdf/1912.09363.pdf) 5 | 6 | # Implementation 7 | This repository contains the source code for the Temporal Fusion Transformer reproduced in Pytorch using [Pytorch Lightning](https://github.com/PyTorchLightning/pytorch-lightning) which is used to scale models and write less boilerplate . In the moment, the model is trained with the Electricity dataset from the paper. However, im currently working on the code to allow the use of the other 3 datasets described in the paper and reproduce the results. 8 | 9 | - **data_formatters**: Stores the main dataset-specific column definitions, along with functions for data transformation and normalization. For compatibility with the TFT, new experiments should implement a unique GenericDataFormatter (see base.py), with examples for the default experiments shown in the other python files. 10 | 11 | - **data**: Stores the main dataset-specific download procedure, along with the pytorch dataset class ready to use as input to the dataloader and the model. 12 | 13 | # Training 14 | To run the training procedure, open up **training_tft.ipynb** and execute all cells to start training. 15 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /data_formatters/traffic.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from utils.utils import get_single_col_by_input_type 3 | from utils.utils import extract_cols_from_data_type 4 | from .base import GenericDataFormatter 5 | from .base import VolatilityFormatter 6 | from .base import DataTypes, InputTypes 7 | import sklearn.preprocessing 8 | 9 | class TrafficFormatter(VolatilityFormatter): 10 | """Defines and formats data for the traffic dataset. 11 | This also performs z-score normalization across the entire dataset, hence 12 | re-uses most of the same functions as volatility. 13 | Attributes: 14 | column_definition: Defines input and data type of column used in the 15 | experiment. 16 | identifiers: Entity identifiers used in experiments. 17 | """ 18 | 19 | _column_definition = [ 20 | ('id', DataTypes.REAL_VALUED, InputTypes.ID), 21 | ('hours_from_start', DataTypes.REAL_VALUED, InputTypes.TIME), 22 | ('values', DataTypes.REAL_VALUED, InputTypes.TARGET), 23 | ('time_on_day', DataTypes.REAL_VALUED, InputTypes.KNOWN_INPUT), 24 | ('day_of_week', DataTypes.REAL_VALUED, InputTypes.KNOWN_INPUT), 25 | ('hours_from_start', DataTypes.REAL_VALUED, InputTypes.KNOWN_INPUT), 26 | ('categorical_id', DataTypes.CATEGORICAL, InputTypes.STATIC_INPUT), 27 | ] 28 | 29 | def split_data(self, df, valid_boundary=151, test_boundary=166): 30 | """Splits data frame into training-validation-test data frames. 31 | This also calibrates scaling object, and transforms data for each split. 32 | Args: 33 | df: Source data frame to split. 34 | valid_boundary: Starting year for validation data 35 | test_boundary: Starting year for test data 36 | Returns: 37 | Tuple of transformed (train, valid, test) data. 38 | """ 39 | 40 | print('Formatting train-valid-test splits.') 41 | 42 | index = df['sensor_day'] 43 | train = df.loc[index < valid_boundary] 44 | valid = df.loc[(index >= valid_boundary - 7) & (index < test_boundary)] 45 | test = df.loc[index >= test_boundary - 7] 46 | 47 | self.set_scalers(train) 48 | 49 | return (self.transform_inputs(data) for data in [train, valid, test]) 50 | 51 | # Default params 52 | def get_fixed_params(self): 53 | """Returns fixed model parameters for experiments.""" 54 | 55 | fixed_params = { 56 | 'total_time_steps': 8 * 24, 57 | 'num_encoder_steps': 7 * 24, 58 | 'num_epochs': 100, 59 | 'early_stopping_patience': 5, 60 | 'multiprocessing_workers': 5 61 | } 62 | 63 | return fixed_params 64 | 65 | def get_default_model_params(self): 66 | """Returns default optimised model parameters.""" 67 | 68 | model_params = { 69 | 'dropout_rate': 0.3, 70 | 'hidden_layer_size': 320, 71 | 'learning_rate': 0.001, 72 | 'minibatch_size': 128, 73 | 'max_gradient_norm': 100., 74 | 'num_heads': 4, 75 | 'stack_size': 1 76 | } 77 | 78 | return model_params 79 | 80 | def get_num_samples_for_calibration(self): 81 | """Gets the default number of training and validation samples. 82 | Use to sub-sample the data for network calibration and a value of -1 uses 83 | all available samples. 84 | Returns: 85 | Tuple of (training samples, validation samples) 86 | """ 87 | return 450000, 50000 -------------------------------------------------------------------------------- /data/data_download.py: -------------------------------------------------------------------------------- 1 | import os 2 | import wget 3 | import pyunpack 4 | import pandas as pd 5 | import numpy as np 6 | 7 | def download_from_url(url, output_path): 8 | """Downloads a file froma url.""" 9 | 10 | print('Pulling data from {} to {}'.format(url, output_path)) 11 | wget.download(url, output_path) 12 | print('done') 13 | 14 | def recreate_folder(path): 15 | """Deletes and recreates folder.""" 16 | 17 | shutil.rmtree(path) 18 | os.makedirs(path) 19 | 20 | def unzip(zip_path, output_file, data_folder): 21 | """Unzips files and checks successful completion.""" 22 | 23 | print('Unzipping file: {}'.format(zip_path)) 24 | pyunpack.Archive(zip_path).extractall(data_folder) 25 | 26 | # Checks if unzip was successful 27 | if not os.path.exists(output_file): 28 | raise ValueError( 29 | 'Error in unzipping process! {} not found.'.format(output_file)) 30 | 31 | def download_and_unzip(url, zip_path, csv_path, data_folder): 32 | """Downloads and unzips an online csv file. 33 | Args: 34 | url: Web address 35 | zip_path: Path to download zip file 36 | csv_path: Expected path to csv file 37 | data_folder: Folder in which data is stored. 38 | """ 39 | 40 | download_from_url(url, zip_path) 41 | 42 | unzip(zip_path, csv_path, data_folder) 43 | 44 | print('Done.') 45 | 46 | def download_electricity(config): 47 | """Downloads electricity dataset from UCI repository.""" 48 | 49 | url = 'https://archive.ics.uci.edu/ml/machine-learning-databases/00321/LD2011_2014.txt.zip' 50 | 51 | data_folder = config.data_folder 52 | csv_path = os.path.join(data_folder, 'LD2011_2014.txt') 53 | zip_path = csv_path + '.zip' 54 | 55 | download_and_unzip(url, zip_path, csv_path, data_folder) 56 | 57 | print('Aggregating to hourly data') 58 | 59 | df = pd.read_csv(csv_path, index_col=0, sep=';', decimal=',') 60 | df.index = pd.to_datetime(df.index) 61 | df.sort_index(inplace=True) 62 | 63 | # Used to determine the start and end dates of a series 64 | output = df.resample('1h').mean().replace(0., np.nan) 65 | 66 | earliest_time = output.index.min() 67 | 68 | df_list = [] 69 | for label in output: 70 | print('Processing {}'.format(label)) 71 | srs = output[label] 72 | 73 | start_date = min(srs.fillna(method='ffill').dropna().index) 74 | end_date = max(srs.fillna(method='bfill').dropna().index) 75 | 76 | active_range = (srs.index >= start_date) & (srs.index <= end_date) 77 | srs = srs[active_range].fillna(0.) 78 | 79 | tmp = pd.DataFrame({'power_usage': srs}) 80 | date = tmp.index 81 | tmp['t'] = (date - earliest_time).seconds / 60 / 60 + ( 82 | date - earliest_time).days * 24 83 | tmp['days_from_start'] = (date - earliest_time).days 84 | tmp['categorical_id'] = label 85 | tmp['date'] = date 86 | tmp['id'] = label 87 | tmp['hour'] = date.hour 88 | tmp['day'] = date.day 89 | tmp['day_of_week'] = date.dayofweek 90 | tmp['month'] = date.month 91 | 92 | df_list.append(tmp) 93 | 94 | output = pd.concat(df_list, axis=0, join='outer').reset_index(drop=True) 95 | 96 | output['categorical_id'] = output['id'].copy() 97 | output['hours_from_start'] = output['t'] 98 | output['categorical_day_of_week'] = output['day_of_week'].copy() 99 | output['categorical_hour'] = output['hour'].copy() 100 | 101 | # Filter to match range used by other academic papers 102 | output = output[(output['days_from_start'] >= 1096) 103 | & (output['days_from_start'] < 1346)].copy() 104 | 105 | output.to_csv(config.data_csv_path) 106 | 107 | print('Done.') 108 | 109 | class Config(): 110 | def __init__(self, data_folder, csv_path): 111 | self.data_folder = data_folder 112 | self.data_csv_path = csv_path -------------------------------------------------------------------------------- /data/custom_dataset.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | from torch.utils.data import Dataset, DataLoader 4 | from utils.utils import get_single_col_by_input_type 5 | from utils.utils import extract_cols_from_data_type 6 | from data_formatters.electricity import ElectricityFormatter 7 | from data_formatters.base import DataTypes, InputTypes 8 | 9 | class TFTDataset(Dataset, ElectricityFormatter): 10 | """Dataset Basic Structure for Temporal Fusion Transformer""" 11 | 12 | def __init__(self, 13 | data_df): 14 | super(ElectricityFormatter, self).__init__() 15 | """ 16 | Args: 17 | csv_file (string): Path to the csv file with annotations. 18 | """ 19 | # Attribute loading the data 20 | self.data = data_df.reset_index(drop=True) 21 | 22 | self.id_col = get_single_col_by_input_type(InputTypes.ID, self._column_definition) 23 | self.time_col = get_single_col_by_input_type(InputTypes.TIME, self._column_definition) 24 | self.target_col = get_single_col_by_input_type(InputTypes.TARGET, self._column_definition) 25 | self.input_cols = [ 26 | tup[0] 27 | for tup in self._column_definition 28 | if tup[2] not in {InputTypes.ID, InputTypes.TIME} 29 | ] 30 | self.col_mappings = { 31 | 'identifier': [self.id_col], 32 | 'time': [self.time_col], 33 | 'outputs': [self.target_col], 34 | 'inputs': self.input_cols 35 | } 36 | self.lookback = self.get_time_steps() 37 | self.num_encoder_steps = self.get_num_encoder_steps() 38 | 39 | self.data_index = self.get_index_filtering() 40 | self.group_size = self.data.groupby([self.id_col]).apply(lambda x: x.shape[0]).mean() 41 | self.data_index = self.data_index[self.data_index.end_rel < self.group_size].reset_index() 42 | 43 | def get_index_filtering(self): 44 | 45 | g = self.data.groupby([self.id_col]) 46 | 47 | df_index_abs = g[[self.target_col]].transform(lambda x: x.index+self.lookback) \ 48 | .reset_index() \ 49 | .rename(columns={'index':'init_abs', 50 | self.target_col:'end_abs'}) 51 | df_index_rel_init = g[[self.target_col]].transform(lambda x: x.reset_index(drop=True).index) \ 52 | .rename(columns={self.target_col:'init_rel'}) 53 | df_index_rel_end = g[[self.target_col]].transform(lambda x: x.reset_index(drop=True).index+self.lookback) \ 54 | .rename(columns={self.target_col:'end_rel'}) 55 | df_total_count = g[[self.target_col]].transform(lambda x: x.shape[0] - self.lookback + 1) \ 56 | .rename(columns = {self.target_col:'group_count'}) 57 | 58 | return pd.concat([df_index_abs, 59 | df_index_rel_init, 60 | df_index_rel_end, 61 | self.data[[self.id_col]], 62 | df_total_count], axis = 1).reset_index(drop = True) 63 | 64 | def __len__(self): 65 | # In this case, the length of the dataset is not the length of the training data, 66 | # rather the ammount of unique sequences in the data 67 | return self.data_index.shape[0] 68 | 69 | def __getitem__(self, idx): 70 | 71 | data_index = self.data.iloc[self.data_index.init_abs.iloc[idx]:self.data_index.end_abs.iloc[idx]] 72 | 73 | data_map = {} 74 | for k in self.col_mappings: 75 | cols = self.col_mappings[k] 76 | 77 | if k not in data_map: 78 | data_map[k] = [data_index[cols].values] 79 | else: 80 | data_map[k].append(data_index[cols].values) 81 | 82 | # Combine all data 83 | for k in data_map: 84 | data_map[k] = np.concatenate(data_map[k], axis=0) 85 | # Shorten target so we only get decoder steps 86 | data_map['outputs'] = data_map['outputs'][self.num_encoder_steps:, :] 87 | 88 | active_entries = np.ones_like(data_map['outputs']) 89 | if 'active_entries' not in data_map: 90 | data_map['active_entries'] = active_entries 91 | else: 92 | data_map['active_entries'].append(active_entries) 93 | 94 | return data_map['inputs'], data_map['outputs'], data_map['active_entries'] -------------------------------------------------------------------------------- /quantile_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class QuantileLossCalculator(): 4 | """Computes the combined quantile loss for prespecified quantiles. 5 | Attributes: 6 | quantiles: Quantiles to compute losses 7 | """ 8 | 9 | def __init__(self, quantiles, output_size): 10 | """Initializes computer with quantiles for loss calculations. 11 | Args: 12 | quantiles: Quantiles to use for computations. 13 | """ 14 | self.quantiles = quantiles 15 | self.output_size = output_size 16 | 17 | # Loss functions. 18 | def quantile_loss(self, y, y_pred, quantile): 19 | """ Computes quantile loss for pytorch. 20 | Standard quantile loss as defined in the "Training Procedure" section of 21 | the main TFT paper 22 | Args: 23 | y: Targets 24 | y_pred: Predictions 25 | quantile: Quantile to use for loss calculations (between 0 & 1) 26 | Returns: 27 | Tensor for quantile loss. 28 | """ 29 | 30 | # Checks quantile 31 | if quantile < 0 or quantile > 1: 32 | raise ValueError( 33 | 'Illegal quantile value={}! Values should be between 0 and 1.'.format(quantile)) 34 | 35 | prediction_underflow = y - y_pred 36 | # print('prediction_underflow') 37 | # print(prediction_underflow.shape) 38 | q_loss = quantile * torch.max(prediction_underflow, torch.zeros_like(prediction_underflow)) + \ 39 | (1. - quantile) * torch.max(-prediction_underflow, torch.zeros_like(prediction_underflow)) 40 | 41 | # print('q_loss') 42 | # print(q_loss.shape) 43 | 44 | # loss = torch.mean(q_loss, dim = 1) 45 | # print('loss') 46 | # print(loss.shape) 47 | # return loss 48 | 49 | # return torch.sum(q_loss, dim = -1) 50 | return q_loss.unsqueeze(1) 51 | 52 | def apply(self, b, a): 53 | """Returns quantile loss for specified quantiles. 54 | Args: 55 | a: Targets 56 | b: Predictions 57 | """ 58 | quantiles_used = set(self.quantiles) 59 | 60 | loss = [] 61 | # loss = 0. 62 | for i, quantile in enumerate(self.quantiles): 63 | if quantile in quantiles_used: 64 | #print(a[Ellipsis, self.output_size * i:self.output_size * (i + 1)].shape) 65 | # loss += self.quantile_loss(a[Ellipsis, self.output_size * i:self.output_size * (i + 1)], 66 | # b[Ellipsis, self.output_size * i:self.output_size * (i + 1)], 67 | # quantile) 68 | #print(a[Ellipsis, self.output_size * i].shape) 69 | #loss += self.quantile_loss(a[Ellipsis, self.output_size * i], 70 | # b[Ellipsis, self.output_size * i], 71 | # quantile) 72 | 73 | # loss.append(self.quantile_loss(a[Ellipsis, self.output_size * i:self.output_size * (i + 1)], 74 | # b[Ellipsis, self.output_size * i:self.output_size * (i + 1)], 75 | # quantile)) 76 | 77 | loss.append(self.quantile_loss(a[Ellipsis, i], 78 | b[Ellipsis, i], 79 | quantile)) 80 | 81 | # loss_computed = torch.cat(loss, axis = -1) 82 | # loss_computed = torch.sum(loss_computed, axis = -1) 83 | # loss_computed = torch.sum(loss_computed, axis = 0) 84 | 85 | loss_computed = torch.mean(torch.sum(torch.cat(loss, axis = 1), axis = 1)) 86 | 87 | return loss_computed 88 | # return loss 89 | 90 | class NormalizedQuantileLossCalculator(): 91 | """Computes the combined quantile loss for prespecified quantiles. 92 | Attributes: 93 | quantiles: Quantiles to compute losses 94 | """ 95 | 96 | def __init__(self, quantiles, output_size): 97 | """Initializes computer with quantiles for loss calculations. 98 | Args: 99 | quantiles: Quantiles to use for computations. 100 | """ 101 | self.quantiles = quantiles 102 | self.output_size = output_size 103 | 104 | # Loss functions. 105 | def apply(self, y, y_pred, quantile): 106 | """ Computes quantile loss for pytorch. 107 | Standard quantile loss as defined in the "Training Procedure" section of 108 | the main TFT paper 109 | Args: 110 | y: Targets 111 | y_pred: Predictions 112 | quantile: Quantile to use for loss calculations (between 0 & 1) 113 | Returns: 114 | Tensor for quantile loss. 115 | """ 116 | 117 | # Checks quantile 118 | if quantile < 0 or quantile > 1: 119 | raise ValueError( 120 | 'Illegal quantile value={}! Values should be between 0 and 1.'.format(quantile)) 121 | 122 | prediction_underflow = y - y_pred 123 | # print('prediction_underflow') 124 | # print(prediction_underflow.shape) 125 | weighted_errors = quantile * torch.max(prediction_underflow, torch.zeros_like(prediction_underflow)) + \ 126 | (1. - quantile) * torch.max(-prediction_underflow, torch.zeros_like(prediction_underflow)) 127 | 128 | quantile_loss = torch.mean(weighted_errors) 129 | normaliser = torch.mean(torch.abs(quantile_loss)) 130 | return 2 * quantile_loss / normaliser -------------------------------------------------------------------------------- /data_formatters/volatility.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from utils.utils import get_single_col_by_input_type 3 | from utils.utils import extract_cols_from_data_type 4 | from .base import GenericDataFormatter 5 | from .base import DataTypes, InputTypes 6 | import sklearn.preprocessing 7 | 8 | class VolatilityFormatter(GenericDataFormatter): 9 | """Defines and formats data for the volatility dataset. 10 | Attributes: 11 | column_definition: Defines input and data type of column used in the 12 | experiment. 13 | identifiers: Entity identifiers used in experiments. 14 | """ 15 | 16 | _column_definition = [ 17 | ('Symbol', DataTypes.CATEGORICAL, InputTypes.ID), 18 | ('date', DataTypes.DATE, InputTypes.TIME), 19 | ('log_vol', DataTypes.REAL_VALUED, InputTypes.TARGET), 20 | ('open_to_close', DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT), 21 | ('days_from_start', DataTypes.REAL_VALUED, InputTypes.KNOWN_INPUT), 22 | ('day_of_week', DataTypes.CATEGORICAL, InputTypes.KNOWN_INPUT), 23 | ('day_of_month', DataTypes.CATEGORICAL, InputTypes.KNOWN_INPUT), 24 | ('week_of_year', DataTypes.CATEGORICAL, InputTypes.KNOWN_INPUT), 25 | ('month', DataTypes.CATEGORICAL, InputTypes.KNOWN_INPUT), 26 | ('Region', DataTypes.CATEGORICAL, InputTypes.STATIC_INPUT), 27 | ] 28 | 29 | def __init__(self): 30 | """Initialises formatter.""" 31 | 32 | self.identifiers = None 33 | self._real_scalers = None 34 | self._cat_scalers = None 35 | self._target_scaler = None 36 | self._num_classes_per_cat_input = None 37 | 38 | def split_data(self, df, valid_boundary=2016, test_boundary=2018): 39 | """Splits data frame into training-validation-test data frames. 40 | This also calibrates scaling object, and transforms data for each split. 41 | Args: 42 | df: Source data frame to split. 43 | valid_boundary: Starting year for validation data 44 | test_boundary: Starting year for test data 45 | Returns: 46 | Tuple of transformed (train, valid, test) data. 47 | """ 48 | 49 | print('Formatting train-valid-test splits.') 50 | 51 | index = df['year'] 52 | train = df.loc[index < valid_boundary] 53 | valid = df.loc[(index >= valid_boundary) & (index < test_boundary)] 54 | test = df.loc[index >= test_boundary] 55 | 56 | self.set_scalers(train) 57 | 58 | return (self.transform_inputs(data) for data in [train, valid, test]) 59 | 60 | def set_scalers(self, df): 61 | """Calibrates scalers using the data supplied. 62 | Args: 63 | df: Data to use to calibrate scalers. 64 | """ 65 | print('Setting scalers with training data...') 66 | 67 | column_definitions = self.get_column_definition() 68 | id_column = utils.get_single_col_by_input_type(InputTypes.ID, 69 | column_definitions) 70 | target_column = utils.get_single_col_by_input_type(InputTypes.TARGET, 71 | column_definitions) 72 | 73 | # Extract identifiers in case required 74 | self.identifiers = list(df[id_column].unique()) 75 | 76 | # Format real scalers 77 | real_inputs = utils.extract_cols_from_data_type( 78 | DataTypes.REAL_VALUED, column_definitions, 79 | {InputTypes.ID, InputTypes.TIME}) 80 | 81 | data = df[real_inputs].values 82 | self._real_scalers = sklearn.preprocessing.StandardScaler().fit(data) 83 | self._target_scaler = sklearn.preprocessing.StandardScaler().fit( 84 | df[[target_column]].values) # used for predictions 85 | 86 | # Format categorical scalers 87 | categorical_inputs = utils.extract_cols_from_data_type( 88 | DataTypes.CATEGORICAL, column_definitions, 89 | {InputTypes.ID, InputTypes.TIME}) 90 | 91 | categorical_scalers = {} 92 | num_classes = [] 93 | for col in categorical_inputs: 94 | # Set all to str so that we don't have mixed integer/string columns 95 | srs = df[col].apply(str) 96 | categorical_scalers[col] = sklearn.preprocessing.LabelEncoder().fit( 97 | srs.values) 98 | num_classes.append(srs.nunique()) 99 | 100 | # Set categorical scaler outputs 101 | self._cat_scalers = categorical_scalers 102 | self._num_classes_per_cat_input = num_classes 103 | 104 | def transform_inputs(self, df): 105 | """Performs feature transformations. 106 | This includes both feature engineering, preprocessing and normalisation. 107 | Args: 108 | df: Data frame to transform. 109 | Returns: 110 | Transformed data frame. 111 | """ 112 | output = df.copy() 113 | 114 | if self._real_scalers is None and self._cat_scalers is None: 115 | raise ValueError('Scalers have not been set!') 116 | 117 | column_definitions = self.get_column_definition() 118 | 119 | real_inputs = utils.extract_cols_from_data_type( 120 | DataTypes.REAL_VALUED, column_definitions, 121 | {InputTypes.ID, InputTypes.TIME}) 122 | categorical_inputs = utils.extract_cols_from_data_type( 123 | DataTypes.CATEGORICAL, column_definitions, 124 | {InputTypes.ID, InputTypes.TIME}) 125 | 126 | # Format real inputs 127 | output[real_inputs] = self._real_scalers.transform(df[real_inputs].values) 128 | 129 | # Format categorical inputs 130 | for col in categorical_inputs: 131 | string_df = df[col].apply(str) 132 | output[col] = self._cat_scalers[col].transform(string_df) 133 | 134 | return output 135 | 136 | def format_predictions(self, predictions): 137 | """Reverts any normalisation to give predictions in original scale. 138 | Args: 139 | predictions: Dataframe of model predictions. 140 | Returns: 141 | Data frame of unnormalised predictions. 142 | """ 143 | output = predictions.copy() 144 | 145 | column_names = predictions.columns 146 | 147 | for col in column_names: 148 | if col not in {'forecast_time', 'identifier'}: 149 | output[col] = self._target_scaler.inverse_transform(predictions[col]) 150 | 151 | return output 152 | 153 | # Default params 154 | def get_fixed_params(self): 155 | """Returns fixed model parameters for experiments.""" 156 | 157 | fixed_params = { 158 | 'total_time_steps': 252 + 5, 159 | 'num_encoder_steps': 252, 160 | 'num_epochs': 100, 161 | 'early_stopping_patience': 5, 162 | 'multiprocessing_workers': 5, 163 | } 164 | 165 | return fixed_params 166 | 167 | def get_default_model_params(self): 168 | """Returns default optimised model parameters.""" 169 | 170 | model_params = { 171 | 'dropout_rate': 0.3, 172 | 'hidden_layer_size': 160, 173 | 'learning_rate': 0.01, 174 | 'minibatch_size': 64, 175 | 'max_gradient_norm': 0.01, 176 | 'num_heads': 1, 177 | 'stack_size': 1 178 | } 179 | 180 | return model_params -------------------------------------------------------------------------------- /data_formatters/base.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import enum 3 | from utils.utils import get_single_col_by_input_type 4 | from utils.utils import extract_cols_from_data_type 5 | 6 | # Type defintions 7 | class DataTypes(enum.IntEnum): 8 | """Defines numerical types of each column.""" 9 | REAL_VALUED = 0 10 | CATEGORICAL = 1 11 | DATE = 2 12 | 13 | class InputTypes(enum.IntEnum): 14 | """Defines input types of each column.""" 15 | TARGET = 0 16 | OBSERVED_INPUT = 1 17 | KNOWN_INPUT = 2 18 | STATIC_INPUT = 3 19 | ID = 4 # Single column used as an entity identifier 20 | TIME = 5 # Single column exclusively used as a time index 21 | 22 | class GenericDataFormatter(abc.ABC): 23 | """Abstract base class for all data formatters. 24 | User can implement the abstract methods below to perform dataset-specific 25 | manipulations. 26 | """ 27 | 28 | @abc.abstractmethod 29 | def set_scalers(self, df): 30 | """Calibrates scalers using the data supplied.""" 31 | raise NotImplementedError() 32 | 33 | @abc.abstractmethod 34 | def transform_inputs(self, df): 35 | """Performs feature transformation.""" 36 | raise NotImplementedError() 37 | 38 | @abc.abstractmethod 39 | def format_predictions(self, df): 40 | """Reverts any normalisation to give predictions in original scale.""" 41 | raise NotImplementedError() 42 | 43 | @abc.abstractmethod 44 | def split_data(self, df): 45 | """Performs the default train, validation and test splits.""" 46 | raise NotImplementedError() 47 | 48 | @property 49 | @abc.abstractmethod 50 | def _column_definition(self): 51 | """Defines order, input type and data type of each column.""" 52 | raise NotImplementedError() 53 | 54 | @abc.abstractmethod 55 | def get_fixed_params(self): 56 | """Defines the fixed parameters used by the model for training. 57 | Requires the following keys: 58 | 'total_time_steps': Defines the total number of time steps used by TFT 59 | 'num_encoder_steps': Determines length of LSTM encoder (i.e. history) 60 | 'num_epochs': Maximum number of epochs for training 61 | 'early_stopping_patience': Early stopping param for keras 62 | 'multiprocessing_workers': # of cpus for data processing 63 | Returns: 64 | A dictionary of fixed parameters, e.g.: 65 | fixed_params = { 66 | 'total_time_steps': 252 + 5, 67 | 'num_encoder_steps': 252, 68 | 'num_epochs': 100, 69 | 'early_stopping_patience': 5, 70 | 'multiprocessing_workers': 5, 71 | } 72 | """ 73 | raise NotImplementedError() 74 | 75 | # Shared functions across data-formatters 76 | @property 77 | def num_classes_per_cat_input(self): 78 | """Returns number of categories per relevant input. 79 | This is seqeuently required for keras embedding layers. 80 | """ 81 | return self._num_classes_per_cat_input 82 | 83 | def get_num_samples_for_calibration(self): 84 | """Gets the default number of training and validation samples. 85 | Use to sub-sample the data for network calibration and a value of -1 uses 86 | all available samples. 87 | Returns: 88 | Tuple of (training samples, validation samples) 89 | """ 90 | return -1, -1 91 | 92 | def get_column_definition(self): 93 | """"Returns formatted column definition in order expected by the TFT.""" 94 | 95 | column_definition = self._column_definition 96 | 97 | # Sanity checks first. 98 | # Ensure only one ID and time column exist 99 | def _check_single_column(input_type): 100 | 101 | length = len([tup for tup in column_definition if tup[2] == input_type]) 102 | 103 | if length != 1: 104 | raise ValueError('Illegal number of inputs ({}) of type {}'.format( 105 | length, input_type)) 106 | 107 | _check_single_column(InputTypes.ID) 108 | _check_single_column(InputTypes.TIME) 109 | 110 | identifier = [tup for tup in column_definition if tup[2] == InputTypes.ID] 111 | time = [tup for tup in column_definition if tup[2] == InputTypes.TIME] 112 | real_inputs = [ 113 | tup for tup in column_definition if tup[1] == DataTypes.REAL_VALUED and 114 | tup[2] not in {InputTypes.ID, InputTypes.TIME} 115 | ] 116 | categorical_inputs = [ 117 | tup for tup in column_definition if tup[1] == DataTypes.CATEGORICAL and 118 | tup[2] not in {InputTypes.ID, InputTypes.TIME} 119 | ] 120 | 121 | return identifier + time + real_inputs + categorical_inputs 122 | 123 | def _get_input_columns(self): 124 | """Returns names of all input columns.""" 125 | return [ 126 | tup[0] 127 | for tup in self.get_column_definition() 128 | if tup[2] not in {InputTypes.ID, InputTypes.TIME} 129 | ] 130 | 131 | def _get_tft_input_indices(self): 132 | """Returns the relevant indexes and input sizes required by TFT.""" 133 | 134 | # Functions 135 | def _extract_tuples_from_data_type(data_type, defn): 136 | return [ 137 | tup for tup in defn if tup[1] == data_type and 138 | tup[2] not in {InputTypes.ID, InputTypes.TIME} 139 | ] 140 | 141 | def _get_locations(input_types, defn): 142 | return [i for i, tup in enumerate(defn) if tup[2] in input_types] 143 | 144 | # Start extraction 145 | column_definition = [ 146 | tup for tup in self.get_column_definition() 147 | if tup[2] not in {InputTypes.ID, InputTypes.TIME} 148 | ] 149 | 150 | categorical_inputs = _extract_tuples_from_data_type(DataTypes.CATEGORICAL, 151 | column_definition) 152 | real_inputs = _extract_tuples_from_data_type(DataTypes.REAL_VALUED, 153 | column_definition) 154 | 155 | locations = { 156 | 'input_size': 157 | len(self._get_input_columns()), 158 | 'output_size': 159 | len(_get_locations({InputTypes.TARGET}, column_definition)), 160 | 'category_counts': 161 | self.num_classes_per_cat_input, 162 | 'input_obs_loc': 163 | _get_locations({InputTypes.TARGET}, column_definition), 164 | 'static_input_loc': 165 | _get_locations({InputTypes.STATIC_INPUT}, column_definition), 166 | 'known_regular_inputs': 167 | _get_locations({InputTypes.STATIC_INPUT, InputTypes.KNOWN_INPUT}, 168 | real_inputs), 169 | 'known_categorical_inputs': 170 | _get_locations({InputTypes.STATIC_INPUT, InputTypes.KNOWN_INPUT}, 171 | categorical_inputs), 172 | } 173 | 174 | return locations 175 | 176 | def get_experiment_params(self): 177 | """Returns fixed model parameters for experiments.""" 178 | 179 | required_keys = [ 180 | 'total_time_steps', 'num_encoder_steps', 'num_epochs', 181 | 'early_stopping_patience', 'multiprocessing_workers' 182 | ] 183 | 184 | fixed_params = self.get_fixed_params() 185 | 186 | for k in required_keys: 187 | if k not in fixed_params: 188 | raise ValueError('Field {}'.format(k) + 189 | ' missing from fixed parameter definitions!') 190 | 191 | fixed_params['column_definition'] = self.get_column_definition() 192 | 193 | fixed_params.update(self._get_tft_input_indices()) 194 | 195 | return fixed_params -------------------------------------------------------------------------------- /data_formatters/electricity.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from utils.utils import get_single_col_by_input_type 3 | from utils.utils import extract_cols_from_data_type 4 | from .base import GenericDataFormatter 5 | from .base import DataTypes, InputTypes 6 | import sklearn.preprocessing 7 | 8 | class ElectricityFormatter(GenericDataFormatter): 9 | """Defines and formats data for the electricity dataset. 10 | Note that per-entity z-score normalization is used here, and is implemented 11 | across functions. 12 | Attributes: 13 | column_definition: Defines input and data type of column used in the 14 | experiment. 15 | identifiers: Entity identifiers used in experiments. 16 | """ 17 | 18 | _column_definition = [ 19 | ('id', DataTypes.REAL_VALUED, InputTypes.ID), 20 | ('hours_from_start', DataTypes.REAL_VALUED, InputTypes.TIME), 21 | ('power_usage', DataTypes.REAL_VALUED, InputTypes.TARGET), 22 | ('hour', DataTypes.REAL_VALUED, InputTypes.KNOWN_INPUT), 23 | ('day_of_week', DataTypes.REAL_VALUED, InputTypes.KNOWN_INPUT), 24 | ('hours_from_start', DataTypes.REAL_VALUED, InputTypes.KNOWN_INPUT), 25 | ('categorical_id', DataTypes.CATEGORICAL, InputTypes.STATIC_INPUT), 26 | ] 27 | 28 | def __init__(self): 29 | """Initialises formatter.""" 30 | 31 | self.identifiers = None 32 | self._real_scalers = None 33 | self._cat_scalers = None 34 | self._target_scaler = None 35 | self._num_classes_per_cat_input = None 36 | self._time_steps = self.get_fixed_params()['total_time_steps'] 37 | self._num_encoder_steps = self.get_fixed_params()['num_encoder_steps'] 38 | 39 | def get_time_steps(self): 40 | return self.get_fixed_params()['total_time_steps'] 41 | 42 | def get_num_encoder_steps(self): 43 | return self.get_fixed_params()['num_encoder_steps'] 44 | 45 | def split_data(self, df, valid_boundary=1315, test_boundary=1339): 46 | """Splits data frame into training-validation-test data frames. 47 | This also calibrates scaling object, and transforms data for each split. 48 | Args: 49 | df: Source data frame to split. 50 | valid_boundary: Starting year for validation data 51 | test_boundary: Starting year for test data 52 | Returns: 53 | Tuple of transformed (train, valid, test) data. 54 | """ 55 | 56 | print('Formatting train-valid-test splits.') 57 | 58 | index = df['days_from_start'] 59 | train = df.loc[index < valid_boundary] 60 | valid = df.loc[(index >= valid_boundary - 7) & (index < test_boundary)] 61 | test = df.loc[index >= test_boundary - 7] 62 | 63 | self.set_scalers(train) 64 | 65 | return (self.transform_inputs(data) for data in [train, valid, test]) 66 | 67 | def set_scalers(self, df): 68 | """Calibrates scalers using the data supplied. 69 | Args: 70 | df: Data to use to calibrate scalers. 71 | """ 72 | print('Setting scalers with training data...') 73 | 74 | column_definitions = self.get_column_definition() 75 | id_column = get_single_col_by_input_type(InputTypes.ID, 76 | column_definitions) 77 | target_column = get_single_col_by_input_type(InputTypes.TARGET, 78 | column_definitions) 79 | 80 | # Format real scalers 81 | real_inputs = extract_cols_from_data_type( 82 | DataTypes.REAL_VALUED, column_definitions, 83 | {InputTypes.ID, InputTypes.TIME}) 84 | 85 | # Initialise scaler caches 86 | self._real_scalers = {} 87 | self._target_scaler = {} 88 | identifiers = [] 89 | for identifier, sliced in df.groupby(id_column): 90 | 91 | if len(sliced) >= self._time_steps: 92 | 93 | data = sliced[real_inputs].values 94 | targets = sliced[[target_column]].values 95 | self._real_scalers[identifier] \ 96 | = sklearn.preprocessing.StandardScaler().fit(data) 97 | 98 | self._target_scaler[identifier] \ 99 | = sklearn.preprocessing.StandardScaler().fit(targets) 100 | identifiers.append(identifier) 101 | 102 | # Format categorical scalers 103 | categorical_inputs = extract_cols_from_data_type( 104 | DataTypes.CATEGORICAL, column_definitions, 105 | {InputTypes.ID, InputTypes.TIME}) 106 | 107 | categorical_scalers = {} 108 | num_classes = [] 109 | for col in categorical_inputs: 110 | # Set all to str so that we don't have mixed integer/string columns 111 | srs = df[col].apply(str) 112 | categorical_scalers[col] = sklearn.preprocessing.LabelEncoder().fit( 113 | srs.values) 114 | num_classes.append(srs.nunique()) 115 | 116 | # Set categorical scaler outputs 117 | self._cat_scalers = categorical_scalers 118 | self._num_classes_per_cat_input = num_classes 119 | 120 | # Extract identifiers in case required 121 | self.identifiers = identifiers 122 | 123 | def transform_inputs(self, df): 124 | """Performs feature transformations. 125 | This includes both feature engineering, preprocessing and normalisation. 126 | Args: 127 | df: Data frame to transform. 128 | Returns: 129 | Transformed data frame. 130 | """ 131 | 132 | if self._real_scalers is None and self._cat_scalers is None: 133 | raise ValueError('Scalers have not been set!') 134 | 135 | # Extract relevant columns 136 | column_definitions = self.get_column_definition() 137 | id_col = get_single_col_by_input_type(InputTypes.ID, 138 | column_definitions) 139 | real_inputs = extract_cols_from_data_type( 140 | DataTypes.REAL_VALUED, column_definitions, 141 | {InputTypes.ID, InputTypes.TIME}) 142 | categorical_inputs = extract_cols_from_data_type( 143 | DataTypes.CATEGORICAL, column_definitions, 144 | {InputTypes.ID, InputTypes.TIME}) 145 | 146 | # Transform real inputs per entity 147 | df_list = [] 148 | for identifier, sliced in df.groupby(id_col): 149 | 150 | # Filter out any trajectories that are too short 151 | if len(sliced) >= self._time_steps: 152 | sliced_copy = sliced.copy() 153 | sliced_copy[real_inputs] = self._real_scalers[identifier].transform( 154 | sliced_copy[real_inputs].values) 155 | df_list.append(sliced_copy) 156 | 157 | output = pd.concat(df_list, axis=0) 158 | 159 | # Format categorical inputs 160 | for col in categorical_inputs: 161 | string_df = df[col].apply(str) 162 | output[col] = self._cat_scalers[col].transform(string_df) 163 | 164 | return output 165 | 166 | def format_predictions(self, predictions): 167 | """Reverts any normalisation to give predictions in original scale. 168 | Args: 169 | predictions: Dataframe of model predictions. 170 | Returns: 171 | Data frame of unnormalised predictions. 172 | """ 173 | 174 | if self._target_scaler is None: 175 | raise ValueError('Scalers have not been set!') 176 | 177 | column_names = predictions.columns 178 | 179 | df_list = [] 180 | for identifier, sliced in predictions.groupby('identifier'): 181 | sliced_copy = sliced.copy() 182 | target_scaler = self._target_scaler[identifier] 183 | 184 | for col in column_names: 185 | if col not in {'forecast_time', 'identifier'}: 186 | sliced_copy[col] = target_scaler.inverse_transform(sliced_copy[col]) 187 | df_list.append(sliced_copy) 188 | 189 | output = pd.concat(df_list, axis=0) 190 | 191 | return output 192 | 193 | # Default params 194 | def get_fixed_params(self): 195 | """Returns fixed model parameters for experiments.""" 196 | 197 | fixed_params = { 198 | 'total_time_steps': 8 * 24, 199 | 'num_encoder_steps': 7 * 24, 200 | 'num_epochs': 100, 201 | 'early_stopping_patience': 5, 202 | 'multiprocessing_workers': 5 203 | } 204 | 205 | return fixed_params 206 | 207 | def get_default_model_params(self): 208 | """Returns default optimised model parameters.""" 209 | 210 | model_params = { 211 | 'dropout_rate': 0.1, 212 | 'hidden_layer_size': 160, 213 | 'learning_rate': 0.001, 214 | 'minibatch_size': 64, 215 | 'max_gradient_norm': 0.01, 216 | 'num_heads': 4, 217 | 'stack_size': 1 218 | } 219 | 220 | return model_params 221 | 222 | def get_num_samples_for_calibration(self): 223 | """Gets the default number of training and validation samples. 224 | Use to sub-sample the data for network calibration and a value of -1 uses 225 | all available samples. 226 | Returns: 227 | Tuple of (training samples, validation samples) 228 | """ 229 | return 450000, 50000 -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset, DataLoader 2 | import torch.nn.functional as F 3 | from torchvision import transforms, utils 4 | from torch import tanh 5 | from torch import nn 6 | import torch 7 | 8 | import pytorch_lightning as pl 9 | from pytorch_lightning.callbacks import EarlyStopping 10 | 11 | import sklearn.preprocessing 12 | 13 | class GatedLinearUnit(nn.Module): 14 | def __init__(self, input_size, 15 | hidden_layer_size, 16 | dropout_rate, 17 | activation = None): 18 | 19 | super(GatedLinearUnit, self).__init__() 20 | 21 | self.input_size = input_size 22 | self.hidden_layer_size = hidden_layer_size 23 | self.dropout_rate = dropout_rate 24 | self.activation_name = activation 25 | 26 | if self.dropout_rate: 27 | self.dropout = nn.Dropout(p=self.dropout_rate) 28 | 29 | self.W4 = torch.nn.Linear(self.input_size, self.hidden_layer_size) 30 | self.W5 = torch.nn.Linear(self.input_size, self.hidden_layer_size) 31 | 32 | if self.activation_name: 33 | self.activation = getattr(nn, self.activation_name)() 34 | 35 | self.sigmoid = nn.Sigmoid() 36 | 37 | self.init_weights() 38 | 39 | def init_weights(self): 40 | for n, p in self.named_parameters(): 41 | if 'bias' not in n: 42 | torch.nn.init.xavier_uniform_(p) 43 | # torch.nn.init.kaiming_normal_(p, a=0, mode='fan_in', nonlinearity='sigmoid') 44 | elif 'bias' in n: 45 | torch.nn.init.zeros_(p) 46 | 47 | def forward(self, x): 48 | 49 | if self.dropout_rate: 50 | x = self.dropout(x) 51 | 52 | if self.activation_name: 53 | output = self.sigmoid(self.W4(x)) * self.activation(self.W5(x)) 54 | else: 55 | output = self.sigmoid(self.W4(x)) * self.W5(x) 56 | 57 | return output 58 | 59 | class GateAddNormNetwork(nn.Module): 60 | def __init__(self, input_size, 61 | hidden_layer_size, 62 | dropout_rate, 63 | activation = None): 64 | 65 | super(GateAddNormNetwork, self).__init__() 66 | 67 | self.input_size = input_size 68 | self.hidden_layer_size = hidden_layer_size 69 | self.dropout_rate = dropout_rate 70 | self.activation_name = activation 71 | 72 | self.GLU = GatedLinearUnit(self.input_size, 73 | self.hidden_layer_size, 74 | self.dropout_rate, 75 | activation = self.activation_name) 76 | 77 | self.LayerNorm = nn.LayerNorm(self.hidden_layer_size) 78 | 79 | def forward(self, x, skip): 80 | 81 | output = self.LayerNorm(self.GLU(x) + skip) 82 | 83 | return output 84 | 85 | class GatedResidualNetwork(nn.Module): 86 | def __init__(self, 87 | hidden_layer_size, 88 | input_size = None, 89 | output_size = None, 90 | dropout_rate = None, 91 | additional_context = None, 92 | return_gate = False): 93 | 94 | super(GatedResidualNetwork, self).__init__() 95 | 96 | self.hidden_layer_size = hidden_layer_size 97 | self.input_size = input_size if input_size else self.hidden_layer_size 98 | self.output_size = output_size 99 | self.dropout_rate = dropout_rate 100 | self.additional_context = additional_context 101 | self.return_gate = return_gate 102 | 103 | self.W1 = torch.nn.Linear(self.hidden_layer_size, self.hidden_layer_size) 104 | self.W2 = torch.nn.Linear(self.input_size, self.hidden_layer_size) 105 | 106 | if self.additional_context: 107 | self.W3 = torch.nn.Linear(self.additional_context, self.hidden_layer_size, bias = False) 108 | 109 | 110 | if self.output_size: 111 | self.skip_linear = torch.nn.Linear(self.input_size, self.output_size) 112 | self.glu_add_norm = GateAddNormNetwork(self.hidden_layer_size, 113 | self.output_size, 114 | self.dropout_rate) 115 | else: 116 | self.glu_add_norm = GateAddNormNetwork(self.hidden_layer_size, 117 | self.hidden_layer_size, 118 | self.dropout_rate) 119 | 120 | self.init_weights() 121 | 122 | def init_weights(self): 123 | for name, p in self.named_parameters(): 124 | if ('W2' in name or 'W3' in name) and 'bias' not in name: 125 | torch.nn.init.kaiming_normal_(p, a=0, mode='fan_in', nonlinearity='leaky_relu') 126 | elif ('skip_linear' in name or 'W1' in name) and 'bias' not in name: 127 | torch.nn.init.xavier_uniform_(p) 128 | # torch.nn.init.kaiming_normal_(p, a=0, mode='fan_in', nonlinearity='sigmoid') 129 | elif 'bias' in name: 130 | torch.nn.init.zeros_(p) 131 | 132 | def forward(self, x): 133 | 134 | if self.additional_context: 135 | x, context = x 136 | #x_forward = self.W2(x) 137 | #context_forward = self.W3(context) 138 | #print(self.W3(context).shape) 139 | n2 = F.elu(self.W2(x) + self.W3(context)) 140 | else: 141 | n2 = F.elu(self.W2(x)) 142 | 143 | #print('n2 shape {}'.format(n2.shape)) 144 | 145 | n1 = self.W1(n2) 146 | 147 | #print('n1 shape {}'.format(n1.shape)) 148 | 149 | if self.output_size: 150 | output = self.glu_add_norm(n1, self.skip_linear(x)) 151 | else: 152 | output = self.glu_add_norm(n1, x) 153 | 154 | #print('output shape {}'.format(output.shape)) 155 | 156 | return output 157 | 158 | class ScaledDotProductAttention(nn.Module): 159 | def __init__(self, dropout = 0, scale = True): 160 | super(ScaledDotProductAttention, self).__init__() 161 | self.dropout = nn.Dropout(p=dropout) 162 | self.softmax = nn.Softmax(dim = 2) 163 | self.scale = scale 164 | 165 | def forward(self, q, k, v, mask = None): 166 | #print('---Inputs----') 167 | #print('q: {}'.format(q[0])) 168 | #print('k: {}'.format(k[0])) 169 | #print('v: {}'.format(v[0])) 170 | 171 | attn = torch.bmm(q, k.permute(0,2,1)) 172 | #print('first bmm') 173 | #print(attn.shape) 174 | #print('attn: {}'.format(attn[0])) 175 | 176 | if self.scale: 177 | dimention = torch.sqrt(torch.tensor(k.shape[-1]).to(torch.float32)) 178 | attn = attn / dimention 179 | # print('attn_scaled: {}'.format(attn[0])) 180 | 181 | if mask is not None: 182 | #fill = torch.tensor(-1e9).to(DEVICE) 183 | #zero = torch.tensor(0).to(DEVICE) 184 | attn = attn.masked_fill(mask == 0, -1e9) 185 | # print('attn_masked: {}'.format(attn[0])) 186 | 187 | attn = self.softmax(attn) 188 | #print('attn_softmax: {}'.format(attn[0])) 189 | attn = self.dropout(attn) 190 | 191 | output = torch.bmm(attn, v) 192 | 193 | return output, attn 194 | 195 | class InterpretableMultiHeadAttention(nn.Module): 196 | def __init__(self, n_head, d_model, dropout): 197 | super(InterpretableMultiHeadAttention, self).__init__() 198 | 199 | self.n_head = n_head 200 | self.d_model = d_model 201 | self.d_k = self.d_q = self.d_v = d_model // n_head 202 | self.dropout = nn.Dropout(p=dropout) 203 | 204 | self.v_layer = nn.Linear(self.d_model, self.d_v, bias = False) 205 | self.q_layers = nn.ModuleList([nn.Linear(self.d_model, self.d_q, bias = False) 206 | for _ in range(self.n_head)]) 207 | self.k_layers = nn.ModuleList([nn.Linear(self.d_model, self.d_k, bias = False) 208 | for _ in range(self.n_head)]) 209 | self.v_layers = nn.ModuleList([self.v_layer for _ in range(self.n_head)]) 210 | self.attention = ScaledDotProductAttention() 211 | self.w_h = nn.Linear(self.d_v, self.d_model, bias = False) 212 | 213 | self.init_weights() 214 | 215 | def init_weights(self): 216 | for name, p in self.named_parameters(): 217 | if 'bias' not in name: 218 | torch.nn.init.xavier_uniform_(p) 219 | # torch.nn.init.kaiming_normal_(p, a=0, mode='fan_in', nonlinearity='sigmoid') 220 | else: 221 | torch.nn.init.zeros_(p) 222 | 223 | def forward(self, q, k, v, mask = None): 224 | 225 | heads = [] 226 | attns = [] 227 | for i in range(self.n_head): 228 | qs = self.q_layers[i](q) 229 | ks = self.k_layers[i](k) 230 | vs = self.v_layers[i](v) 231 | #print('qs layer: {}'.format(qs.shape)) 232 | head, attn = self.attention(qs, ks, vs, mask) 233 | #print('head layer: {}'.format(head.shape)) 234 | #print('attn layer: {}'.format(attn.shape)) 235 | head_dropout = self.dropout(head) 236 | heads.append(head_dropout) 237 | attns.append(attn) 238 | 239 | head = torch.stack(heads, dim = 2) if self.n_head > 1 else heads[0] 240 | #print('concat heads: {}'.format(head.shape)) 241 | #print('heads {}: {}'.format(0, head[0,0,Ellipsis])) 242 | attn = torch.stack(attns, dim = 2) 243 | #print('concat attn: {}'.format(attn.shape)) 244 | 245 | outputs = torch.mean(head, dim = 2) if self.n_head > 1 else head 246 | #print('outputs mean: {}'.format(outputs.shape)) 247 | #print('outputs mean {}: {}'.format(0, outputs[0,0,Ellipsis])) 248 | outputs = self.w_h(outputs) 249 | outputs = self.dropout(outputs) 250 | 251 | return outputs, attn 252 | 253 | class VariableSelectionNetwork(nn.Module): 254 | def __init__(self, hidden_layer_size, 255 | dropout_rate, 256 | output_size, 257 | input_size = None, 258 | additional_context = None): 259 | super(VariableSelectionNetwork, self).__init__() 260 | 261 | self.hidden_layer_size = hidden_layer_size 262 | self.input_size = input_size 263 | self.output_size = output_size 264 | self.dropout_rate = dropout_rate 265 | self.additional_context = additional_context 266 | 267 | self.flattened_grn = GatedResidualNetwork(self.hidden_layer_size, 268 | input_size = self.input_size, 269 | output_size = self.output_size, 270 | dropout_rate = self.dropout_rate, 271 | additional_context=self.additional_context) 272 | 273 | self.per_feature_grn = nn.ModuleList([GatedResidualNetwork(self.hidden_layer_size, 274 | dropout_rate=self.dropout_rate) 275 | for i in range(self.output_size)]) 276 | def forward(self, x): 277 | # Non Static Inputs 278 | if self.additional_context: 279 | embedding, static_context = x 280 | #print('static_context') 281 | #print(static_context.shape) 282 | 283 | time_steps = embedding.shape[1] 284 | flatten = embedding.view(-1, time_steps, self.hidden_layer_size * self.output_size) 285 | #print('flatten') 286 | #print(flatten.shape) 287 | 288 | static_context = static_context.unsqueeze(1) 289 | #print('static_context') 290 | #print(static_context.shape) 291 | 292 | # Nonlinear transformation with gated residual network. 293 | mlp_outputs = self.flattened_grn((flatten, static_context)) 294 | #print('mlp_outputs') 295 | #print(mlp_outputs.shape) 296 | 297 | sparse_weights = F.softmax(mlp_outputs, dim = -1) 298 | sparse_weights = sparse_weights.unsqueeze(2) 299 | #print('sparse_weights') 300 | #print(sparse_weights.shape) 301 | 302 | trans_emb_list = [] 303 | for i in range(self.output_size): 304 | e = self.per_feature_grn[i](embedding[Ellipsis, i]) 305 | trans_emb_list.append(e) 306 | transformed_embedding = torch.stack(trans_emb_list, axis=-1) 307 | #print('transformed_embedding') 308 | #print(transformed_embedding.shape) 309 | 310 | combined = sparse_weights * transformed_embedding 311 | #print('combined') 312 | #print(combined.shape) 313 | 314 | temporal_ctx = torch.sum(combined, dim = -1) 315 | #print('temporal_ctx') 316 | #print(temporal_ctx.shape) 317 | 318 | # Static Inputs 319 | else: 320 | embedding = x 321 | #print('embedding') 322 | #print(embedding.shape) 323 | 324 | flatten = torch.flatten(embedding, start_dim=1) 325 | #flatten = embedding.view(batch_size, -1) 326 | #print('flatten') 327 | #print(flatten.shape) 328 | 329 | # Nonlinear transformation with gated residual network. 330 | mlp_outputs = self.flattened_grn(flatten) 331 | #print('mlp_outputs') 332 | #print(mlp_outputs.shape) 333 | 334 | sparse_weights = F.softmax(mlp_outputs, dim = -1) 335 | sparse_weights = sparse_weights.unsqueeze(-1) 336 | # print('sparse_weights') 337 | # print(sparse_weights.shape) 338 | 339 | trans_emb_list = [] 340 | for i in range(self.output_size): 341 | #print('embedding for the per feature static grn') 342 | #print(embedding[:, i:i + 1, :].shape) 343 | e = self.per_feature_grn[i](embedding[:, i:i + 1, :]) 344 | trans_emb_list.append(e) 345 | transformed_embedding = torch.cat(trans_emb_list, axis=1) 346 | # print('transformed_embedding') 347 | # print(transformed_embedding.shape) 348 | 349 | combined = sparse_weights * transformed_embedding 350 | # print('combined') 351 | # print(combined.shape) 352 | 353 | temporal_ctx = torch.sum(combined, dim = 1) 354 | # print('temporal_ctx') 355 | # print(temporal_ctx.shape) 356 | 357 | return temporal_ctx, sparse_weights --------------------------------------------------------------------------------