├── .gitignore ├── LICENSE ├── README.md ├── conf.etdataset.gridsearch.yml ├── conf.etdataset.yml ├── download_etdataset.py ├── main.py ├── readme_figures ├── loss.png └── preds.png ├── requirements.txt ├── tests ├── test_csv.csv ├── test_model.py └── test_tsmixer.py └── utils ├── __init__.py ├── load_csv.py ├── model.py ├── plotting.py ├── tsmixer.py ├── tsmixer_conf.py └── tsmixer_grid_search_conf.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | .DS_Store 3 | output*/ 4 | data/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2023 Oliver K. Ernst 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TSMixer in PyTorch 2 | 3 | Reimplementation of TSMixer in PyTorch. 4 | 5 | * Original paper: [https://arxiv.org/pdf/2303.06053.pdf](https://arxiv.org/abs/2303.06053) 6 | * Similar implementations: [https://github.com/marcopeix/time-series-analysis/blob/master/TSMixer.ipynb](https://github.com/marcopeix/time-series-analysis/blob/master/TSMixer.ipynb) 7 | 8 | ## Sample results 9 | 10 | ![Predictions on validation set](readme_figures/preds.png) 11 | *Predictions on validation set* 12 | 13 | ![Training loss](readme_figures/loss.png) 14 | *Loss during training* 15 | 16 | Parameters used for example: 17 | * `input_length`: 512 18 | * `prediction_length`: 96 19 | * `no_features`: 7 20 | * `no_mixer_layers`: 4 21 | * `dataset`: ETTh1.csv 22 | * `batch_size`: 32 23 | * `num_epochs`: 100 with early stopping after 5 epochs without improvement 24 | * `learning_rate`: 0.00001 25 | * `optimizer`: Adam 26 | * `validation_split_holdout`: 0.2 - last 20% of the time series data is used for validation 27 | * `dropout`: 0.3 28 | * `feat_mixing_hidden_channels`: 256 - number of hidden channels in the feature mixing layer 29 | 30 | ## Data 31 | 32 | You can find the raw ETDataset data [here](https://github.com/zhouhaoyi/ETDataset/tree/11ab373cf9c9f5be7698e219a5a170e1b1c8a930), specifically: 33 | 34 | * [ETTh1.csv](https://github.com/zhouhaoyi/ETDataset/raw/11ab373cf9c9f5be7698e219a5a170e1b1c8a930/ETT-small/ETTh1.csv) 35 | * [ETTh2.csv](https://github.com/zhouhaoyi/ETDataset/raw/11ab373cf9c9f5be7698e219a5a170e1b1c8a930/ETT-small/ETTh2.csv) 36 | * [ETTm1.csv](https://github.com/zhouhaoyi/ETDataset/raw/11ab373cf9c9f5be7698e219a5a170e1b1c8a930/ETT-small/ETTm1.csv) 37 | * [ETTm2.csv](https://github.com/zhouhaoyi/ETDataset/raw/11ab373cf9c9f5be7698e219a5a170e1b1c8a930/ETT-small/ETTm2.csv) 38 | 39 | You can use the `download_etdataset.py` script to download the data: 40 | 41 | ```bash 42 | python download_etdataset.py 43 | ``` 44 | 45 | ## Running 46 | 47 | Install the requirements: 48 | 49 | ```bash 50 | pip install -r requirements.txt 51 | ``` 52 | 53 | Train the model: 54 | 55 | ```bash 56 | python main.py --conf conf.etdataset.yml --command train 57 | ``` 58 | 59 | The output will be in the `output_dir` directory specified in the config file. The config file is in YAML format. The format is defined by [utils/tsmixer_conf.py](utils/tsmixer_conf.py). 60 | 61 | Plot the loss curves: 62 | 63 | ```bash 64 | python main.py --conf conf.etdataset.yml --command loss --show 65 | ``` 66 | 67 | Predict some of the validation data and plot it: 68 | 69 | ```bash 70 | python main.py --conf conf.etdataset.yml --command predict --show 71 | ``` 72 | 73 | Run a grid search over the hyperparameters: 74 | 75 | ```bash 76 | python main.py --conf conf.etdataset.gridsearch.yml --command grid-search 77 | ``` 78 | 79 | Note that the format of the config file is different for the grid search. The format is defined by [utils/tsmixer_grid_search_conf.py](utils/tsmixer_grid_search_conf.py). 80 | 81 | ### Tests 82 | 83 | Run the tests with `pytest`: 84 | 85 | ```bash 86 | cd tests 87 | pytest 88 | ``` 89 | 90 | ## Implementation notes from the paper 91 | 92 | ### Training parameters 93 | 94 | > For multivariate long-term forecasting datasets, we follow the settings in recent research (Liu et al., 2022b; Zhou et al., 2022a; Nie et al., 2023). We set the input length L = 512 as suggested in Nie et al. (2023) and evaluate the results for prediction lengths of T = {96, 192, 336, 720}. We use the Adam optimization algorithm (Kingma & Ba, 2015) to minimize the mean square error (MSE) training objective, and consider MSE and mean absolute error (MAE) as the evaluation metrics. We apply reversible instance normalization (Kim et al., 2022) to ensure a fair comparison with the state-of-the-art PatchTST (Nie et al., 2023). 95 | 96 | > For the M5 dataset, we mostly follow the data processing from Alexandrov et al. (2020). We consider the prediction length of T = 28 (same as the competition), and set the input length to L = 35. We optimize log-likelihood of negative binomial distribution as suggested by Salinas et al. (2020). We follow the competition’s protocol (Makridakis et al., 2022) to aggregate the predictions at different levels and evaluate them using the weighted root mean squared scaled error (WRMSSE). More details about the experimental setup and hyperparameter tuning can be found in Appendices C and E. 97 | 98 | ### Reversible Instance Normalization for Time Series Forecasting 99 | 100 | Reversible instance normalization https://openreview.net/pdf?id=cGDAkQo1C0p 101 | 102 | > First, we normalize the input data x(i) using its instance-specific mean and stan- dard deviation, which is widely accepted as instance normalization (Ulyanov et al., 2016). The mean and standard deviation are computed for every instance x(i) ∈ RTx of the input data (Fig. 2(a-3)) as 103 | 104 | ``` 105 | Mean[xi_kt] = mean_{j=1}^Tx ( xi_kj ) 106 | Var[xi_kt] = var_{j=1}^Tx ( xi_kj ) 107 | ``` 108 | Where `i` = sample in the batch, `K` = num variables (features), `Tx` = num time steps in input, `Ty` = num time steps in output (prediction). 109 | 110 | > Then, we apply the normalization to the **input data** (sent to model) as 111 | 112 | ``` 113 | xhati_kt = gamma_k * (xi_kt - Mean[xi_kt]) / sqrt(Var[xi_kt] + epsilon) + beta_k 114 | ``` 115 | 116 | where gamma_k and beta_k are learnable parameters for each variable k (**recall: K = num features**). 117 | 118 | After final layer of model, we get output `yhati_kt`, apply the reverse transformation to the **output data** (sent to loss function) as 119 | 120 | ``` 121 | yi_kt = (yhati_kt - beta_k) * sqrt(Var[xi_kt] + epsilon) / gamma_k + Mean[xi_kt] 122 | ``` 123 | 124 | where `yhati_kt` is the output of the model for variable `k` at time `t` for sample `i`, and `yi_kt` is sent to the loss function. 125 | 126 | ### Details on multivariate time series forecasting experiments 127 | 128 | Input = matrix X of size (L,C) where L = num time steps, C = num features 129 | Output = prediction of size (T,C) where T = num time steps 130 | 131 | > B.3.2 Basic TSMixer for Multivariate Time Series Forecasting 132 | > For long-term time series forecasting (LTSF) tasks, TSMixer only uses the historical target time series X as input. A series of mixer blocks are applied to project the input data to a latent representation of size C. The final output is then projected to the prediction length T: 133 | ``` 134 | O_1 = Mix[C->C] (X) 135 | O_k = Mix[C->C] (O_{k-1}), for k = 2,...,K 136 | Y = TP[L->T] (O_K) 137 | ``` 138 | > where Ok is the latent representation of the k-th mixer block and Yˆ is the prediction. We project the sequence to length T after the mixer blocks as T may be quite long in LTSF tasks. 139 | 140 | i.e. keep the number of features the same as C, and use the same input time length L in the mixture blocks, then project to longer length L for the output. 141 | 142 | ### Hidden layers of feature mixing 143 | 144 | > To increase the model capacity, we modify the hidden layers in Feature Mixing by using W2 ∈ (H×C),W3 ∈ (C×H),b2 ∈ H,b3 ∈ C in Eq. equation B.3.1, where H is a hyper-parameter indicating the hidden size. 145 | 146 | i.e. in th feature mixing block, where there are two fully connected layers, the first projects the number of channels from C->H and the second from H->C, where H is an additional parameter. 147 | 148 | > Another modification is using pre-normalization (Xiong et al., 2020) instead of post-normalization in residual blocks to keep the input scale. 149 | 150 | i.e. apply normalization to the input of the feature mixing block, instead of the output. 151 | 152 | ## Standardization of data 153 | 154 | > Specifically, we standardize each covariate independently and do not re-scale the data when evaluating the performance. 155 | 156 | > Global normalization: Global normalization standardizes all variates of time series independently as a data pre-processing. The standardized data is then used for training and evaluation. It is a common setup in long-term time series forecasting experiments to prevent from the affects of different variate scales. For M5, since there is only one target time series (sales), we do not apply the global normalization. 157 | 158 | Standardize each feature independently based on the training split, then use the same mean and standard deviation for the test set. 159 | 160 | > We train each model with a maximum 100 epochs and do early stopping if the validation loss is not improved after 5 epochs. 161 | 162 | Max 100 epochs, early stopping after 5 epochs without improvement. -------------------------------------------------------------------------------- /conf.etdataset.gridsearch.yml: -------------------------------------------------------------------------------- 1 | --- 2 | # Any number of parameter ranges to run the grid search over 3 | param_ranges: 4 | - input_lengths: [512] 5 | prediction_lengths: [96] 6 | learning_rates: [0.00001] 7 | no_mixer_layers: [2,4] 8 | dropouts: [0.3,0.5] 9 | feat_mixing_hidden_channels: [64,256] 10 | 11 | # Output directory - each run will be saved in a subdirectory of this 12 | output_dir: output.etdataset.gridsearch 13 | 14 | # Num features in the dataset 15 | no_features: 7 16 | 17 | # Data source 18 | data_src: csv-file 19 | data_src_csv: ETTh1.csv -------------------------------------------------------------------------------- /conf.etdataset.yml: -------------------------------------------------------------------------------- 1 | --- 2 | # REQUIRED: Output directory 3 | output_dir: output.etdataset 4 | 5 | # REQUIRED: Input length 6 | input_length: 512 7 | 8 | # REQUIRED: Prediction length 9 | prediction_length: 96 10 | 11 | # REQUIRED: Number of features 12 | no_features: 7 13 | 14 | # REQUIRED: Number of mixer layers 15 | no_mixer_layers: 2 16 | 17 | # REQUIRED: Data source 18 | data_src: csv-file 19 | 20 | # REQUIRED: Path to the data source 21 | data_src_csv: data/ETTh1.csv 22 | 23 | # Optional: How to initialize the model (restarts training from scratch if set to from-scratch) 24 | initialize: from-scratch # from-best-checkpoint 25 | 26 | # Optional: Batch size 27 | batch_size: 32 28 | 29 | # Optional: Number of epochs 30 | num_epochs: 100 31 | 32 | # Optional: Learning rate 33 | learning_rate: 0.00001 34 | 35 | # Optional: Optimizer - Adam, SGD, RMSprop, etc 36 | optimizer: Adam 37 | 38 | # Optional: Random seed 39 | random_seed: 42 40 | 41 | # Optional: Validation split method 42 | validation_split: temporal-holdout 43 | 44 | # Optional: Validation split holdout - fraction of the data to be used for validation 45 | # (only used if validation_split is set to temporal-holdout) 46 | validation_split_holdout: 0.2 47 | 48 | # Optional: dropout 49 | dropout: 0.3 50 | 51 | # Optional: Number of hidden channels in the feature mixing layer 52 | feat_mixing_hidden_channels: null -------------------------------------------------------------------------------- /download_etdataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import requests 3 | from loguru import logger 4 | import argparse 5 | 6 | DATASET_TO_URL = { 7 | "ETTh1": "https://github.com/zhouhaoyi/ETDataset/raw/11ab373cf9c9f5be7698e219a5a170e1b1c8a930/ETT-small/ETTh1.csv", 8 | "ETTh2": "https://github.com/zhouhaoyi/ETDataset/raw/11ab373cf9c9f5be7698e219a5a170e1b1c8a930/ETT-small/ETTh2.csv", 9 | "ETTm1": "https://github.com/zhouhaoyi/ETDataset/raw/11ab373cf9c9f5be7698e219a5a170e1b1c8a930/ETT-small/ETTm1.csv", 10 | "ETTm2": "https://github.com/zhouhaoyi/ETDataset/raw/11ab373cf9c9f5be7698e219a5a170e1b1c8a930/ETT-small/ETTm2.csv" 11 | } 12 | 13 | def download_etdataset(url: str, download_directory: str): 14 | 15 | # Make sure the directory exists, create it if not 16 | if download_directory != "": 17 | os.makedirs(download_directory, exist_ok=True) 18 | 19 | # Extract the file name from the URL 20 | file_name = os.path.join(download_directory, os.path.basename(url)) 21 | 22 | if os.path.exists(file_name): 23 | logger.info(f"File {file_name} already exists, skipping download") 24 | return 25 | logger.info(f"Downloading {dataset} to {args.dir} from {url}") 26 | 27 | # Send an HTTP GET request to the URL 28 | response = requests.get(url) 29 | 30 | # Check if the request was successful (status code 200) 31 | if response.status_code == 200: 32 | # Save the content of the response to the file 33 | with open(file_name, "wb") as file: 34 | file.write(response.content) 35 | logger.info(f"File downloaded and saved to {file_name}") 36 | else: 37 | logger.info(f"Failed to download the file. Status code: {response.status_code}") 38 | 39 | if __name__ == "__main__": 40 | 41 | parser = argparse.ArgumentParser() 42 | parser.add_argument("--dir", type=str, required=False, default="data", help="Directory to save the downloaded files to") 43 | args = parser.parse_args() 44 | 45 | for dataset,url in DATASET_TO_URL.items(): 46 | download_etdataset(url, download_directory=args.dir) -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from utils import TSMixer, plot_preds, plot_loss, TSMixerConf, TSMixerGridSearch 2 | 3 | import argparse 4 | import yaml 5 | import os 6 | 7 | 8 | if __name__ == "__main__": 9 | 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument("--command", type=str, required=True, choices=["train", "predict", "loss", "grid-search"], help="Command to run") 12 | parser.add_argument("--conf", type=str, required=False, help="Path to the configuration file") 13 | parser.add_argument("--no-feats-plot", type=int, required=False, default=6, help="Number of features to plot") 14 | parser.add_argument("--show", action="store_true", required=False, help="Show plots") 15 | args = parser.parse_args() 16 | 17 | if args.command == "train": 18 | # Load configuration 19 | assert args.conf is not None, "Must provide a configuration file" 20 | with open(args.conf, "r") as f: 21 | conf = TSMixerConf.from_dict(yaml.safe_load(f)) 22 | tsmixer = TSMixer(conf) 23 | 24 | # Train 25 | tsmixer.train() 26 | 27 | elif args.command == "predict": 28 | 29 | assert args.conf is not None, "Must provide a configuration file" 30 | with open(args.conf, "r") as f: 31 | conf = TSMixerConf.from_dict(yaml.safe_load(f)) 32 | 33 | # Load best checkpoint 34 | conf.initialize = TSMixerConf.Initialize.FROM_BEST_CHECKPOINT 35 | 36 | tsmixer = TSMixer(conf) 37 | 38 | # Predict on validation dataset 39 | data = tsmixer.predict_val_dataset(max_samples=10, save_inputs=False) 40 | 41 | # Plot predictions 42 | data_plt = data[0] 43 | assert args.no_feats_plot is not None, "Must provide number of features to plot" 44 | plot_preds( 45 | preds=data_plt.pred, 46 | preds_gt=data_plt.pred_gt, 47 | no_feats_plot=args.no_feats_plot, 48 | show=args.show, 49 | fname_save=os.path.join(conf.image_dir, "preds.png") 50 | ) 51 | 52 | elif args.command == "loss": 53 | 54 | assert args.conf is not None, "Must provide a configuration file" 55 | with open(args.conf, "r") as f: 56 | conf = TSMixerConf.from_dict(yaml.safe_load(f)) 57 | 58 | train_data = conf.load_training_metadata_or_new() 59 | plot_loss( 60 | train_data=train_data, 61 | show=args.show, 62 | fname_save=os.path.join(conf.image_dir, "loss.png") 63 | ) 64 | 65 | elif args.command == "grid-search": 66 | 67 | # Load configuration 68 | assert args.conf is not None, "Must provide a configuration file" 69 | with open(args.conf, "r") as f: 70 | conf_grid_search = TSMixerGridSearch.from_dict(yaml.safe_load(f)) 71 | 72 | # Run grid search 73 | for conf in conf_grid_search.iterate(): 74 | tsmixer = TSMixer(conf) 75 | tsmixer.train() 76 | 77 | else: 78 | raise NotImplementedError(f"Command {args.command} not implemented") 79 | 80 | -------------------------------------------------------------------------------- /readme_figures/loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smrfeld/tsmixer-pytorch/342a6ebb323efff75f96909203c64c9e1e7d7aa5/readme_figures/loss.png -------------------------------------------------------------------------------- /readme_figures/preds.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smrfeld/tsmixer-pytorch/342a6ebb323efff75f96909203c64c9e1e7d7aa5/readme_figures/preds.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | loguru==0.7.2 2 | mashumaro==3.10 3 | pandas==2.1.3 4 | plotly==5.18.0 5 | pytest==7.4.3 6 | PyYAML==6.0.1 7 | Requests==2.31.0 8 | torch==2.1.1 9 | tqdm==4.66.1 10 | -------------------------------------------------------------------------------- /tests/test_csv.csv: -------------------------------------------------------------------------------- 1 | date,HUFL,HULL,MUFL,MULL,LUFL,LULL,OT 2 | 2016-07-01 00:00:00,5.827000141143799,2.009000062942505,1.5989999771118164,0.4620000123977661,4.203000068664552,1.3400000333786009,30.5310001373291 3 | 2016-07-01 01:00:00,5.692999839782715,2.075999975204468,1.4919999837875366,0.4259999990463257,4.142000198364259,1.371000051498413,27.78700065612793 4 | 2016-07-01 02:00:00,5.1570000648498535,1.741000056266785,1.2790000438690186,0.35499998927116394,3.776999950408936,1.218000054359436,27.78700065612793 5 | 2016-07-01 03:00:00,5.0900001525878915,1.9420000314712524,1.2790000438690186,0.3910000026226044,3.806999921798706,1.2790000438690186,25.04400062561035 6 | 2016-07-01 04:00:00,5.357999801635742,1.9420000314712524,1.4919999837875366,0.4620000123977661,3.868000030517578,1.2790000438690186,21.947999954223643 7 | 2016-07-01 05:00:00,5.625999927520752,2.1429998874664307,1.5279999971389768,0.5329999923706055,4.051000118255615,1.371000051498413,21.173999786376953 8 | 2016-07-01 06:00:00,7.166999816894531,2.9470000267028813,2.131999969482422,0.7820000052452087,5.026000022888184,1.8580000400543213,22.79199981689453 9 | 2016-07-01 07:00:00,7.434999942779541,3.282000064849853,2.309999942779541,1.031000018119812,5.086999893188477,2.2239999771118164,23.143999099731445 10 | 2016-07-01 08:00:00,5.559000015258789,3.013999938964844,2.45199990272522,1.1729999780654907,2.9549999237060547,1.4320000410079956,21.66699981689453 11 | 2016-07-01 09:00:00,4.554999828338623,2.5450000762939453,1.919000029563904,0.8169999718666077,2.6800000667572017,1.371000051498413,17.445999145507812 12 | 2016-07-01 10:00:00,4.956999778747559,2.5450000762939453,1.9900000095367432,0.8529999852180481,2.9549999237060547,1.4919999837875366,19.979000091552734 13 | 2016-07-01 11:00:00,5.760000228881836,2.5450000762939453,2.203000068664551,0.8529999852180481,3.441999912261963,1.4919999837875366,20.11899948120117 14 | 2016-07-01 12:00:00,4.689000129699707,2.5450000762939453,1.812000036239624,0.8529999852180481,2.8329999446868896,1.5230000019073486,19.20499992370605 15 | 2016-07-01 13:00:00,4.689000129699707,2.678999900817871,1.7769999504089355,1.24399995803833,3.1070001125335693,1.6139999628067017,18.57200050354004 16 | 2016-07-01 14:00:00,5.0900001525878915,2.9470000267028813,2.45199990272522,1.350000023841858,2.559000015258789,1.4320000410079956,19.55599975585937 17 | 2016-07-01 15:00:00,5.0900001525878915,3.1480000019073486,2.486999988555908,1.350000023841858,2.58899998664856,1.5230000019073486,17.305000305175778 18 | 2016-07-01 16:00:00,4.2199997901916495,2.4110000133514404,1.7059999704360962,0.7820000052452087,2.61899995803833,1.4919999837875366,19.48600006103516 19 | 2016-07-01 17:00:00,4.75600004196167,2.344000101089477,1.6349999904632568,0.7110000252723694,3.075999975204468,1.4919999837875366,19.13400077819824 20 | 2016-07-01 18:00:00,5.625999927520752,2.880000114440918,2.5230000019073486,1.2079999446868896,3.075999975204468,1.4919999837875366,20.68199920654297 21 | 2016-07-01 19:00:00,5.492000102996826,3.013999938964844,2.45199990272522,1.2079999446868896,3.015000104904175,1.5529999732971191,18.71199989318848 22 | 2016-07-01 20:00:00,5.357999801635742,3.013999938964844,2.45199990272522,1.2079999446868896,2.86299991607666,1.5230000019073486,17.868000030517578 23 | 2016-07-01 21:00:00,5.0900001525878915,2.9470000267028813,2.38100004196167,1.2079999446868896,2.6800000667572017,1.5230000019073486,18.009000778198242 24 | 2016-07-01 22:00:00,4.822999954223633,2.9470000267028813,2.203000068664551,1.1729999780654907,2.61899995803833,1.5230000019073486,18.009000778198242 25 | 2016-07-01 23:00:00,4.622000217437744,2.880000114440918,2.131999969482422,1.1369999647140503,2.4670000076293945,1.4919999837875366,19.76799964904785 26 | 2016-07-02 00:00:00,5.223999977111816,3.0810000896453857,2.700999975204468,1.315000057220459,2.437000036239624,1.5230000019073486,21.104000091552734 27 | 2016-07-02 01:00:00,5.1570000648498535,3.013999938964844,2.878000020980835,1.350000023841858,2.345000028610229,1.4320000410079956,19.69700050354004 28 | 2016-07-02 02:00:00,5.1570000648498535,3.1480000019073486,2.878000020980835,1.4919999837875366,2.2839999198913574,1.4320000410079956,20.048999786376953 29 | 2016-07-02 03:00:00,5.1570000648498535,3.0810000896453857,2.9140000343322754,1.4919999837875366,2.193000078201294,1.4010000228881836,20.75200080871582 30 | 2016-07-02 04:00:00,4.554999828338623,3.0810000896453857,2.45199990272522,1.4919999837875366,2.193000078201294,1.4010000228881836,21.38500022888184 31 | 2016-07-02 05:00:00,5.425000190734863,3.282000064849853,3.0920000076293945,1.7059999704360962,2.437000036239624,1.462000012397766,22.229999542236328 32 | 2016-07-02 06:00:00,5.492000102996826,3.282000064849853,2.5230000019073486,1.4919999837875366,2.984999895095825,1.462000012397766,20.26000022888184 33 | 2016-07-02 07:00:00,5.625999927520752,3.2149999141693115,2.486999988555908,1.4919999837875366,3.075999975204468,1.5230000019073486,21.104000091552734 34 | 2016-07-02 08:00:00,5.559000015258789,3.282000064849853,2.594000101089477,1.6699999570846558,2.924000024795532,1.5230000019073486,20.61199951171875 35 | 2016-07-02 09:00:00,5.223999977111816,3.2149999141693115,2.559000015258789,1.5640000104904177,2.6800000667572017,1.462000012397766,18.36100006103516 36 | 2016-07-02 10:00:00,9.913000106811523,4.956999778747559,6.644999980926514,3.3050000667572017,3.0460000038146973,1.5529999732971191,20.96299934387207 37 | 2016-07-02 11:00:00,11.788000106811525,5.425000190734863,8.173000335693361,2.5230000019073486,3.686000108718872,1.6749999523162842,19.416000366210934 38 | 2016-07-02 12:00:00,9.645000457763672,4.956999778747559,6.751999855041504,2.131999969482422,3.1070001125335693,1.8279999494552608,20.82299995422364 39 | 2016-07-02 13:00:00,10.381999969482422,5.760000228881836,7.4619998931884775,2.559000015258789,2.984999895095825,1.7669999599456787,20.190000534057614 40 | 2016-07-02 14:00:00,8.77400016784668,4.689000129699707,6.111999988555907,2.025000095367432,2.8940000534057617,1.919000029563904,21.315000534057614 41 | 2016-07-02 15:00:00,10.449000358581543,5.1570000648498535,6.965000152587893,2.45199990272522,2.7720000743865967,1.7359999418258667,22.018999099731445 42 | 2016-07-02 16:00:00,9.845999717712402,4.822999954223633,7.035999774932861,2.664999961853028,2.8940000534057617,1.7669999599456787,20.68199920654297 43 | 2016-07-02 17:00:00,9.913000106811523,4.822999954223633,6.894000053405763,2.4159998893737797,3.2290000915527344,1.7359999418258667,25.465999603271484 44 | 2016-07-02 18:00:00,10.649999618530273,4.689000129699707,6.928999900817871,2.45199990272522,3.38100004196167,1.7970000505447388,25.88800048828125 45 | 2016-07-02 19:00:00,10.11400032043457,4.3540000915527335,6.644999980926514,1.812000036239624,3.1070001125335693,1.7359999418258667,27.85700035095215 46 | 2016-07-02 20:00:00,9.979999542236328,4.152999877929688,6.573999881744385,1.9539999961853027,3.4110000133514404,1.7669999599456787,27.295000076293945 47 | 2016-07-02 21:00:00,9.3100004196167,4.2199997901916495,6.005000114440918,2.131999969482422,3.2290000915527344,1.8580000400543213,22.229999542236328 48 | 2016-07-02 22:00:00,9.444000244140623,4.622000217437744,6.965000152587893,2.1679999828338623,2.9549999237060547,1.8580000400543213,21.947999954223643 49 | 2016-07-02 23:00:00,9.444000244140623,4.287000179290772,6.822999954223633,2.559000015258789,2.58899998664856,1.7359999418258667,27.295000076293945 50 | 2016-07-03 00:00:00,10.381999969482422,5.425000190734863,7.604000091552732,2.309999942779541,2.9549999237060547,1.6749999523162842,29.334999084472656 51 | 2016-07-03 01:00:00,9.779000282287598,5.223999977111816,6.716000080108643,2.8429999351501465,2.650000095367432,1.6749999523162842,26.02799987792969 52 | 2016-07-03 02:00:00,10.381999969482422,4.689000129699707,7.320000171661378,2.203000068664551,2.984999895095825,1.8580000400543213,24.34000015258789 53 | 2016-07-03 03:00:00,9.779000282287598,4.152999877929688,6.822999954223633,1.9900000095367432,2.5280001163482666,1.6749999523162842,26.45000076293945 54 | 2016-07-03 04:00:00,10.717000007629396,4.75600004196167,7.355999946594237,2.806999921798706,2.650000095367432,1.7970000505447388,25.95800018310547 55 | 2016-07-03 05:00:00,10.3149995803833,4.689000129699707,7.390999794006348,2.45199990272522,2.924000024795532,1.8580000400543213,24.05900001525879 56 | 2016-07-03 06:00:00,12.592000007629395,5.223999977111816,8.670999526977539,2.203000068664551,3.7160000801086426,1.9490000009536743,25.32500076293945 57 | 2016-07-03 07:00:00,11.119000434875488,4.622000217437744,7.888999938964844,2.8429999351501465,3.625,1.919000029563904,23.636999130249023 58 | 2016-07-03 08:00:00,10.649999618530273,4.421000003814697,7.035999774932861,2.025000095367432,3.594000101089477,1.919000029563904,26.3799991607666 59 | 2016-07-03 09:00:00,10.04699993133545,4.2199997901916495,6.432000160217285,1.6699999570846558,3.686000108718872,1.9490000009536743,27.364999771118164 60 | 2016-07-03 10:00:00,11.720999717712402,5.0900001525878915,7.888999938964844,2.559000015258789,3.563999891281128,1.8580000400543213,28.06800079345703 61 | 2016-07-03 11:00:00,12.123000144958494,5.357999801635742,8.065999984741211,2.486999988555908,4.081999778747559,1.919000029563904,29.47500038146973 62 | 2016-07-03 12:00:00,9.979999542236328,5.0229997634887695,6.857999801635742,2.559000015258789,3.2899999618530287,1.8580000400543213,26.80200004577637 63 | 2016-07-03 13:00:00,9.243000030517578,4.956999778747559,6.289999961853027,2.630000114440918,3.13700008392334,1.888000011444092,29.968000411987305 64 | 2016-07-03 14:00:00,10.180999755859377,5.425000190734863,7.177999973297119,3.0199999809265137,3.075999975204468,1.888000011444092,30.38999938964844 65 | 2016-07-03 15:00:00,9.645000457763672,5.425000190734863,7.10699987411499,2.664999961853028,3.015000104904175,1.8279999494552608,31.16399955749512 66 | 2016-07-03 16:00:00,9.779000282287598,4.889999866485597,6.502999782562256,2.984999895095825,3.075999975204468,2.009999990463257,29.756999969482425 67 | 2016-07-03 17:00:00,11.119000434875488,5.1570000648498535,7.320000171661378,2.9140000343322754,3.806999921798706,1.9800000190734863,32.28900146484375 68 | 2016-07-03 18:00:00,11.052000045776367,4.956999778747559,7.390999794006348,2.5230000019073486,3.686000108718872,1.9800000190734863,31.9379997253418 69 | 2016-07-03 19:00:00,10.784000396728516,4.889999866485597,7.214000225067139,2.486999988555908,3.594000101089477,1.888000011444092,28.56100082397461 70 | 2016-07-03 20:00:00,11.185999870300293,4.889999866485597,7.177999973297119,2.345000028610229,3.9600000381469727,1.919000029563904,21.525999069213867 71 | 2016-07-03 21:00:00,10.449000358581543,4.889999866485597,6.610000133514402,2.309999942779541,3.806999921798706,2.0409998893737797,22.229999542236328 72 | 2016-07-03 22:00:00,9.57800006866455,5.760000228881836,6.787000179290772,3.127000093460083,3.259000062942505,1.888000011444092,19.416000366210934 73 | 2016-07-03 23:00:00,9.3100004196167,5.760000228881836,6.610000133514402,3.0559999942779537,3.1679999828338623,1.888000011444092,18.57200050354004 74 | 2016-07-04 00:00:00,9.913000106811523,5.894000053405763,6.254000186920166,2.630000114440918,3.015000104904175,1.8580000400543213,21.66699981689453 75 | 2016-07-04 01:00:00,8.975000381469727,4.956999778747559,6.289999961853027,2.664999961853028,2.86299991607666,1.8279999494552608,25.535999298095703 76 | 2016-07-04 02:00:00,8.640000343322754,4.822999954223633,6.1479997634887695,2.594000101089477,2.924000024795532,1.8279999494552608,27.85700035095215 77 | 2016-07-04 03:00:00,9.175999641418457,5.492000102996826,5.578999996185303,2.38100004196167,2.86299991607666,1.8580000400543213,27.92799949645996 78 | 2016-07-04 04:00:00,9.109000205993652,4.822999954223633,5.6500000953674325,2.5230000019073486,2.7720000743865967,1.7970000505447388,24.62100028991699 79 | 2016-07-04 05:00:00,9.845999717712402,5.559000015258789,5.96999979019165,2.9489998817443848,3.1070001125335693,1.888000011444092,23.847999572753906 80 | 2016-07-04 06:00:00,11.588000297546387,5.425000190734863,7.390999794006348,2.806999921798706,3.806999921798706,1.9800000190734863,23.07399940490723 81 | 2016-07-04 07:00:00,11.788000106811525,6.09499979019165,7.214000225067139,2.984999895095825,3.8989999294281006,2.0409998893737797,22.51099967956543 82 | 2016-07-04 08:00:00,10.583000183105467,5.960999965667725,7.14300012588501,2.9140000343322754,3.6549999713897705,2.0710000991821294,21.66699981689453 83 | 2016-07-04 09:00:00,11.588000297546387,6.296000003814697,7.568999767303468,3.0559999942779537,3.4719998836517334,2.009999990463257,25.395000457763672 84 | 2016-07-04 10:00:00,11.92199993133545,6.2290000915527335,7.710999965667725,3.0559999942779537,3.746000051498413,1.9490000009536743,25.18400001525879 85 | 2016-07-04 11:00:00,12.324000358581545,5.559000015258789,8.421999931335451,3.233999967575073,4.203000068664552,1.9800000190734863,29.54599952697754 86 | 2016-07-04 12:00:00,10.381999969482422,5.894000053405763,6.857999801635742,2.630000114440918,3.563999891281128,1.9490000009536743,29.47500038146973 87 | 2016-07-04 13:00:00,10.04699993133545,5.425000190734863,6.751999855041504,3.0199999809265137,3.3199999332427983,1.9490000009536743,29.263999938964844 88 | 2016-07-04 14:00:00,10.515999794006348,6.027999877929688,7.10699987411499,3.375999927520752,3.13700008392334,1.919000029563904,30.952999114990234 89 | 2016-07-04 15:00:00,10.717000007629396,6.09499979019165,6.787000179290772,3.0199999809265137,3.1679999828338623,2.009999990463257,31.72599983215332 90 | 2016-07-04 16:00:00,9.979999542236328,5.0229997634887695,6.502999782562256,2.559000015258789,3.441999912261963,2.0409998893737797,33.132999420166016 91 | 2016-07-04 17:00:00,11.31999969482422,5.0900001525878915,7.355999946594237,2.45199990272522,3.868000030517578,2.0409998893737797,28.982999801635746 92 | 2016-07-04 18:00:00,11.38700008392334,4.956999778747559,7.355999946594237,2.45199990272522,4.295000076293944,2.193000078201294,28.982999801635746 93 | 2016-07-04 19:00:00,9.376999855041504,3.884999990463257,6.894000053405763,2.2390000820159908,2.4670000076293945,1.187999963760376,31.72599983215332 94 | 2016-07-04 20:00:00,10.11400032043457,4.0859999656677255,7.14300012588501,2.2390000820159908,2.9549999237060547,1.462000012397766,25.18400001525879 95 | 2016-07-04 21:00:00,10.381999969482422,4.822999954223633,6.894000053405763,2.309999942779541,3.503000020980835,2.009999990463257,30.5310001373291 96 | 2016-07-04 22:00:00,9.645000457763672,4.889999866485597,6.610000133514402,1.919000029563904,3.259000062942505,1.919000029563904,27.645999908447266 97 | 2016-07-04 23:00:00,12.72599983215332,6.497000217437744,9.345999717712402,3.4820001125335693,3.1679999828338623,1.9800000190734863,25.465999603271484 98 | 2016-07-05 00:00:00,11.98900032043457,5.625999927520752,8.777000427246094,2.9489998817443848,3.1979999542236333,1.9800000190734863,25.95800018310547 99 | 2016-07-05 01:00:00,12.524999618530273,6.296000003814697,8.954999923706055,3.1630001068115234,3.13700008392334,2.009999990463257,25.95800018310547 100 | 2016-07-05 02:00:00,12.324000358581545,6.296000003814697,8.812999725341797,3.375999927520752,2.984999895095825,1.919000029563904,26.02799987792969 -------------------------------------------------------------------------------- /tests/test_model.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append("..") 3 | 4 | from utils import TSMixerModelExclRIN, TSBatchNorm2d, TSFeatMixingResBlock, TSMixingLayer, TSTemporalProjection, TSTimeMixingResBlock, TSMixerModel 5 | 6 | import pytest 7 | import torch 8 | 9 | class TestModel: 10 | 11 | def _time_series(self, batch_size: int, input_length: int, no_feats: int) -> torch.Tensor: 12 | return torch.randn(batch_size, input_length, no_feats) 13 | 14 | def test_batchnorm2d(self): 15 | bn = TSBatchNorm2d() 16 | data = self._time_series(batch_size=32, input_length=100, no_feats=5) 17 | data_out = bn(data) 18 | assert data_out.shape == data.shape 19 | 20 | 21 | def test_tstemporalprojection(self): 22 | tp = TSTemporalProjection(input_length=100, forecast_length=30) 23 | data = self._time_series(batch_size=32, input_length=100, no_feats=5) 24 | data_out = tp(data) 25 | assert data_out.shape == (32, 30, 5) 26 | 27 | 28 | def test_tsmixinglayer(self): 29 | ml = TSMixingLayer(input_length=100, no_feats=5, dropout=0.5, feat_mixing_hidden_channels=10) 30 | data = self._time_series(batch_size=32, input_length=100, no_feats=5) 31 | data_out = ml(data) 32 | assert data_out.shape == data.shape 33 | 34 | 35 | def test_tsfeatmixingresblock(self): 36 | fmrb = TSFeatMixingResBlock(width_feats=5, dropout=0.5, width_feats_hidden=10) 37 | data = self._time_series(batch_size=32, input_length=100, no_feats=5) 38 | data_out = fmrb(data) 39 | assert data_out.shape == data.shape 40 | 41 | 42 | def test_tstimemixingresblock(self): 43 | tmrb = TSTimeMixingResBlock(width_time=100, dropout=0.5) 44 | data = self._time_series(batch_size=32, input_length=100, no_feats=5) 45 | data_out = tmrb(data) 46 | assert data_out.shape == data.shape 47 | 48 | 49 | def test_tsmixer_excl_rin(self): 50 | 51 | ts = TSMixerModelExclRIN( 52 | input_length=100, 53 | forecast_length=10, 54 | no_feats=5, 55 | no_mixer_layers=3, 56 | dropout=0.5, 57 | feat_mixing_hidden_channels=5 58 | ) 59 | data = self._time_series(batch_size=32, input_length=100, no_feats=5) 60 | forecast = ts(data) 61 | 62 | assert forecast.shape == (32, 10, 5) 63 | 64 | 65 | def test_tsmixer(self): 66 | 67 | ts = TSMixerModel( 68 | input_length=100, 69 | forecast_length=10, 70 | no_feats=5, 71 | no_mixer_layers=3, 72 | dropout=0.5, 73 | feat_mixing_hidden_channels=5 74 | ) 75 | data = self._time_series(batch_size=32, input_length=100, no_feats=5) 76 | forecast = ts(data) 77 | 78 | assert forecast.shape == (32, 10, 5) 79 | -------------------------------------------------------------------------------- /tests/test_tsmixer.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append("..") 3 | 4 | from utils import TSMixer, TSMixerConf 5 | 6 | import pytest 7 | import torch 8 | import os 9 | import shutil 10 | 11 | TEST_CSV_NO_FEATS = 7 12 | 13 | 14 | @pytest.fixture 15 | def conf(): 16 | output_dir = "TMP_output_dir" 17 | if os.path.exists(output_dir): 18 | shutil.rmtree(output_dir) 19 | 20 | yield TSMixerConf( 21 | input_length=20, 22 | prediction_length=5, 23 | no_features=TEST_CSV_NO_FEATS, 24 | no_mixer_layers=2, 25 | output_dir=output_dir, 26 | data_src=TSMixerConf.DataSrc.CSV_FILE, 27 | data_src_csv="test_csv.csv", 28 | batch_size=4, 29 | num_epochs=10, 30 | learning_rate=0.001, 31 | optimizer="Adam" 32 | ) 33 | 34 | if os.path.exists(output_dir): 35 | shutil.rmtree(output_dir) 36 | 37 | 38 | @pytest.fixture 39 | def tsmixer(conf: TSMixerConf): 40 | return TSMixer(conf=conf) 41 | 42 | 43 | class TestTsMixer: 44 | 45 | def test_load_data(self, tsmixer: TSMixer): 46 | loader_train, loader_val, _ = tsmixer.conf.create_data_loaders_train_val() 47 | for loader in [loader_train, loader_val]: 48 | batch_input, batch_pred = next(iter(loader)) 49 | assert batch_input.shape == (tsmixer.conf.batch_size, tsmixer.conf.input_length, tsmixer.conf.no_features) 50 | assert batch_pred.shape == (tsmixer.conf.batch_size, tsmixer.conf.prediction_length, tsmixer.conf.no_features) -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import TSBatchNorm2d, TSFeatMixingResBlock, TSMixerModelExclRIN, TSMixingLayer, TSTemporalProjection, TSTimeMixingResBlock, TSMixerModel 2 | from .plotting import plot_preds, plot_loss 3 | from .tsmixer_conf import TSMixerConf, TrainingMetadata 4 | from .tsmixer_grid_search_conf import TSMixerGridSearch 5 | from .tsmixer import TSMixer -------------------------------------------------------------------------------- /utils/load_csv.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from torch.utils.data import DataLoader, Dataset, Subset 3 | from enum import Enum 4 | import torch 5 | from typing import Tuple, Callable, Optional, List 6 | from dataclasses import dataclass 7 | from mashumaro import DataClassDictMixin 8 | from loguru import logger 9 | 10 | 11 | class ValidationSplit(Enum): 12 | 13 | TEMPORAL_HOLDOUT = "temporal-holdout" 14 | "Reserve the last portion (e.g., 10-20%) of your time-ordered data for validation, and use the remaining data for training. This is a simple and widely used approach." 15 | 16 | 17 | class DataframeDataset(Dataset): 18 | """Dataset from a pandas dataframe 19 | """ 20 | 21 | def __init__(self, df: pd.DataFrame, window_size_input: int, window_size_predict: int, transform: Optional[Callable] = None): 22 | """Constructor 23 | 24 | Args: 25 | df (pd.DataFrame): Dataframe 26 | window_size_input (int): Input window size 27 | window_size_predict (int): Prediction window size 28 | transform (Optional[Callable], optional): Transforms such as normalization applied to time series. Defaults to None. 29 | """ 30 | window_size_total = window_size_input + window_size_predict 31 | assert len(df) > window_size_total, f"Dataset length ({len(df)}) must be greater than window size ({window_size_total})" 32 | self.df = df 33 | self.window_size_input = window_size_input 34 | self.window_size_predict = window_size_predict 35 | self.transform = transform 36 | 37 | def __len__(self): 38 | return len(self.df) - self.window_size_input - self.window_size_predict 39 | 40 | def get_sample(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]: 41 | """Get a window sample. Input from [idx, idx + window_size_input], prediction from [idx + window_size_input, idx + window_size_input + window_size_predict 42 | 43 | Args: 44 | idx (int): Index 45 | 46 | Returns: 47 | Tuple[torch.Tensor, torch.Tensor]: Input and prediction tensors 48 | """ 49 | 50 | # Check if the index plus window size exceeds the length of the dataset 51 | if idx + self.window_size_input + self.window_size_predict > len(self.df): 52 | raise IndexError(f"Index ({idx}) + window_size_input ({self.window_size_input}) + window_size_predict ({self.window_size_predict}) exceeds dataset length ({len(self.df)})") 53 | 54 | # Window the data 55 | sample_input = self.df.iloc[idx:idx + self.window_size_input, :] 56 | sample_pred = self.df.iloc[idx + self.window_size_input:idx + self.window_size_input + self.window_size_predict, :] 57 | 58 | # Convert to torch tensor 59 | sample_input = torch.tensor(sample_input.values, dtype=torch.float32) 60 | sample_pred = torch.tensor(sample_pred.values, dtype=torch.float32) 61 | 62 | # Apply transform 63 | if self.transform is not None: 64 | sample_input = self.transform(sample_input) 65 | sample_pred = self.transform(sample_pred) 66 | 67 | return sample_input, sample_pred 68 | 69 | def __getitem__(self, idx): 70 | if torch.is_tensor(idx): 71 | idx = idx.tolist() 72 | 73 | if isinstance(idx, list): 74 | # Handle a list of indices 75 | samples = [self.get_sample(i) for i in idx] 76 | return samples 77 | else: 78 | # Handle a single index 79 | return self.get_sample(idx) 80 | 81 | 82 | @dataclass 83 | class DataNormalization(DataClassDictMixin): 84 | mean_each_feature: Optional[List[float]] = None 85 | "Mean for each feature" 86 | 87 | std_each_feature: Optional[List[float]] = None 88 | "Std for each feature" 89 | 90 | 91 | def load_csv_dataset( 92 | csv_file: str, 93 | batch_size: int, 94 | input_length: int, 95 | prediction_length: int, 96 | val_split: ValidationSplit, 97 | val_split_holdout: float = 0.2, 98 | shuffle: bool = True, 99 | normalize_each_feature: bool = True, 100 | data_norm_exist: Optional[DataNormalization] = None 101 | ) -> Tuple[DataLoader, DataLoader, DataNormalization]: 102 | """Load a CSV dataset 103 | 104 | Args: 105 | csv_file (str): CSV file path 106 | batch_size (int): Batch size 107 | input_length (int): Input length 108 | prediction_length (int): Prediction length 109 | val_split (ValidationSplit): Validation split method 110 | val_split_holdout (float, optional): Holdout fraction for validation (last X% of data) - only used for TEMPORAL_HOLDOUT. Defaults to 0.2. 111 | shuffle (bool, optional): True to shuffle data. Defaults to True. 112 | normalize_each_feature (bool, optional): Normalize each feature. Defaults to True. 113 | data_norm_exist (Optional[DataNormalization], optional): Existing normalization data - apply this instead of recalculating. Defaults to None. 114 | 115 | Returns: 116 | Tuple[DataLoader, DataLoader, DataNormalization]: Training and validation data loaders, and normalization data 117 | """ 118 | 119 | # Load the CSV file into a DataFrame 120 | df_raw = pd.read_csv(csv_file) 121 | df = df_raw.set_index('date') 122 | 123 | # Make dataset 124 | dataset = DataframeDataset(df, window_size_input=input_length, window_size_predict=prediction_length) 125 | no_pts = len(dataset) 126 | 127 | # Split the data into training and validation 128 | if val_split == ValidationSplit.TEMPORAL_HOLDOUT: 129 | idx_train_val = int(no_pts * (1-val_split_holdout)) 130 | else: 131 | raise NotImplementedError(f"Validation split {val_split} not implemented") 132 | 133 | # Normalize each feature separately 134 | if data_norm_exist is None: 135 | data_norm_exist = DataNormalization() 136 | 137 | # Compute mean and std on training data from pandas dataframe 138 | filtered_df = df[:idx_train_val] 139 | data_norm_exist.mean_each_feature = list(filtered_df.mean().values) 140 | data_norm_exist.std_each_feature = list(filtered_df.std().values) 141 | logger.debug(f"Computed data mean for each feature: {data_norm_exist.mean_each_feature}") 142 | logger.debug(f"Computed data std for each feature: {data_norm_exist.std_each_feature}") 143 | 144 | if normalize_each_feature: 145 | assert data_norm_exist.mean_each_feature is not None, "Must provide data mean for each feature" 146 | assert data_norm_exist.std_each_feature is not None, "Must provide data std for each feature" 147 | 148 | # Create a normalization function 149 | transform = lambda x: (x - torch.Tensor(data_norm_exist.mean_each_feature)) / torch.Tensor(data_norm_exist.std_each_feature) 150 | 151 | # Apply the normalization function 152 | dataset.transform = transform 153 | 154 | # Splits 155 | train_dataset = Subset(dataset, range(idx_train_val)) 156 | val_dataset = Subset(dataset, range(idx_train_val, no_pts)) 157 | 158 | loader_train = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle) 159 | loader_val = DataLoader(val_dataset, batch_size=batch_size, shuffle=shuffle) 160 | 161 | return loader_train, loader_val, data_norm_exist 162 | -------------------------------------------------------------------------------- /utils/model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | 5 | class TSBatchNorm2d(nn.Module): 6 | 7 | def __init__(self): 8 | super(TSBatchNorm2d, self).__init__() 9 | self.bn = nn.BatchNorm2d(num_features=1) 10 | 11 | def forward(self, x: torch.Tensor) -> torch.Tensor: 12 | # Input x: (batch_size, time, features) 13 | 14 | # Reshape input_data to (batch_size, 1, timepoints, features) 15 | x = x.unsqueeze(1) 16 | 17 | # Forward pass 18 | output = self.bn(x) 19 | 20 | # Reshape the output back to (batch_size, timepoints, features) 21 | output = output.squeeze(1) 22 | return output 23 | 24 | 25 | class TSTimeMixingResBlock(nn.Module): 26 | 27 | def __init__(self, width_time: int, dropout: float): 28 | super(TSTimeMixingResBlock, self).__init__() 29 | self.norm = TSBatchNorm2d() 30 | 31 | self.lin = nn.Linear(in_features=width_time, out_features=width_time) 32 | self.dropout = nn.Dropout(p=dropout) 33 | self.act = nn.ReLU() 34 | 35 | def forward(self, x: torch.Tensor) -> torch.Tensor: 36 | # Input x: (batch_size, time, features) 37 | y = self.norm(x) 38 | 39 | # Now rotate such that shape is (batch_size, features, time) 40 | y = torch.transpose(y, 1, 2) 41 | 42 | # Apply MLP to time dimension 43 | y = self.lin(y) 44 | y = self.act(y) 45 | 46 | # Rotate back such that shape is (batch_size, time, features) 47 | y = torch.transpose(y, 1, 2) 48 | 49 | # Dropout 50 | y = self.dropout(y) 51 | 52 | # Add residual connection 53 | return x + y 54 | 55 | 56 | class TSFeatMixingResBlock(nn.Module): 57 | 58 | def __init__(self, width_feats: int, width_feats_hidden: int, dropout: float): 59 | super(TSFeatMixingResBlock, self).__init__() 60 | self.norm = TSBatchNorm2d() 61 | 62 | self.lin_1 = nn.Linear(in_features=width_feats, out_features=width_feats_hidden) 63 | self.lin_2 = nn.Linear(in_features=width_feats_hidden, out_features=width_feats) 64 | self.dropout_1 = nn.Dropout(p=dropout) 65 | self.dropout_2 = nn.Dropout(p=dropout) 66 | self.act = nn.ReLU() 67 | 68 | 69 | def forward(self, x: torch.Tensor) -> torch.Tensor: 70 | # Input x: (batch_size, time, features) 71 | y = self.norm(x) 72 | 73 | # Apply MLP to feat dimension 74 | y = self.lin_1(y) 75 | y = self.act(y) 76 | y = self.dropout_1(y) 77 | y = self.lin_2(y) 78 | y = self.dropout_2(y) 79 | 80 | # Add residual connection 81 | return x + y 82 | 83 | 84 | class TSMixingLayer(nn.Module): 85 | 86 | def __init__(self, input_length: int, no_feats: int, feat_mixing_hidden_channels: int, dropout: float): 87 | super(TSMixingLayer, self).__init__() 88 | self.time_mixing = TSTimeMixingResBlock(width_time=input_length, dropout=dropout) 89 | self.feat_mixing = TSFeatMixingResBlock(width_feats=no_feats, width_feats_hidden=feat_mixing_hidden_channels, dropout=dropout) 90 | 91 | def forward(self, x: torch.Tensor) -> torch.Tensor: 92 | # Input x: (batch_size, time, features) 93 | y = self.time_mixing(x) 94 | y = self.feat_mixing(y) 95 | return y 96 | 97 | 98 | class TSTemporalProjection(nn.Module): 99 | 100 | def __init__(self, input_length: int, forecast_length: int): 101 | super(TSTemporalProjection, self).__init__() 102 | self.lin = nn.Linear(in_features=input_length, out_features=forecast_length) 103 | 104 | def forward(self, x: torch.Tensor) -> torch.Tensor: 105 | # Input x: (batch_size, time, features) 106 | # Now rotate such that shape is (batch_size, features, time=input_length) 107 | y = torch.transpose(x, 1, 2) 108 | 109 | # Apply linear projection -> shape is (batch_size, features, time=forecast_length) 110 | y = self.lin(y) 111 | 112 | # Rotate back such that shape is (batch_size, time=forecast_length, features) 113 | y = torch.transpose(y, 1, 2) 114 | return y 115 | 116 | 117 | class TSMixerModelExclRIN(nn.Module): 118 | 119 | def __init__(self, input_length: int, forecast_length: int, no_feats: int, feat_mixing_hidden_channels: int, no_mixer_layers: int, dropout: float): 120 | super(TSMixerModelExclRIN, self).__init__() 121 | self.temp_proj = TSTemporalProjection(input_length=input_length, forecast_length=forecast_length) 122 | mixer_layers = [] 123 | for _ in range(no_mixer_layers): 124 | mixer_layers.append(TSMixingLayer(input_length=input_length, no_feats=no_feats, feat_mixing_hidden_channels=feat_mixing_hidden_channels, dropout=dropout)) 125 | self.mixer_layers = nn.ModuleList(mixer_layers) 126 | 127 | def forward(self, x: torch.Tensor) -> torch.Tensor: 128 | # Input x: (batch_size, time, features) 129 | for mixer_layer in self.mixer_layers: 130 | x = mixer_layer(x) 131 | 132 | # Apply temporal projection -> shape is (batch_size, time=forecast_length, features) 133 | x = self.temp_proj(x) 134 | 135 | return x 136 | 137 | 138 | class TSMixerModel(nn.Module): 139 | """Include Reversible instance normalization https://openreview.net/pdf?id=cGDAkQo1C0p 140 | """ 141 | 142 | def __init__(self, input_length: int, forecast_length: int, no_feats: int, feat_mixing_hidden_channels: int, no_mixer_layers: int, dropout: float, eps: float = 1e-8): 143 | super(TSMixerModel, self).__init__() 144 | self.eps = eps 145 | 146 | # Scale and shift params to learn 147 | self.scale = nn.Parameter(torch.ones(no_feats)) 148 | self.shift = nn.Parameter(torch.zeros(no_feats)) 149 | 150 | # ts mixer layers 151 | self.ts = TSMixerModelExclRIN( 152 | input_length=input_length, 153 | forecast_length=forecast_length, 154 | no_feats=no_feats, 155 | feat_mixing_hidden_channels=feat_mixing_hidden_channels, 156 | no_mixer_layers=no_mixer_layers, 157 | dropout=dropout 158 | ) 159 | 160 | def forward(self, x: torch.Tensor) -> torch.Tensor: 161 | # Input x: (batch_size, time, features) 162 | 163 | # Compute mean, var across time dimension 164 | # mean: (batch_size, 1, features) 165 | # var: (batch_size, 1, features) 166 | mean = torch.mean(x, dim=1, keepdim=True) 167 | var = torch.var(x, dim=1, keepdim=True) 168 | 169 | # Normalize across time dimension 170 | # x: (batch_size, time, features) 171 | x = (x - mean) / torch.sqrt(var + self.eps) 172 | 173 | # Apply scale and shift in each feature dimension separately 174 | # x: (batch_size, time, features) 175 | # scale: (features) 176 | # shift: (features) 177 | x = x * self.scale + self.shift 178 | 179 | # Apply ts mixer layers 180 | x = self.ts(x) 181 | 182 | # Apply inverse scale and shift in each feature dimension separately 183 | # x: (batch_size, time, features) 184 | # scale: (features) 185 | # shift: (features) 186 | x = (x - self.shift) / self.scale 187 | 188 | # Unnormalize across time dimension 189 | # x: (batch_size, time, features) 190 | # mean: (batch_size, 1, features) 191 | # var: (batch_size, 1, features) 192 | x = x * torch.sqrt(var + self.eps) + mean 193 | 194 | return x -------------------------------------------------------------------------------- /utils/plotting.py: -------------------------------------------------------------------------------- 1 | from .tsmixer_conf import TrainingMetadata 2 | 3 | from typing import List, Tuple, Optional 4 | from loguru import logger 5 | 6 | 7 | def plot_preds(preds: List[List[float]], preds_gt: List[List[float]], no_feats_plot: int, fname_save: Optional[str] = None, inputs: Optional[List[List[float]]] = None, show: bool = True): 8 | """Plot predictions 9 | 10 | Args: 11 | preds (List[List[float]]): Predictions of shape (no_samples, no_feats) 12 | preds_gt (List[List[float]]): Predictions of shape (no_samples, no_feats) 13 | no_feats_plot (int): Number of features to plot 14 | fname_save (Optional[str], optional): File name to save the plot. Defaults to None. 15 | inputs (Optional[List[List[float]]], optional): Input of shape (no_samples, no_feats) 16 | show (bool): Show the plot 17 | """ 18 | import plotly.graph_objects as go 19 | from plotly.subplots import make_subplots 20 | 21 | no_feats = len(preds[0]) 22 | if no_feats_plot > no_feats: 23 | logger.warning(f"no_feats_plot ({no_feats_plot}) is larger than no_feats ({no_feats}). Setting no_feats_plot to no_feats") 24 | no_feats_plot = no_feats 25 | 26 | no_cols = 3 27 | no_rows = int(no_feats_plot / no_cols) 28 | if no_feats_plot % no_cols != 0: 29 | no_rows += 1 30 | 31 | fig = make_subplots(rows=no_rows, cols=no_cols, subplot_titles=[f"Feature {ifeat}" for ifeat in range(no_feats_plot)]) 32 | 33 | no_inputs = len(inputs) if inputs is not None else 0 34 | x_preds = list(range(no_inputs, no_inputs + len(preds))) 35 | for ifeat in range(no_feats_plot): 36 | row = int(ifeat / no_cols) + 1 37 | col = (ifeat % no_cols) + 1 38 | 39 | if inputs is not None: 40 | x_inputs = list(range(len(inputs))) 41 | fig.add_trace(go.Scatter(x=x_inputs, y=[in_y[ifeat] for in_y in inputs], mode="lines", name=f"Inputs", line=dict(color="black"), showlegend=ifeat==0), row=row, col=col) 42 | 43 | fig.add_trace(go.Scatter(x=x_preds, y=[pred[ifeat] for pred in preds_gt], mode="lines", name=f"Ground truth", line=dict(color="red"), showlegend=ifeat==0), row=row, col=col) 44 | fig.add_trace(go.Scatter(x=x_preds, y=[pred[ifeat] for pred in preds], mode="lines", name=f"Model", line=dict(color="blue"), showlegend=ifeat==0), row=row, col=col) 45 | 46 | fig.update_layout( 47 | height=300*no_rows, 48 | width=400*no_cols, 49 | title_text="Predictions", 50 | font=dict(size=18), 51 | xaxis_title_text="Time", 52 | yaxis_title_text="Signal", 53 | ) 54 | 55 | if fname_save is not None: 56 | fig.write_image(fname_save) 57 | logger.info(f"Saved plot to {fname_save}") 58 | 59 | if show: 60 | fig.show() 61 | 62 | return fig 63 | 64 | 65 | def plot_loss(train_data: TrainingMetadata, fname_save: Optional[str] = None, show: bool = True): 66 | """Plot loss 67 | 68 | Args: 69 | train_data (TSMixer.TrainingMetadata): Training metadata 70 | fname_save (Optional[str], optional): File name to save the plot. Defaults to None. 71 | show (bool): Show the plot 72 | """ 73 | import plotly.graph_objects as go 74 | 75 | fig = go.Figure() 76 | x = [ epoch for epoch in train_data.epoch_to_data.keys() ] 77 | y = [ data.val_loss for data in train_data.epoch_to_data.values() ] 78 | fig.add_trace(go.Scatter(x=x, y=y, mode="lines", name="Val. loss")) 79 | y = [ data.train_loss for data in train_data.epoch_to_data.values() ] 80 | fig.add_trace(go.Scatter(x=x, y=y, mode="lines", name="Train loss")) 81 | 82 | fig.update_layout( 83 | height=500, 84 | width=700, 85 | title_text="Loss during training", 86 | xaxis_title_text="Epoch", 87 | yaxis_title_text="Loss", 88 | font=dict(size=18), 89 | ) 90 | 91 | if fname_save is not None: 92 | fig.write_image(fname_save) 93 | logger.info(f"Saved plot to {fname_save}") 94 | 95 | if show: 96 | fig.show() 97 | 98 | return fig -------------------------------------------------------------------------------- /utils/tsmixer.py: -------------------------------------------------------------------------------- 1 | from .tsmixer_conf import TSMixerConf, TrainingMetadata, makedirs 2 | from .model import TSMixerModel 3 | from .load_csv import DataNormalization 4 | 5 | import os 6 | from typing import Optional, Tuple, Dict, List 7 | import torch 8 | from loguru import logger 9 | from tqdm import tqdm 10 | import json 11 | import time 12 | import shutil 13 | from dataclasses import dataclass 14 | from mashumaro import DataClassDictMixin 15 | import yaml 16 | 17 | 18 | class TSMixer: 19 | """TSMixer including training and prediction methods 20 | """ 21 | 22 | 23 | def __init__(self, conf: TSMixerConf): 24 | """Constructor for TSMixer class 25 | 26 | Args: 27 | conf (TSMixerConf): Configuration 28 | """ 29 | conf.check_valid() 30 | self.conf = conf 31 | 32 | # Create the model 33 | self.model = TSMixerModel( 34 | input_length=self.conf.input_length, 35 | forecast_length=self.conf.prediction_length, 36 | no_feats=self.conf.no_features, 37 | feat_mixing_hidden_channels=self.conf.feat_mixing_hidden_channels or self.conf.no_features, 38 | no_mixer_layers=self.conf.no_mixer_layers, 39 | dropout=self.conf.dropout 40 | ) 41 | 42 | # Move to device 43 | self.model.to(self.conf.device) 44 | 45 | # Load the model 46 | if self.conf.initialize == self.conf.Initialize.FROM_LATEST_CHECKPOINT: 47 | self.load_checkpoint(fname=self.conf.checkpoint_latest) 48 | elif self.conf.initialize == self.conf.Initialize.FROM_BEST_CHECKPOINT: 49 | self.load_checkpoint(fname=self.conf.checkpoint_best) 50 | elif self.conf.initialize == self.conf.Initialize.FROM_SCRATCH: 51 | pass 52 | else: 53 | raise NotImplementedError(f"Initialize {self.conf.initialize} not implemented") 54 | 55 | 56 | def load_checkpoint(self, fname: str, optimizer: Optional[torch.optim.Optimizer] = None) -> Tuple[int,float]: 57 | """Load a checkpoint, optionally including the optimizer state 58 | 59 | Args: 60 | fname (str): File name 61 | optimizer (Optional[torch.optim.Optimizer], optional): Optimizer to update from checkpoint. Defaults to None. 62 | 63 | Returns: 64 | Tuple[int,float]: Epoch and loss 65 | """ 66 | logger.debug(f"Loading model weights from {fname}") 67 | checkpoint = torch.load(fname) 68 | self.model.load_state_dict(checkpoint['model_state_dict']) 69 | 70 | if optimizer is not None: 71 | logger.debug(f"Loading optimizer state from {fname}") 72 | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 73 | epoch = checkpoint['epoch'] 74 | loss = checkpoint['loss'] 75 | logger.info(f"Loaded optimizer state from epoch {epoch} with loss {loss}") 76 | return epoch, loss 77 | 78 | 79 | def predict(self, batch_input: torch.Tensor) -> torch.Tensor: 80 | """Predict the output for a batch of input data 81 | 82 | Args: 83 | batch_input (torch.Tensor): Input data of shape (batch_size, input_length (time), no_features) 84 | 85 | Returns: 86 | torch.Tensor: Predicted output of shape (batch_size, prediction_length (time), no_features) 87 | """ 88 | self.model.eval() 89 | 90 | # Check size 91 | assert batch_input.shape[1] == self.conf.input_length, f"Input length {batch_input.shape[1]} does not match configuration {self.conf.input_length}" 92 | assert batch_input.shape[2] == self.conf.no_features, f"Number of features {batch_input.shape[2]} does not match configuration {self.conf.no_features}" 93 | 94 | # Predict 95 | batch_input = batch_input.to(self.conf.device) 96 | with torch.no_grad(): 97 | batch_pred_hat = self.model(batch_input) 98 | return batch_pred_hat 99 | 100 | 101 | def load_data_norm(self) -> Optional[DataNormalization]: 102 | """Load the data normalization from a JSON file 103 | 104 | Returns: 105 | Optional[DataNormalization]: Data normalization, or None if the file does not exist 106 | """ 107 | 108 | if os.path.exists(self.conf.data_norm_json): 109 | logger.debug(f"Loading data normalization from {self.conf.data_norm_json}") 110 | with open(self.conf.data_norm_json, "r") as f: 111 | return DataNormalization.from_dict(json.load(f)) 112 | else: 113 | return None 114 | 115 | 116 | @dataclass 117 | class PredData(DataClassDictMixin): 118 | """Prediction data 119 | """ 120 | 121 | pred_gt: List[List[float]] 122 | "Ground truth prediction" 123 | 124 | pred: List[List[float]] 125 | "Model prediction" 126 | 127 | inputs: Optional[List[List[float]]] = None 128 | "Inputs" 129 | 130 | 131 | def predict_val_dataset(self, max_samples: Optional[int] = None, save_inputs: bool = False) -> List[PredData]: 132 | """Predict on the validation dataset 133 | 134 | Args: 135 | max_samples (Optional[int], optional): Maximum number of samples to predict from the validation dataset. Defaults to None. 136 | save_inputs (bool, optional): Save the inputs as well as the predictions. Defaults to False. 137 | 138 | Returns: 139 | List[PredData]: List of predictions 140 | """ 141 | 142 | # Change batch size to 1 and not shuffle data for consistency 143 | batch_size_save = self.conf.batch_size 144 | shuffle_save = self.conf.shuffle 145 | self.conf.batch_size = 1 146 | self.conf.shuffle = False 147 | 148 | # Load the data normalization if it exists and use it 149 | data_norm = self.load_data_norm() 150 | 151 | # Create the loaders 152 | _, loader_val, _ = self.conf.create_data_loaders_train_val(data_norm) 153 | 154 | # Predict 155 | data_list: List[TSMixer.PredData] = [] 156 | for _ in tqdm(range(max_samples or len(loader_val)), desc="Predicting"): 157 | batch_input, batch_pred = next(iter(loader_val)) 158 | batch_pred_hat = self.predict(batch_input) 159 | data = TSMixer.PredData( 160 | pred_gt=batch_pred.tolist()[0], 161 | pred=batch_pred_hat.tolist()[0], 162 | inputs=batch_input.tolist()[0] if save_inputs else None 163 | ) 164 | data_list.append(data) 165 | 166 | # Save data to json 167 | with open(self.conf.pred_val_dataset_json, "w") as f: 168 | json.dump([ d.to_dict() for d in data_list ], f) 169 | logger.info(f"Saved data to {f.name}") 170 | 171 | # Reset options 172 | self.conf.batch_size = batch_size_save 173 | self.conf.shuffle = shuffle_save 174 | 175 | return data_list 176 | 177 | 178 | def train(self): 179 | """Train the model 180 | """ 181 | 182 | # Create the optimizer 183 | optimizer_cls = getattr(torch.optim, self.conf.optimizer) 184 | optimizer = optimizer_cls(self.model.parameters(), lr=self.conf.learning_rate) 185 | 186 | # Load if needed 187 | if self.conf.initialize == self.conf.Initialize.FROM_LATEST_CHECKPOINT: 188 | epoch_start, val_loss_best = self.load_checkpoint(fname=self.conf.checkpoint_latest, optimizer=optimizer) 189 | data_norm = self.load_data_norm() 190 | elif self.conf.initialize == self.conf.Initialize.FROM_BEST_CHECKPOINT: 191 | epoch_start, val_loss_best = self.load_checkpoint(fname=self.conf.checkpoint_best, optimizer=optimizer) 192 | data_norm = self.load_data_norm() 193 | elif self.conf.initialize == self.conf.Initialize.FROM_SCRATCH: 194 | epoch_start, val_loss_best = 0, float("inf") 195 | 196 | # Clear the output directory 197 | if os.path.exists(self.conf.output_dir): 198 | logger.warning(f"Output directory {self.conf.output_dir} already exists. Deleting it to start over. You have 8 seconds.") 199 | for _ in range(8): 200 | print(".", end="", flush=True) 201 | time.sleep(1) 202 | print("") 203 | shutil.rmtree(self.conf.output_dir) 204 | makedirs(self.conf.output_dir) 205 | 206 | # Save initial weights 207 | self._save_checkpoint(epoch=epoch_start, optimizer=optimizer, loss=val_loss_best, fname=self.conf.checkpoint_init) 208 | data_norm = None 209 | 210 | # Copy the config to the output directory for reference 211 | fname_conf = os.path.join(self.conf.output_dir, "conf.yml") 212 | makedirs(self.conf.output_dir) 213 | with open(fname_conf, "w") as f: 214 | yaml.dump(self.conf.to_dict(), f, indent=3) 215 | logger.info(f"Saved configuration to {f.name}") 216 | 217 | else: 218 | raise NotImplementedError(f"Initialize {self.conf.initialize} not implemented") 219 | train_data = self.conf.load_training_metadata_or_new(epoch_start) 220 | 221 | # Create the loaders 222 | loader_train, loader_val, data_norm = self.conf.create_data_loaders_train_val(data_norm) 223 | 224 | # Write data normalization 225 | self.conf.write_data_norm(data_norm) 226 | 227 | # Train 228 | epoch_last_improvement = None 229 | for epoch in range(epoch_start, self.conf.num_epochs): 230 | logger.info(f"Epoch {epoch+1}/{self.conf.num_epochs}") 231 | t0 = time.time() 232 | 233 | # Training 234 | train_loss = 0 235 | for batch_input, batch_pred in tqdm(loader_train, desc="Training batches"): 236 | batch_input, batch_pred = batch_input.to(self.conf.device), batch_pred.to(self.conf.device) 237 | train_loss += self._train_step(batch_input, batch_pred, optimizer) 238 | 239 | # Validation loss 240 | self.model.eval() 241 | with torch.no_grad(): 242 | val_loss = 0 243 | for batch_input, batch_pred in tqdm(loader_val, desc="Validation batches"): 244 | batch_input, batch_pred = batch_input.to(self.conf.device), batch_pred.to(self.conf.device) 245 | val_loss += self._compute_loss(batch_input, batch_pred).item() 246 | 247 | # Log 248 | train_loss /= len(loader_train) 249 | val_loss /= len(loader_val) 250 | dur = time.time() - t0 251 | logger.info(f"Training loss: {train_loss:.5f} val: {val_loss:.5f} duration: {dur:.2f}s") 252 | 253 | # Store metadata about training 254 | train_data.epoch_to_data[epoch] = TrainingMetadata.EpochData(epoch=epoch, train_loss=train_loss, val_loss=val_loss, duration_seconds=dur) 255 | 256 | # Save checkpoint 257 | if val_loss < val_loss_best: 258 | logger.info(f"New best validation loss: {val_loss:.5f}") 259 | self._save_checkpoint(epoch=epoch, optimizer=optimizer, loss=val_loss, fname=self.conf.checkpoint_best) 260 | val_loss_best = val_loss 261 | epoch_last_improvement = epoch 262 | self._save_checkpoint(epoch=epoch, optimizer=optimizer, loss=val_loss, fname=self.conf.checkpoint_latest) 263 | self.conf.write_training_metadata(train_data) 264 | 265 | # Early stopping 266 | if epoch_last_improvement is not None and self.conf.early_stopping_patience is not None and epoch - epoch_last_improvement >= self.conf.early_stopping_patience: 267 | logger.info(f"Stopping early after {epoch - epoch_last_improvement} epochs without improvement in validation loss.") 268 | break 269 | 270 | 271 | def _save_checkpoint(self, epoch: int, optimizer: torch.optim.Optimizer, loss: float, fname: str): 272 | torch.save({ 273 | 'epoch': epoch, 274 | 'model_state_dict': self.model.state_dict(), 275 | 'optimizer_state_dict': optimizer.state_dict(), 276 | 'loss': loss, 277 | }, fname) 278 | 279 | 280 | def _compute_loss(self, batch_input: torch.Tensor, batch_pred: torch.Tensor) -> torch.Tensor: 281 | """Compute the loss 282 | 283 | Args: 284 | batch_input (torch.Tensor): Batch input of shape (batch_size, input_length (time), no_features) 285 | batch_pred (torch.Tensor): Batch prediction of shape (batch_size, prediction_length (time), no_features) 286 | 287 | Returns: 288 | torch.Tensor: Loss (MSE) 289 | """ 290 | 291 | # Forward pass 292 | batch_pred_hat = self.model(batch_input) 293 | 294 | # Compute MSE loss 295 | loss = torch.nn.functional.mse_loss(batch_pred_hat, batch_pred) 296 | 297 | # Normalize the loss by the batch size 298 | # batch_size = batch_input.size(0) 299 | # loss /= batch_size 300 | 301 | return loss 302 | 303 | 304 | def _train_step(self, batch_input: torch.Tensor, batch_pred: torch.Tensor, optimizer: torch.optim.Optimizer) -> float: 305 | """Training step 306 | 307 | Args: 308 | batch_input (torch.Tensor): Input data of shape (batch_size, input_length (time), no_features) 309 | batch_pred (torch.Tensor): Prediction data of shape (batch_size, prediction_length (time), no_features) 310 | optimizer (torch.optim.Optimizer): Optimizer 311 | 312 | Returns: 313 | float: Loss (MSE) 314 | """ 315 | optimizer.zero_grad() 316 | 317 | # Train mode 318 | self.model.train() 319 | 320 | # Loss 321 | loss = self._compute_loss(batch_input, batch_pred) 322 | 323 | # Backward pass 324 | loss.backward() 325 | 326 | # Update parameters 327 | optimizer.step() 328 | 329 | return loss.item() -------------------------------------------------------------------------------- /utils/tsmixer_conf.py: -------------------------------------------------------------------------------- 1 | from .load_csv import DataNormalization 2 | 3 | from dataclasses import dataclass 4 | from mashumaro import DataClassDictMixin 5 | from enum import Enum 6 | import os 7 | from typing import Optional, Tuple, Dict, List 8 | from torch.utils.data import DataLoader 9 | from loguru import logger 10 | import json 11 | 12 | 13 | def makedirs(d: str): 14 | if d != "": 15 | os.makedirs(d, exist_ok=True) 16 | 17 | 18 | @dataclass 19 | class TSMixerConf(DataClassDictMixin): 20 | 21 | class Initialize(Enum): 22 | FROM_LATEST_CHECKPOINT = "from-latest-checkpoint" 23 | "Load the model from the latest checkpoint" 24 | 25 | FROM_BEST_CHECKPOINT = "from-best-checkpoint" 26 | "Load the model from the best checkpoint" 27 | 28 | FROM_SCRATCH = "from-scratch" 29 | "Initialize the model from scratch" 30 | 31 | class DataSrc(Enum): 32 | 33 | CSV_FILE = "csv-file" 34 | "Load the dataset from a CSV file" 35 | 36 | class ValidationSplit(Enum): 37 | 38 | TEMPORAL_HOLDOUT = "temporal-holdout" 39 | "Reserve the last portion (e.g., 10-20%) of your time-ordered data for validation, and use the remaining data for training. This is a simple and widely used approach." 40 | 41 | output_dir: str 42 | "Directory where to save checkpoints and generated images" 43 | 44 | input_length: int 45 | "Number of time steps to use as input" 46 | 47 | no_features: int 48 | "Number of features in the dataset" 49 | 50 | no_mixer_layers: int 51 | "Number of mixer layers" 52 | 53 | prediction_length: int 54 | "Number of time steps to predict" 55 | 56 | data_src: DataSrc 57 | "Where to load the dataset from" 58 | 59 | device: str = "mps" 60 | "Device to use for training" 61 | 62 | data_src_csv: Optional[str] = None 63 | "Path to the CSV file to load the dataset from. Only used if data_src is CSV_FILE" 64 | 65 | batch_size: int = 64 66 | "Batch size" 67 | 68 | shuffle: bool = True 69 | "Shuffle the data" 70 | 71 | num_epochs: int = 10 72 | "Number of epochs to train for" 73 | 74 | learning_rate: float = 0.001 75 | "Learning rate" 76 | 77 | optimizer: str = "Adam" 78 | "Optimizer to use" 79 | 80 | random_seed: int = 42 81 | "Random seed for reproducibility" 82 | 83 | validation_split: ValidationSplit = ValidationSplit.TEMPORAL_HOLDOUT 84 | "How to split the data into training and validation" 85 | 86 | validation_split_holdout: float = 0.2 87 | "Use the last X% of the data for validation. Only used for TEMPORAL_HOLDOUT" 88 | 89 | initialize: Initialize = Initialize.FROM_SCRATCH 90 | "How to initialize the model" 91 | 92 | dropout: float = 0.5 93 | "Dropout" 94 | 95 | feat_mixing_hidden_channels: Optional[int] = None 96 | "Number of hidden channels in the feature mixing MLP. If None, uses same as input features." 97 | 98 | early_stopping_patience: Optional[int] = 5 99 | "Early stopping patience. If the validation loss does not improve over this many epochs, stop early. If None, no early stopping is used." 100 | 101 | @property 102 | def image_dir(self): 103 | makedirs(self.output_dir) 104 | makedirs(os.path.join(self.output_dir, "images")) 105 | return os.path.join(self.output_dir, "images") 106 | 107 | @property 108 | def checkpoint_init(self): 109 | makedirs(self.output_dir) 110 | return os.path.join(self.output_dir, "init.pth") 111 | 112 | @property 113 | def checkpoint_best(self): 114 | makedirs(self.output_dir) 115 | return os.path.join(self.output_dir, "best.pth") 116 | 117 | @property 118 | def checkpoint_latest(self): 119 | makedirs(self.output_dir) 120 | return os.path.join(self.output_dir, "latest.pth") 121 | 122 | @property 123 | def train_progress_json(self): 124 | makedirs(self.output_dir) 125 | return os.path.join(self.output_dir, "loss.json") 126 | 127 | @property 128 | def pred_val_dataset_json(self): 129 | makedirs(self.output_dir) 130 | return os.path.join(self.output_dir, "pred_val_dataset.json") 131 | 132 | @property 133 | def data_norm_json(self): 134 | makedirs(self.output_dir) 135 | return os.path.join(self.output_dir, "data_norm.json") 136 | 137 | def check_valid(self): 138 | assert 0 <= self.validation_split_holdout <= 1, "validation_split_holdout must be between 0 and 1" 139 | 140 | # Check device exists 141 | import torch 142 | assert self.device in ["cpu", "cuda", "cuda:0", "cuda:1", "cuda:2", "cuda:3", "mps"], f"Device {self.device} not supported" 143 | if self.device == "cuda": 144 | assert torch.cuda.is_available(), "CUDA is not available" 145 | assert torch.cuda.device_count() > 1, "Must have more than one CUDA device to use MPS" 146 | elif self.device == "mps": 147 | assert torch.backends.mps.is_available(), "MPS is not available" 148 | 149 | 150 | def load_training_metadata_or_new(self, epoch_start: Optional[int] = None) -> "TrainingMetadata": 151 | """Load the training progress from a JSON file, or create a new one 152 | 153 | Args: 154 | epoch_start (Optional[int], optional): Starting epoch - earlier epochs will be removed if not None. Defaults to None. 155 | 156 | Returns: 157 | TrainProgress: Training metadata 158 | """ 159 | if os.path.exists(self.train_progress_json): 160 | with open(self.train_progress_json, "r") as f: 161 | tp = TrainingMetadata.from_dict(json.load(f)) 162 | 163 | # Remove epochs after epoch_start 164 | if epoch_start is not None: 165 | tp.epoch_to_data = { epoch: tp.epoch_to_data[epoch] for epoch in tp.epoch_to_data if epoch < epoch_start } 166 | 167 | return tp 168 | else: 169 | return TrainingMetadata(epoch_to_data={}) 170 | 171 | 172 | def write_data_norm(self, data_norm: DataNormalization): 173 | """Write the data normalization to a JSON file 174 | 175 | Args: 176 | data_norm (DataNormalization): Data normalization 177 | """ 178 | with open(self.data_norm_json, "w") as f: 179 | json.dump(data_norm.to_dict(), f, indent=3) 180 | logger.debug(f"Saved data normalization to {f.name}") 181 | 182 | 183 | def write_training_metadata(self, train_data: "TrainingMetadata"): 184 | """Write the training progress to a JSON file 185 | 186 | Args: 187 | train_data (TrainingMetadata): _description_ 188 | """ 189 | if os.path.dirname(self.train_progress_json) != "": 190 | makedirs(os.path.dirname(self.train_progress_json)) 191 | with open(self.train_progress_json, "w") as f: 192 | json.dump(train_data.to_dict(), f, indent=3) 193 | 194 | 195 | def create_data_loaders_train_val(self, data_norm: Optional[DataNormalization] = None) -> Tuple[DataLoader, DataLoader, DataNormalization]: 196 | """Create the training and validation data loaders 197 | 198 | Args: 199 | data_norm (Optional[DataNormalization], optional): Data normalization to use, otherwise will be calculated. Defaults to None. 200 | 201 | Returns: 202 | Tuple[DataLoader, DataLoader, DataNormalization]: Training and validation data loaders 203 | """ 204 | 205 | if self.data_src == self.DataSrc.CSV_FILE: 206 | assert self.data_src_csv is not None, "data_src_csv must be set if data_src is CSV_FILE" 207 | 208 | from .load_csv import load_csv_dataset, ValidationSplit 209 | return load_csv_dataset( 210 | csv_file=self.data_src_csv, 211 | batch_size=self.batch_size, 212 | input_length=self.input_length, 213 | prediction_length=self.prediction_length, 214 | val_split=ValidationSplit(self.validation_split.value), 215 | val_split_holdout=self.validation_split_holdout, 216 | shuffle=self.shuffle, 217 | data_norm_exist=data_norm 218 | ) 219 | else: 220 | raise NotImplementedError(f"data_src {self.data_src} not implemented") 221 | 222 | 223 | @dataclass 224 | class TrainingMetadata(DataClassDictMixin): 225 | 226 | @dataclass 227 | class EpochData(DataClassDictMixin): 228 | epoch: int 229 | "Epoch number" 230 | 231 | train_loss: float 232 | "Training loss" 233 | 234 | val_loss: float 235 | "Validation loss" 236 | 237 | duration_seconds: float 238 | "Duration of the epoch in seconds" 239 | 240 | epoch_to_data: Dict[int, EpochData] 241 | "Mapping from epoch number to epoch data" 242 | -------------------------------------------------------------------------------- /utils/tsmixer_grid_search_conf.py: -------------------------------------------------------------------------------- 1 | from .tsmixer_conf import TSMixerConf 2 | 3 | from dataclasses import dataclass, field 4 | from mashumaro import DataClassDictMixin 5 | from typing import Optional, Tuple, Dict, List, Iterator 6 | from loguru import logger 7 | import os 8 | 9 | 10 | @dataclass 11 | class TSMixerGridSearch(DataClassDictMixin): 12 | """Configuration for grid search 13 | """ 14 | 15 | @dataclass 16 | class ParamRange(DataClassDictMixin): 17 | 18 | learning_rates: List[float] 19 | "Learning rates" 20 | 21 | no_mixer_layers: List[int] 22 | "Number of mixer layers" 23 | 24 | dropouts: List[float] 25 | "Dropout" 26 | 27 | input_lengths: List[int] 28 | "Number of time steps to use as input" 29 | 30 | prediction_lengths: List[int] 31 | "Number of time steps to predict" 32 | 33 | feat_mixing_hidden_channels: List[Optional[int]] = field(default_factory=lambda: [None]) 34 | "Number of hidden channels in the feature mixing MLP. If None, uses same as input features." 35 | 36 | batch_sizes: List[int] = field(default_factory=lambda: [64]) 37 | "Batch size" 38 | 39 | num_epochs: List[int] = field(default_factory=lambda: [100]) 40 | "Number of epochs to train for" 41 | 42 | optimizers: List[str] = field(default_factory=lambda: ["Adam"]) 43 | "Optimizer to use" 44 | 45 | @property 46 | def options_str(self) -> str: 47 | s = [] 48 | s.append(("lr",str(self.learning_rates))) 49 | s.append(("nmix",str(self.no_mixer_layers))) 50 | s.append(("drop",str(self.dropouts))) 51 | s.append(("in",str(self.input_lengths))) 52 | s.append(("pred",str(self.prediction_lengths))) 53 | s.append(("hidden",str(self.feat_mixing_hidden_channels))) 54 | s.append(("batch",str(self.batch_sizes))) 55 | s.append(("epochs",str(self.num_epochs))) 56 | s.append(("opt",str(self.optimizers))) 57 | 58 | # Sort by key 59 | s = sorted(s, key=lambda x: x[0]) 60 | 61 | return "_".join([f"{k}{v}" for k,v in s]) 62 | 63 | param_ranges: List[ParamRange] 64 | "Any number of parameter ranges to try" 65 | 66 | output_dir: str 67 | "Output directory" 68 | 69 | no_features: int 70 | "Number of features in the dataset" 71 | 72 | data_src: TSMixerConf.DataSrc = TSMixerConf.DataSrc.CSV_FILE 73 | "Where to load the dataset from" 74 | 75 | data_src_csv: Optional[str] = None 76 | "Path to the CSV file to load the dataset from. Only used if data_src is CSV_FILE" 77 | 78 | def iterate(self) -> Iterator[TSMixerConf]: 79 | """Iterate over all configurations 80 | 81 | Yields: 82 | Iterator[TSMixerConf]: Configuration for a single run 83 | """ 84 | for idx,param_range in enumerate(self.param_ranges): 85 | logger.info("===========================================") 86 | logger.info(f"Grid search iteration {idx+1}/{len(self.param_ranges)}") 87 | logger.info("===========================================") 88 | 89 | for learning_rate in param_range.learning_rates: 90 | for no_mixer_layers in param_range.no_mixer_layers: 91 | for dropout in param_range.dropouts: 92 | for feat_mixing_hidden_channels in param_range.feat_mixing_hidden_channels: 93 | for input_length in param_range.input_lengths: 94 | for prediction_length in param_range.prediction_lengths: 95 | for batch_size in param_range.batch_sizes: 96 | for num_epochs in param_range.num_epochs: 97 | for optimizer in param_range.optimizers: 98 | # Output subdir 99 | output_dir = os.path.join(self.output_dir, param_range.options_str) 100 | conf = TSMixerConf( 101 | input_length=input_length, 102 | prediction_length=prediction_length, 103 | no_features=self.no_features, 104 | no_mixer_layers=no_mixer_layers, 105 | output_dir=output_dir, 106 | data_src=self.data_src, 107 | data_src_csv=self.data_src_csv, 108 | batch_size=batch_size, 109 | num_epochs=num_epochs, 110 | learning_rate=learning_rate, 111 | optimizer=optimizer, 112 | dropout=dropout, 113 | feat_mixing_hidden_channels=feat_mixing_hidden_channels 114 | ) 115 | logger.info(f"TSMixer config: {conf}") 116 | logger.info(f"Output sub-dir: {output_dir}") 117 | yield conf 118 | --------------------------------------------------------------------------------