├── .gitignore ├── .gitmodules ├── CODE_OF_CONDUCT.md ├── LICENSE ├── README.md ├── SECURITY.md ├── checkpoints ├── .gitignore └── README.md ├── config ├── default │ ├── autoformer.yaml │ ├── csdi.yaml │ ├── dlinear.yaml │ ├── gru.yaml │ ├── gru_maf.yaml │ ├── gru_nvp.yaml │ ├── itransformer.yaml │ ├── linear.yaml │ ├── mean.yaml │ ├── moderntcn.yaml │ ├── naive.yaml │ ├── nhits.yaml │ ├── nlinear.yaml │ ├── patchtst.yaml │ ├── timegrad.yaml │ ├── timesnet.yaml │ ├── trans_maf.yaml │ ├── transformer.yaml │ ├── tsdiff.yaml │ └── tsmixer.yaml ├── ltsf │ ├── electricity_ltsf │ │ ├── csdi.yaml │ │ ├── dlinear.yaml │ │ ├── gru_nvp.yaml │ │ ├── patchtst.yaml │ │ └── timegrad.yaml │ ├── etth1 │ │ ├── csdi.yaml │ │ ├── dlinear.yaml │ │ ├── gru_nvp.yaml │ │ ├── patchtst.yaml │ │ └── timegrad.yaml │ ├── etth2 │ │ ├── csdi.yaml │ │ ├── dlinear.yaml │ │ ├── gru_nvp.yaml │ │ ├── patchtst.yaml │ │ └── timegrad.yaml │ ├── ettm1 │ │ ├── csdi.yaml │ │ ├── dlinear.yaml │ │ ├── gru_nvp.yaml │ │ ├── patchtst.yaml │ │ └── timegrad.yaml │ ├── ettm2 │ │ ├── csdi.yaml │ │ ├── dlinear.yaml │ │ ├── gru_nvp.yaml │ │ ├── patchtst.yaml │ │ └── timegrad.yaml │ ├── exchange_ltsf │ │ ├── csdi.yaml │ │ ├── dlinear.yaml │ │ ├── gru_nvp.yaml │ │ ├── patchtst.yaml │ │ └── timegrad.yaml │ ├── illness_ltsf │ │ ├── csdi.yaml │ │ ├── dlinear.yaml │ │ ├── gru_nvp.yaml │ │ ├── patchtst.yaml │ │ └── timegrad.yaml │ ├── traffic_ltsf │ │ ├── csdi.yaml │ │ ├── dlinear.yaml │ │ ├── gru_nvp.yaml │ │ ├── patchtst.yaml │ │ └── timegrad.yaml │ └── weather_ltsf │ │ ├── csdi.yaml │ │ ├── dlinear.yaml │ │ ├── gru_nvp.yaml │ │ ├── patchtst.yaml │ │ └── timegrad.yaml ├── m4 │ ├── m4_daily │ │ ├── csdi.yaml │ │ ├── dlinear.yaml │ │ ├── gru_nvp.yaml │ │ ├── patchtst.yaml │ │ └── timegrad.yaml │ ├── m4_weekly │ │ ├── csdi.yaml │ │ ├── dlinear.yaml │ │ ├── gru_nvp.yaml │ │ ├── patchtst.yaml │ │ └── timegrad.yaml │ ├── m5 │ │ ├── csdi.yaml │ │ ├── dlinear.yaml │ │ ├── gru_nvp.yaml │ │ ├── patchtst.yaml │ │ └── timegrad.yaml │ └── tourism_monthly │ │ ├── csdi.yaml │ │ ├── dlinear.yaml │ │ ├── gru_nvp.yaml │ │ ├── patchtst.yaml │ │ └── timegrad.yaml ├── multi_hor │ ├── autoformer.yaml │ └── elastst.yaml ├── pipeline_config.yaml ├── stsf │ ├── electricity │ │ ├── csdi.yaml │ │ ├── dlinear.yaml │ │ ├── gru.yaml │ │ ├── gru_maf.yaml │ │ ├── gru_nvp.yaml │ │ ├── patchtst.yaml │ │ ├── timegrad.yaml │ │ ├── timesnet.yaml │ │ ├── trans_maf.yaml │ │ └── transformer.yaml │ ├── exchange │ │ ├── csdi.yaml │ │ ├── dlinear.yaml │ │ ├── gru.yaml │ │ ├── gru_maf.yaml │ │ ├── gru_nvp.yaml │ │ ├── patchtst.yaml │ │ ├── timegrad.yaml │ │ ├── timesnet.yaml │ │ ├── trans_maf.yaml │ │ └── transformer.yaml │ ├── solar │ │ ├── csdi.yaml │ │ ├── dlinear.yaml │ │ ├── gru.yaml │ │ ├── gru_maf.yaml │ │ ├── gru_nvp.yaml │ │ ├── patchtst.yaml │ │ ├── timegrad.yaml │ │ ├── timesnet.yaml │ │ ├── trans_maf.yaml │ │ └── transformer.yaml │ ├── traffic │ │ ├── csdi.yaml │ │ ├── dlinear.yaml │ │ ├── gru.yaml │ │ ├── gru_maf.yaml │ │ ├── gru_nvp.yaml │ │ ├── patchtst.yaml │ │ ├── timegrad.yaml │ │ ├── timesnet.yaml │ │ ├── trans_maf.yaml │ │ └── transformer.yaml │ └── wiki │ │ ├── csdi.yaml │ │ ├── dlinear.yaml │ │ ├── gru.yaml │ │ ├── gru_maf.yaml │ │ ├── gru_nvp.yaml │ │ ├── patchtst.yaml │ │ ├── timegrad.yaml │ │ ├── timesnet.yaml │ │ ├── trans_maf.yaml │ │ └── transformer.yaml └── tsfm │ ├── chronos.yaml │ ├── forecastpfn.yaml │ ├── lag_llama.yaml │ ├── moirai.yaml │ ├── moirai │ ├── context_5000 │ │ ├── electricity_ltsf.yaml │ │ ├── electricity_nips.yaml │ │ ├── etth1.yaml │ │ ├── etth2.yaml │ │ ├── ettm1.yaml │ │ ├── ettm2.yaml │ │ ├── exchange_rate_nips.yaml │ │ ├── solar_nips.yaml │ │ └── weather_ltsf.yaml │ └── context_96 │ │ ├── electricity_ltsf.yaml │ │ ├── electricity_nips.yaml │ │ ├── etth1.yaml │ │ ├── etth2.yaml │ │ ├── ettm1.yaml │ │ ├── ettm2.yaml │ │ ├── exchange_rate_nips.yaml │ │ ├── solar_nips.yaml │ │ └── weather_ltsf.yaml │ ├── time_moe.yaml │ ├── timer.yaml │ ├── timesfm.yaml │ ├── tinytimemixer.yaml │ └── units.yaml ├── datasets └── .gitignore ├── docs ├── benchmark │ ├── README.md │ ├── figs │ │ └── methodology.jpg │ ├── foundation_model │ │ ├── README.md │ │ ├── chronos.md │ │ ├── figs │ │ │ ├── FM_dataset.jpg │ │ │ ├── FM_summary.jpg │ │ │ ├── fm_short_term.jpg │ │ │ ├── fm_var_hor.jpg │ │ │ ├── foundation_model.png │ │ │ ├── tsfm_analysis.jpg │ │ │ └── tsfm_results.jpg │ │ ├── forecastpfn.md │ │ ├── lag-llama.md │ │ ├── moirai.md │ │ ├── timer.md │ │ ├── timesfm.md │ │ ├── ttm.md │ │ └── units.md │ └── supervised_model │ │ ├── README.md │ │ └── figs │ │ ├── ar_vs_nar.jpg │ │ ├── long_bench.jpg │ │ ├── norm.jpg │ │ ├── point_vs_prob.jpg │ │ ├── short_bench.jpg │ │ └── supervised.png ├── documentation │ ├── Gift_eval.md │ └── README.md └── figs │ ├── data_pipeline.png │ ├── probts_framework.png │ └── probts_logo.png ├── exps └── .gitignore ├── notebook └── data_characteristics.ipynb ├── probts ├── __init__.py ├── callbacks │ ├── __init__.py │ ├── memory_callback.py │ └── time_callback.py ├── data │ ├── __init__.py │ ├── data_manager.py │ ├── data_module.py │ ├── data_utils │ │ ├── data_scaler.py │ │ ├── data_utils.py │ │ ├── get_datasets.py │ │ └── time_features.py │ ├── data_wrapper.py │ └── datasets │ │ ├── gift_eval_datasets.py │ │ ├── multi_horizon_datasets.py │ │ └── single_horizon_datasets.py ├── model │ ├── __init__.py │ ├── forecast_module.py │ ├── forecaster │ │ ├── __init__.py │ │ ├── forecaster.py │ │ ├── point_forecaster │ │ │ ├── __init__.py │ │ │ ├── autoformer.py │ │ │ ├── dlinear.py │ │ │ ├── elastst.py │ │ │ ├── forecastpfn.py │ │ │ ├── gru.py │ │ │ ├── itransformer.py │ │ │ ├── linear.py │ │ │ ├── mean.py │ │ │ ├── moderntcn.py │ │ │ ├── naive.py │ │ │ ├── nhits.py │ │ │ ├── nlinear.py │ │ │ ├── patchtst.py │ │ │ ├── time_moe.py │ │ │ ├── timer.py │ │ │ ├── timesfm.py │ │ │ ├── timesnet.py │ │ │ ├── tinytimemixer.py │ │ │ ├── transformer.py │ │ │ ├── tsmixer.py │ │ │ └── units.py │ │ └── prob_forecaster │ │ │ ├── __init__.py │ │ │ ├── chronos.py │ │ │ ├── csdi.py │ │ │ ├── gru_maf.py │ │ │ ├── gru_nvp.py │ │ │ ├── lag_llama.py │ │ │ ├── moirai.py │ │ │ ├── timegrad.py │ │ │ ├── trans_maf.py │ │ │ └── tsdiff.py │ └── nn │ │ ├── __init__.py │ │ ├── arch │ │ ├── AutoformerModule │ │ │ ├── AutoCorrelation.py │ │ │ └── Autoformer_EncDec.py │ │ ├── ChronosModule │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ ├── chronos.py │ │ │ ├── chronos_bolt.py │ │ │ ├── loss.py │ │ │ └── utils.py │ │ ├── Conv_Blocks.py │ │ ├── ElasTSTModule │ │ │ ├── ElasTST_backbone.py │ │ │ ├── Layers.py │ │ │ ├── Modules.py │ │ │ ├── SubLayers.py │ │ │ ├── TRoPE.py │ │ │ └── __init__.py │ │ ├── ModernTCN_backbone.py │ │ ├── Moirai_backbone.py │ │ ├── PatchTSTModule │ │ │ ├── PatchTST_backbone.py │ │ │ └── PatchTST_layers.py │ │ ├── RevIN.py │ │ ├── S4 │ │ │ ├── s4.py │ │ │ └── s4_backbones.py │ │ ├── TSMixer_layers.py │ │ ├── TimesFMModule │ │ │ ├── __init__.py │ │ │ ├── patched_decoder.py │ │ │ ├── pytorch_patched_decoder.py │ │ │ ├── timesfm_base.py │ │ │ ├── timesfm_jax.py │ │ │ ├── timesfm_torch.py │ │ │ └── xreg_lib.py │ │ ├── TransformerModule │ │ │ ├── Embed.py │ │ │ ├── SelfAttention_Family.py │ │ │ └── Transformer_EncDec.py │ │ ├── __init__.py │ │ └── decomp.py │ │ └── prob │ │ ├── MAF.py │ │ ├── RealNVP.py │ │ ├── __init__.py │ │ ├── diffusion_layers.py │ │ ├── flow_model.py │ │ └── gaussian_diffusion.py └── utils │ ├── __init__.py │ ├── download_datasets.py │ ├── evaluator.py │ ├── masking.py │ ├── metrics.py │ ├── position_emb.py │ ├── save_utils.py │ └── utils.py ├── pyproject.toml ├── run.py ├── run.sh └── scripts ├── prepare_datasets.sh ├── prepare_tsfm_checkpoints.sh ├── reproduce_ltsf_results.sh ├── reproduce_stsf_results.sh ├── reproduce_tsfm_results.sh ├── run_elastst.sh └── run_varied_hor_training.sh /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "submodules/uni2ts"] 2 | path = submodules/uni2ts 3 | url = https://github.com/SalesforceAIResearch/uni2ts.git 4 | [submodule "submodules/lag_llama"] 5 | path = submodules/lag_llama 6 | url = https://github.com/time-series-foundation-models/lag-llama.git 7 | [submodule "submodules/timesfm"] 8 | path = submodules/timesfm 9 | url = https://github.com/google-research/timesfm.git 10 | [submodule "submodules/tsfm"] 11 | path = submodules/tsfm 12 | url = https://github.com/ibm-granite/granite-tsfm.git 13 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Microsoft Open Source Code of Conduct 2 | 3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 4 | 5 | Resources: 6 | 7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) 8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /checkpoints/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore all files 2 | * 3 | 4 | # Except README.md 5 | !README.md -------------------------------------------------------------------------------- /config/default/autoformer.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 0 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | accumulate_grad_batches: 1 13 | # num_sanity_val_steps: 0 14 | # gradient_clip_algorithm: 'norm' 15 | model: 16 | forecaster: 17 | class_path: probts.model.forecaster.point_forecaster.Autoformer 18 | init_args: 19 | moving_avg: 25 20 | factor: 1 21 | n_heads: 8 22 | activation: 'gelu' 23 | e_layers: 2 24 | d_layers: 1 25 | output_attention: false 26 | d_ff: 512 27 | f_hidden_size: 512 28 | embed: 'timeF' 29 | use_lags: false 30 | use_feat_idx_emb: false 31 | use_time_feat: true 32 | feat_idx_emb_dim: 1 33 | num_samples: 1 34 | learning_rate: 1e-3 35 | quantiles_num: 20 36 | data: 37 | data_manager: 38 | class_path: probts.data.data_manager.DataManager 39 | init_args: 40 | dataset: solar_nips 41 | split_val: true 42 | scaler: standard # none, standard, scaling 43 | batch_size: 32 44 | test_batch_size: 32 45 | num_workers: 8 -------------------------------------------------------------------------------- /config/default/csdi.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 800 10 | log_every_n_steps: 1 11 | check_val_every_n_epoch: 2 12 | default_root_dir: ./results 13 | accumulate_grad_batches: 8 14 | model: 15 | forecaster: 16 | class_path: probts.model.forecaster.prob_forecaster.CSDI 17 | init_args: 18 | emb_time_dim: 128 19 | emb_feature_dim: 16 20 | channels: 64 21 | n_layers: 4 22 | num_heads: 8 23 | num_steps: 50 24 | diffusion_embedding_dim: 128 25 | beta_start: 0.001 26 | beta_end: 0.5 27 | sample_size: 64 28 | linear_trans: false 29 | use_lags: false 30 | use_feat_idx_emb: false 31 | use_time_feat: false 32 | feat_idx_emb_dim: 1 33 | num_samples: 100 34 | learning_rate: 0.001 35 | quantiles_num: 20 36 | data: 37 | data_manager: 38 | class_path: probts.data.data_manager.DataManager 39 | init_args: 40 | dataset: solar_nips 41 | split_val: true 42 | scaler: standard # identity, standard, temporal 43 | batch_size: 4 44 | test_batch_size: 4 45 | num_workers: 8 46 | -------------------------------------------------------------------------------- /config/default/dlinear.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | model: 13 | forecaster: 14 | class_path: probts.model.forecaster.point_forecaster.DLinear 15 | init_args: 16 | individual: false 17 | kernel_size: 3 18 | use_lags: true 19 | use_feat_idx_emb: true 20 | use_time_feat: true 21 | learning_rate: 0.01 22 | quantiles_num: 20 23 | data: 24 | data_manager: 25 | class_path: probts.data.data_manager.DataManager 26 | init_args: 27 | dataset: solar_nips 28 | split_val: true 29 | scaler: standard # identity, standard, temporal 30 | batch_size: 32 31 | test_batch_size: 32 32 | num_workers: 8 -------------------------------------------------------------------------------- /config/default/gru.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | model: 13 | forecaster: 14 | class_path: probts.model.forecaster.point_forecaster.GRUForecaster 15 | init_args: 16 | f_hidden_size: 40 17 | num_layers: 2 18 | dropout: 0.1 19 | use_lags: true 20 | use_feat_idx_emb: true 21 | use_time_feat: true 22 | feat_idx_emb_dim: 1 23 | learning_rate: 0.001 24 | quantiles_num: 20 25 | data: 26 | data_manager: 27 | class_path: probts.data.data_manager.DataManager 28 | init_args: 29 | dataset: solar_nips 30 | split_val: true 31 | scaler: standard # identity, standard, temporal 32 | batch_size: 64 33 | test_batch_size: 64 34 | num_workers: 8 35 | -------------------------------------------------------------------------------- /config/default/gru_maf.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 1 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | model: 13 | forecaster: 14 | class_path: probts.model.forecaster.prob_forecaster.GRU_MAF 15 | init_args: 16 | enc_num_layers: 2 17 | enc_hidden_size: 40 18 | enc_dropout: 0.1 19 | n_blocks: 4 20 | hidden_size: 100 21 | n_hidden: 2 22 | batch_norm: false 23 | conditional_length: 200 24 | dequantize: true 25 | use_lags: true 26 | use_feat_idx_emb: true 27 | use_time_feat: true 28 | feat_idx_emb_dim: 1 29 | use_scaling: true 30 | num_samples: 100 31 | learning_rate: 0.001 32 | quantiles_num: 20 33 | data: 34 | data_manager: 35 | class_path: probts.data.data_manager.DataManager 36 | init_args: 37 | dataset: solar_nips 38 | scaler: identity # identity, standard, temporal 39 | split_val: true 40 | batch_size: 64 41 | test_batch_size: 64 42 | num_workers: 8 -------------------------------------------------------------------------------- /config/default/gru_nvp.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 7 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | model: 13 | forecaster: 14 | class_path: probts.model.forecaster.prob_forecaster.GRU_NVP 15 | init_args: 16 | enc_hidden_size: 40 17 | enc_num_layers: 2 18 | enc_dropout: 0.1 19 | n_blocks: 4 20 | hidden_size: 100 21 | n_hidden: 2 22 | batch_norm: true 23 | conditional_length: 200 24 | dequantize: true 25 | use_lags: true 26 | use_feat_idx_emb: true 27 | use_time_feat: true 28 | feat_idx_emb_dim: 1 29 | use_scaling: true 30 | num_samples: 100 31 | learning_rate: 0.001 32 | quantiles_num: 20 33 | data: 34 | data_manager: 35 | class_path: probts.data.data_manager.DataManager 36 | init_args: 37 | dataset: solar_nips 38 | split_val: true 39 | scaler: identity # identity, standard, temporal 40 | batch_size: 64 41 | test_batch_size: 64 42 | num_workers: 8 -------------------------------------------------------------------------------- /config/default/itransformer.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 0 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | accumulate_grad_batches: 1 13 | model: 14 | forecaster: 15 | class_path: probts.model.forecaster.point_forecaster.iTransformer 16 | init_args: 17 | factor: 1 18 | n_heads: 8 19 | activation: 'gelu' 20 | e_layers: 2 21 | output_attention: false 22 | f_hidden_size: 256 23 | d_ff: 256 24 | label_len: 48 25 | use_lags: false 26 | use_feat_idx_emb: false 27 | use_time_feat: false 28 | feat_idx_emb_dim: 1 29 | num_samples: 1 30 | learning_rate: 1e-4 31 | quantiles_num: 20 32 | data: 33 | data_manager: 34 | class_path: probts.data.data_manager.DataManager 35 | init_args: 36 | dataset: solar_nips 37 | split_val: true 38 | scaler: standard # none, standard, scaling 39 | batch_size: 32 40 | test_batch_size: 32 41 | num_workers: 8 -------------------------------------------------------------------------------- /config/default/linear.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 30 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | model: 13 | forecaster: 14 | class_path: probts.model.forecaster.point_forecaster.LinearForecaster 15 | init_args: 16 | individual: false 17 | use_lags: true 18 | learning_rate: 0.001 19 | quantiles_num: 20 20 | data: 21 | data_manager: 22 | class_path: probts.data.data_manager.DataManager 23 | init_args: 24 | dataset: solar_nips 25 | split_val: true 26 | scaler: standard # identity, standard, temporal 27 | batch_size: 64 28 | test_batch_size: 64 29 | num_workers: 8 30 | -------------------------------------------------------------------------------- /config/default/mean.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 40 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | model: 13 | forecaster: 14 | class_path: probts.model.forecaster.point_forecaster.MeanForecaster 15 | init_args: 16 | mode: global 17 | learning_rate: 0.001 18 | quantiles_num: 20 19 | data: 20 | data_manager: 21 | class_path: probts.data.data_manager.DataManager 22 | init_args: 23 | dataset: solar_nips 24 | split_val: true 25 | scaler: identity # identity, standard, temporal 26 | batch_size: 64 27 | test_batch_size: 64 28 | num_workers: 8 -------------------------------------------------------------------------------- /config/default/moderntcn.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 0 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | model: 13 | forecaster: 14 | class_path: probts.model.forecaster.point_forecaster.ModernTCN 15 | init_args: 16 | ffn_ratio: 1 17 | patch_size: 8 18 | patch_stride: 4 19 | num_blocks: [1] 20 | large_size: [51] 21 | dims: [64, 64, 64, 64] 22 | dropout: 0.3 23 | kernel_size: 3 24 | small_size: [5] 25 | use_multi_scale: false 26 | small_kernel_merged: false 27 | learning_rate: 0.0001 28 | quantiles_num: 20 29 | data: 30 | data_manager: 31 | class_path: probts.data.data_manager.DataManager 32 | init_args: 33 | dataset: etth1 34 | split_val: true 35 | scaler: standard # identity, standard, temporal 36 | context_length: 96 37 | prediction_length: 96 38 | batch_size: 32 39 | test_batch_size: 32 40 | num_workers: 8 -------------------------------------------------------------------------------- /config/default/naive.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 40 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | model: 13 | forecaster: 14 | class_path: probts.model.forecaster.point_forecaster.NaiveForecaster 15 | learning_rate: 0.001 16 | quantiles_num: 10 17 | data: 18 | data_manager: 19 | class_path: probts.data.data_manager.DataManager 20 | init_args: 21 | dataset: solar_nips 22 | split_val: true 23 | scaler: identity # identity, standard, temporal 24 | batch_size: 64 25 | test_batch_size: 64 26 | num_workers: 8 -------------------------------------------------------------------------------- /config/default/nhits.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | accumulate_grad_batches: 4 13 | model: 14 | forecaster: 15 | class_path: probts.model.forecaster.point_forecaster.NHiTS 16 | init_args: 17 | n_blocks: [1,1,1] 18 | hidden_size: 512 19 | pooling_mode: 'max' 20 | interpolation_mode: 'linear' 21 | activation: 'ReLU' 22 | initialization: 'lecun_normal' 23 | batch_normalization: false 24 | shared_weights: false 25 | naive_level: 26 | dropout: 0 27 | n_layers: 2 28 | use_lags: false 29 | use_feat_idx_emb: true 30 | use_time_feat: true 31 | feat_idx_emb_dim: 1 32 | learning_rate: 0.001 33 | quantiles_num: 20 34 | data: 35 | data_manager: 36 | class_path: probts.data.data_manager.DataManager 37 | init_args: 38 | dataset: solar_nips 39 | split_val: true 40 | scaler: standard # identity, standard, temporal 41 | batch_size: 64 42 | test_batch_size: 64 43 | num_workers: 8 44 | -------------------------------------------------------------------------------- /config/default/nlinear.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | model: 13 | forecaster: 14 | class_path: probts.model.forecaster.point_forecaster.NLinear 15 | init_args: 16 | individual: false 17 | use_lags: false 18 | use_feat_idx_emb: false 19 | use_time_feat: false 20 | learning_rate: 0.01 21 | quantiles_num: 20 22 | data: 23 | data_manager: 24 | class_path: probts.data.data_manager.DataManager 25 | init_args: 26 | dataset: solar_nips 27 | split_val: true 28 | scaler: standard # identity, standard, temporal 29 | batch_size: 64 30 | test_batch_size: 64 31 | num_workers: 8 32 | -------------------------------------------------------------------------------- /config/default/patchtst.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | accumulate_grad_batches: 1 13 | model: 14 | forecaster: 15 | class_path: probts.model.forecaster.point_forecaster.PatchTST 16 | init_args: 17 | stride: 3 18 | patch_len: 6 19 | dropout: 0.1 20 | f_hidden_size: 32 21 | n_layers: 3 22 | n_heads: 8 23 | fc_dropout: 0.2 24 | head_dropout: 0 25 | individual: false 26 | optimizer_config: 27 | class_name: torch.optim.Adam 28 | init_args: 29 | weight_decay: 0 30 | lr_scheduler_config: 31 | class_name: torch.optim.lr_scheduler.OneCycleLR 32 | init_args: 33 | max_lr: 0.0001 34 | steps_per_epoch: 100 35 | pct_start: 0.3 36 | epochs: 50 37 | learning_rate: 0.0001 38 | quantiles_num: 20 39 | data: 40 | data_manager: 41 | class_path: probts.data.data_manager.DataManager 42 | init_args: 43 | dataset: exchange_rate_nips 44 | split_val: true 45 | scaler: standard # identity, standard, temporal 46 | batch_size: 64 47 | test_batch_size: 64 48 | num_workers: 8 -------------------------------------------------------------------------------- /config/default/timegrad.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | model: 13 | forecaster: 14 | class_path: probts.model.forecaster.prob_forecaster.TimeGrad 15 | init_args: 16 | loss_type: l2 17 | diff_steps: 100 18 | beta_end: 0.1 19 | beta_schedule: linear 20 | conditional_length: 100 21 | enc_hidden_size: 128 22 | enc_num_layers: 4 23 | enc_dropout: 0.1 24 | use_lags: true 25 | use_feat_idx_emb: true 26 | use_time_feat: true 27 | feat_idx_emb_dim: 1 28 | use_scaling: true 29 | learning_rate: 0.001 30 | quantiles_num: 20 31 | data: 32 | data_manager: 33 | class_path: probts.data.data_manager.DataManager 34 | init_args: 35 | dataset: solar_nips 36 | scaler: identity # identity, standard, temporal 37 | split_val: true 38 | batch_size: 64 39 | test_batch_size: 64 40 | num_workers: 8 -------------------------------------------------------------------------------- /config/default/timesnet.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | model: 13 | forecaster: 14 | class_path: probts.model.forecaster.point_forecaster.TimesNet 15 | init_args: 16 | n_layers: 2 17 | num_kernels: 6 18 | top_k: 5 19 | d_ff: 32 20 | dropout: 0.1 21 | f_hidden_size: 40 22 | use_lags: false 23 | use_feat_idx_emb: false 24 | use_time_feat: false 25 | learning_rate: 0.001 26 | quantiles_num: 20 27 | data: 28 | data_manager: 29 | class_path: probts.data.data_manager.DataManager 30 | init_args: 31 | dataset: solar_nips 32 | split_val: true 33 | scaler: standard # identity, standard, temporal 34 | batch_size: 64 35 | test_batch_size: 64 36 | num_workers: 8 37 | -------------------------------------------------------------------------------- /config/default/trans_maf.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | model: 13 | forecaster: 14 | class_path: probts.model.forecaster.prob_forecaster.Trans_MAF 15 | init_args: 16 | enc_hidden_size: 32 17 | enc_num_heads: 8 18 | enc_num_encoder_layers: 2 19 | enc_num_decoder_layers: 2 20 | enc_dim_feedforward_scale: 4 21 | enc_dropout: 0.1 22 | enc_activation: gelu 23 | n_blocks: 4 24 | hidden_size: 100 25 | n_hidden: 2 26 | batch_norm: false 27 | conditional_length: 200 28 | dequantize: true 29 | use_lags: true 30 | use_feat_idx_emb: true 31 | use_time_feat: true 32 | feat_idx_emb_dim: 1 33 | use_scaling: true 34 | num_samples: 100 35 | learning_rate: 0.001 36 | quantiles_num: 20 37 | data: 38 | data_manager: 39 | class_path: probts.data.data_manager.DataManager 40 | init_args: 41 | dataset: solar_nips 42 | scaler: identity # identity, standard, temporal 43 | split_val: true 44 | batch_size: 64 45 | test_batch_size: 64 46 | num_workers: 8 -------------------------------------------------------------------------------- /config/default/transformer.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | model: 13 | forecaster: 14 | class_path: probts.model.forecaster.point_forecaster.TransformerForecaster 15 | init_args: 16 | f_hidden_size: 16 17 | num_heads: 4 18 | num_encoder_layers: 3 19 | num_decoder_layers: 3 20 | dim_feedforward_scale: 4 21 | dropout: 0.1 22 | activation: gelu 23 | use_lags: true 24 | use_feat_idx_emb: true 25 | use_time_feat: true 26 | feat_idx_emb_dim: 1 27 | learning_rate: 0.001 28 | quantiles_num: 20 29 | data: 30 | data_manager: 31 | class_path: probts.data.data_manager.DataManager 32 | init_args: 33 | dataset: solar_nips 34 | split_val: true 35 | scaler: standard # identity, standard, temporal 36 | batch_size: 64 37 | test_batch_size: 64 38 | num_workers: 8 -------------------------------------------------------------------------------- /config/default/tsdiff.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | check_val_every_n_epoch: 1 12 | default_root_dir: ./results 13 | accumulate_grad_batches: 1 14 | gradient_clip_val: 0.5 15 | model: 16 | forecaster: 17 | class_path: probts.model.forecaster.prob_forecaster.TSDiffCond 18 | init_args: 19 | timesteps: 100 20 | hidden_dim: 64 21 | step_emb: 128 22 | num_residual_blocks: 3 23 | dropout: 0.0 24 | mode: diag # diag, nplr 25 | measure: diag # 'diag', 'diag-lin', 'diag-inv', or 'diag-legs' for diag 26 | use_lags: false 27 | use_feat_idx_emb: false 28 | use_time_feat: false 29 | feat_idx_emb_dim: 1 30 | use_scaling: false 31 | num_samples: 100 32 | learning_rate: 0.001 33 | quantiles_num: 20 34 | data: 35 | data_manager: 36 | class_path: probts.data.data_manager.DataManager 37 | init_args: 38 | dataset: solar_nips 39 | split_val: true 40 | scaler: temporal # identity, standard, temporal 41 | context_length: 336 42 | batch_size: 32 43 | test_batch_size: 32 44 | num_workers: 8 45 | -------------------------------------------------------------------------------- /config/default/tsmixer.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 40 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | accumulate_grad_batches: 1 13 | model: 14 | forecaster: 15 | class_path: probts.model.forecaster.point_forecaster.TSMixer 16 | init_args: 17 | num_blocks: 6 18 | dropout_rate: 0.7 19 | ff_dim: 64 20 | use_lags: false 21 | use_feat_idx_emb: false 22 | use_time_feat: false 23 | feat_idx_emb_dim: 1 24 | learning_rate: 0.0001 25 | quantiles_num: 20 26 | data: 27 | data_manager: 28 | class_path: probts.data.data_manager.DataManager 29 | init_args: 30 | dataset: etth1 31 | split_val: true 32 | scaler: standard # identity, standard, temporal 33 | context_length: 96 34 | prediction_length: 96 35 | batch_size: 64 36 | test_batch_size: 64 37 | num_workers: 8 38 | -------------------------------------------------------------------------------- /config/ltsf/electricity_ltsf/csdi.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 800 10 | log_every_n_steps: 1 11 | check_val_every_n_epoch: 3 12 | default_root_dir: ./results 13 | accumulate_grad_batches: 8 14 | model: 15 | forecaster: 16 | class_path: probts.model.forecaster.prob_forecaster.CSDI 17 | init_args: 18 | emb_time_dim: 64 19 | emb_feature_dim: 8 20 | channels: 64 21 | n_layers: 4 22 | num_heads: 8 23 | num_steps: 50 24 | diffusion_embedding_dim: 64 25 | beta_start: 0.001 26 | beta_end: 0.5 27 | sample_size: 16 28 | linear_trans: false 29 | use_lags: false 30 | use_feat_idx_emb: false 31 | use_time_feat: false 32 | feat_idx_emb_dim: 1 33 | num_samples: 100 34 | learning_rate: 0.001 35 | quantiles_num: 20 36 | data: 37 | data_manager: 38 | class_path: probts.data.data_manager.DataManager 39 | init_args: 40 | dataset: electricity_ltsf 41 | scaler: standard # identity, standard, temporal 42 | split_val: true 43 | batch_size: 4 44 | test_batch_size: 8 45 | num_workers: 8 46 | -------------------------------------------------------------------------------- /config/ltsf/electricity_ltsf/dlinear.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 200 10 | log_every_n_steps: 1 11 | accumulate_grad_batches: 2 12 | default_root_dir: ./results 13 | model: 14 | forecaster: 15 | class_path: probts.model.forecaster.point_forecaster.DLinearEncoder 16 | init_args: 17 | individual: true 18 | kernel_size: 25 19 | use_lags: false 20 | use_feat_idx_emb: false 21 | use_time_feat: false 22 | learning_rate: 0.001 23 | quantiles_num: 20 24 | data: 25 | data_manager: 26 | class_path: probts.data.data_manager.DataManager 27 | init_args: 28 | dataset: electricity_ltsf 29 | split_val: true 30 | scaler: standard # identity, standard, temporal 31 | context_length: 96 32 | prediction_length: 96 33 | batch_size: 16 34 | test_batch_size: 16 35 | num_workers: 8 36 | -------------------------------------------------------------------------------- /config/ltsf/electricity_ltsf/gru_nvp.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 400 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | accumulate_grad_batches: 4 13 | model: 14 | forecaster: 15 | class_path: probts.model.forecaster.prob_forecaster.GRU_NVP 16 | init_args: 17 | enc_hidden_size: 128 18 | enc_num_layers: 2 19 | enc_dropout: 0.1 20 | n_blocks: 2 21 | hidden_size: 64 22 | n_hidden: 2 23 | batch_norm: false 24 | conditional_length: 200 25 | dequantize: false 26 | use_lags: true 27 | use_feat_idx_emb: true 28 | use_time_feat: true 29 | feat_idx_emb_dim: 1 30 | use_scaling: true 31 | num_samples: 100 32 | learning_rate: 0.001 33 | quantiles_num: 20 34 | data: 35 | data_manager: 36 | class_path: probts.data.data_manager.DataManager 37 | init_args: 38 | dataset: electricity_ltsf 39 | split_val: true 40 | scaler: identity # identity, standard, temporal 41 | context_length: 96 42 | prediction_length: 96 43 | batch_size: 16 44 | test_batch_size: 16 45 | num_workers: 8 46 | -------------------------------------------------------------------------------- /config/ltsf/electricity_ltsf/patchtst.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 400 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | accumulate_grad_batches: 4 13 | model: 14 | forecaster: 15 | class_path: probts.model.forecaster.point_forecaster.PatchTST 16 | init_args: 17 | stride: 8 18 | patch_len: 16 19 | dropout: 0.2 20 | f_hidden_size: 128 21 | n_layers: 3 22 | n_heads: 16 23 | fc_dropout: 0.2 24 | head_dropout: 0 25 | individual: false 26 | num_samples: 100 27 | learning_rate: 0.0001 28 | quantiles_num: 20 29 | data: 30 | data_manager: 31 | class_path: probts.data.data_manager.DataManager 32 | init_args: 33 | dataset: electricity_ltsf 34 | split_val: true 35 | scaler: standard # identity, standard, temporal 36 | context_length: 96 37 | prediction_length: 96 38 | batch_size: 8 39 | test_batch_size: 8 40 | num_workers: 8 -------------------------------------------------------------------------------- /config/ltsf/electricity_ltsf/timegrad.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | accumulate_grad_batches: 4 13 | model: 14 | forecaster: 15 | class_path: probts.model.forecaster.prob_forecaster.TimeGrad 16 | init_args: 17 | loss_type: l2 18 | diff_steps: 100 19 | beta_end: 0.1 20 | beta_schedule: linear 21 | conditional_length: 200 22 | enc_hidden_size: 128 23 | enc_num_layers: 3 24 | enc_dropout: 0.1 25 | use_lags: true 26 | use_feat_idx_emb: true 27 | use_time_feat: true 28 | feat_idx_emb_dim: 1 29 | use_scaling: true 30 | num_samples: 100 31 | learning_rate: 0.001 32 | quantiles_num: 20 33 | data: 34 | data_manager: 35 | class_path: probts.data.data_manager.DataManager 36 | init_args: 37 | dataset: electricity_ltsf 38 | split_val: true 39 | scaler: identity # identity, standard, temporal 40 | context_length: 96 41 | prediction_length: 96 42 | batch_size: 16 43 | test_batch_size: 16 44 | num_workers: 8 45 | -------------------------------------------------------------------------------- /config/ltsf/etth1/csdi.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 800 10 | log_every_n_steps: 1 11 | check_val_every_n_epoch: 2 12 | default_root_dir: ./results 13 | accumulate_grad_batches: 8 14 | model: 15 | forecaster: 16 | class_path: probts.model.forecaster.prob_forecaster.CSDI 17 | init_args: 18 | emb_time_dim: 128 19 | emb_feature_dim: 16 20 | channels: 64 21 | n_layers: 4 22 | num_heads: 8 23 | num_steps: 50 24 | diffusion_embedding_dim: 128 25 | beta_start: 0.001 26 | beta_end: 0.5 27 | sample_size: 64 28 | linear_trans: false 29 | use_lags: false 30 | use_feat_idx_emb: false 31 | use_time_feat: false 32 | feat_idx_emb_dim: 1 33 | num_samples: 100 34 | learning_rate: 0.001 35 | quantiles_num: 20 36 | data: 37 | data_manager: 38 | class_path: probts.data.data_manager.DataManager 39 | init_args: 40 | dataset: etth1 41 | split_val: true 42 | scaler: standard # identity, standard, temporal 43 | context_length: 96 44 | prediction_length: 96 45 | batch_size: 8 46 | test_batch_size: 8 47 | num_workers: 8 48 | -------------------------------------------------------------------------------- /config/ltsf/etth1/dlinear.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | accumulate_grad_batches: 1 12 | default_root_dir: ./results 13 | model: 14 | forecaster: 15 | class_path: probts.model.forecaster.point_forecaster.DLinear 16 | init_args: 17 | individual: true 18 | kernel_size: 25 19 | use_lags: false 20 | use_feat_idx_emb: false 21 | use_time_feat: false 22 | learning_rate: 0.005 23 | quantiles_num: 20 24 | data: 25 | data_manager: 26 | class_path: probts.data.data_manager.DataManager 27 | init_args: 28 | dataset: etth1 29 | split_val: true 30 | scaler: standard # identity, standard, temporal 31 | context_length: 96 32 | prediction_length: 96 33 | batch_size: 32 34 | test_batch_size: 32 35 | num_workers: 8 36 | -------------------------------------------------------------------------------- /config/ltsf/etth1/gru_nvp.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | accumulate_grad_batches: 1 13 | model: 14 | forecaster: 15 | class_path: probts.model.forecaster.prob_forecaster.GRU_NVP 16 | init_args: 17 | enc_hidden_size: 64 18 | enc_num_layers: 2 19 | enc_dropout: 0.1 20 | n_blocks: 4 21 | hidden_size: 64 22 | n_hidden: 3 23 | batch_norm: false 24 | conditional_length: 100 25 | dequantize: false 26 | use_lags: true 27 | use_feat_idx_emb: true 28 | use_time_feat: true 29 | feat_idx_emb_dim: 1 30 | use_scaling: true 31 | num_samples: 100 32 | learning_rate: 0.001 33 | quantiles_num: 20 34 | data: 35 | data_manager: 36 | class_path: probts.data.data_manager.DataManager 37 | init_args: 38 | dataset: etth1 39 | split_val: true 40 | scaler: identity # identity, standard, temporal 41 | context_length: 96 42 | prediction_length: 96 43 | batch_size: 64 44 | test_batch_size: 64 45 | num_workers: 8 46 | -------------------------------------------------------------------------------- /config/ltsf/etth1/patchtst.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | accumulate_grad_batches: 4 13 | model: 14 | forecaster: 15 | class_path: probts.model.forecaster.point_forecaster.PatchTST 16 | init_args: 17 | stride: 8 18 | patch_len: 16 19 | dropout: 0.3 20 | f_hidden_size: 16 21 | n_layers: 3 22 | n_heads: 4 23 | fc_dropout: 0.2 24 | head_dropout: 0 25 | individual: true 26 | learning_rate: 0.0001 27 | quantiles_num: 20 28 | data: 29 | data_manager: 30 | class_path: probts.data.data_manager.DataManager 31 | init_args: 32 | dataset: etth1 33 | split_val: true 34 | scaler: standard # identity, standard, temporal 35 | context_length: 96 36 | prediction_length: 96 37 | batch_size: 32 38 | test_batch_size: 32 39 | num_workers: 8 -------------------------------------------------------------------------------- /config/ltsf/etth1/timegrad.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | accumulate_grad_batches: 4 13 | model: 14 | forecaster: 15 | class_path: probts.model.forecaster.prob_forecaster.TimeGrad 16 | init_args: 17 | loss_type: l2 18 | diff_steps: 100 19 | beta_end: 0.1 20 | beta_schedule: linear 21 | conditional_length: 200 22 | enc_hidden_size: 128 23 | enc_num_layers: 3 24 | enc_dropout: 0.1 25 | use_lags: true 26 | use_feat_idx_emb: true 27 | use_time_feat: true 28 | feat_idx_emb_dim: 1 29 | use_scaling: true 30 | num_samples: 100 31 | learning_rate: 0.001 32 | quantiles_num: 20 33 | data: 34 | data_manager: 35 | class_path: probts.data.data_manager.DataManager 36 | init_args: 37 | dataset: etth1 38 | split_val: true 39 | scaler: identity # identity, standard, temporal 40 | context_length: 96 41 | prediction_length: 96 42 | batch_size: 16 43 | test_batch_size: 16 44 | num_workers: 8 45 | -------------------------------------------------------------------------------- /config/ltsf/etth2/csdi.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 800 10 | log_every_n_steps: 1 11 | check_val_every_n_epoch: 2 12 | default_root_dir: ./results 13 | accumulate_grad_batches: 8 14 | model: 15 | forecaster: 16 | class_path: probts.model.forecaster.prob_forecaster.CSDI 17 | init_args: 18 | emb_time_dim: 128 19 | emb_feature_dim: 16 20 | channels: 64 21 | n_layers: 4 22 | num_heads: 8 23 | num_steps: 50 24 | diffusion_embedding_dim: 128 25 | beta_start: 0.001 26 | beta_end: 0.5 27 | sample_size: 64 28 | linear_trans: false 29 | use_lags: false 30 | use_feat_idx_emb: false 31 | use_time_feat: false 32 | feat_idx_emb_dim: 1 33 | num_samples: 100 34 | learning_rate: 0.001 35 | quantiles_num: 20 36 | data: 37 | data_manager: 38 | class_path: probts.data.data_manager.DataManager 39 | init_args: 40 | dataset: etth2 41 | split_val: true 42 | scaler: standard # identity, standard, temporal 43 | context_length: 96 44 | prediction_length: 96 45 | batch_size: 8 46 | test_batch_size: 8 47 | num_workers: 8 48 | -------------------------------------------------------------------------------- /config/ltsf/etth2/dlinear.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | accumulate_grad_batches: 1 12 | default_root_dir: ./results 13 | model: 14 | forecaster: 15 | class_path: probts.model.forecaster.point_forecaster.DLinear 16 | init_args: 17 | individual: false 18 | kernel_size: 25 19 | use_lags: false 20 | use_feat_idx_emb: false 21 | use_time_feat: false 22 | learning_rate: 0.05 23 | quantiles_num: 20 24 | data: 25 | data_manager: 26 | class_path: probts.data.data_manager.DataManager 27 | init_args: 28 | dataset: etth2 29 | split_val: true 30 | scaler: standard # identity, standard, temporal 31 | context_length: 96 32 | prediction_length: 96 33 | batch_size: 32 34 | test_batch_size: 32 35 | num_workers: 8 36 | -------------------------------------------------------------------------------- /config/ltsf/etth2/gru_nvp.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 400 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | accumulate_grad_batches: 4 13 | model: 14 | forecaster: 15 | class_path: probts.model.forecaster.prob_forecaster.GRU_NVP 16 | init_args: 17 | enc_hidden_size: 64 18 | enc_num_layers: 4 19 | enc_dropout: 0.1 20 | n_blocks: 2 21 | hidden_size: 128 22 | n_hidden: 3 23 | batch_norm: true 24 | conditional_length: 200 25 | dequantize: false 26 | use_lags: true 27 | use_feat_idx_emb: true 28 | use_time_feat: true 29 | feat_idx_emb_dim: 1 30 | use_scaling: true 31 | num_samples: 100 32 | learning_rate: 0.001 33 | quantiles_num: 20 34 | data: 35 | data_manager: 36 | class_path: probts.data.data_manager.DataManager 37 | init_args: 38 | dataset: etth2 39 | path: /home/covpreduser/Blob/v-jiawezhang/data/all_datasets/ 40 | split_val: true 41 | scaler: identity # identity, standard, temporal 42 | context_length: 96 43 | prediction_length: 96 44 | batch_size: 16 45 | test_batch_size: 16 46 | num_workers: 8 47 | -------------------------------------------------------------------------------- /config/ltsf/etth2/patchtst.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | accumulate_grad_batches: 1 13 | model: 14 | forecaster: 15 | class_path: probts.model.forecaster.point_forecaster.PatchTST 16 | init_args: 17 | stride: 8 18 | patch_len: 16 19 | dropout: 0.3 20 | f_hidden_size: 16 21 | d_ff: 128 22 | n_layers: 3 23 | n_heads: 4 24 | fc_dropout: 0.2 25 | head_dropout: 0 26 | individual: false 27 | learning_rate: 0.0001 28 | quantiles_num: 20 29 | data: 30 | data_manager: 31 | class_path: probts.data.data_manager.DataManager 32 | init_args: 33 | dataset: etth2 34 | split_val: true 35 | scaler: standard # identity, standard, temporal 36 | context_length: 96 37 | prediction_length: 96 38 | batch_size: 32 39 | test_batch_size: 32 40 | num_workers: 8 -------------------------------------------------------------------------------- /config/ltsf/etth2/timegrad.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | accumulate_grad_batches: 4 13 | model: 14 | forecaster: 15 | class_path: probts.model.forecaster.prob_forecaster.TimeGrad 16 | init_args: 17 | loss_type: l2 18 | diff_steps: 100 19 | beta_end: 0.1 20 | beta_schedule: linear 21 | conditional_length: 100 22 | enc_hidden_size: 64 23 | enc_num_layers: 4 24 | enc_dropout: 0.1 25 | use_lags: true 26 | use_feat_idx_emb: true 27 | use_time_feat: true 28 | feat_idx_emb_dim: 1 29 | use_scaling: true 30 | num_samples: 100 31 | learning_rate: 0.001 32 | quantiles_num: 20 33 | data: 34 | data_manager: 35 | class_path: probts.data.data_manager.DataManager 36 | init_args: 37 | dataset: etth2 38 | split_val: true 39 | scaler: identity # identity, standard, temporal 40 | context_length: 96 41 | prediction_length: 96 42 | batch_size: 16 43 | test_batch_size: 16 44 | num_workers: 8 45 | -------------------------------------------------------------------------------- /config/ltsf/ettm1/csdi.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 800 10 | log_every_n_steps: 1 11 | check_val_every_n_epoch: 2 12 | default_root_dir: ./results 13 | accumulate_grad_batches: 8 14 | model: 15 | forecaster: 16 | class_path: probts.model.forecaster.prob_forecaster.CSDI 17 | init_args: 18 | emb_time_dim: 128 19 | emb_feature_dim: 16 20 | channels: 64 21 | n_layers: 4 22 | num_heads: 8 23 | num_steps: 50 24 | diffusion_embedding_dim: 128 25 | beta_start: 0.001 26 | beta_end: 0.5 27 | sample_size: 64 28 | linear_trans: false 29 | use_lags: false 30 | use_feat_idx_emb: false 31 | use_time_feat: false 32 | feat_idx_emb_dim: 1 33 | num_samples: 100 34 | learning_rate: 0.001 35 | quantiles_num: 20 36 | data: 37 | data_manager: 38 | class_path: probts.data.data_manager.DataManager 39 | init_args: 40 | dataset: ettm1 41 | split_val: true 42 | scaler: standard # identity, standard, temporal 43 | context_length: 96 44 | prediction_length: 96 45 | batch_size: 8 46 | test_batch_size: 8 47 | num_workers: 8 48 | -------------------------------------------------------------------------------- /config/ltsf/ettm1/dlinear.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | accumulate_grad_batches: 1 12 | default_root_dir: ./results 13 | model: 14 | forecaster: 15 | class_path: probts.model.forecaster.point_forecaster.DLinear 16 | init_args: 17 | individual: true 18 | kernel_size: 25 19 | use_lags: false 20 | use_feat_idx_emb: false 21 | use_time_feat: false 22 | learning_rate: 0.0001 23 | quantiles_num: 20 24 | data: 25 | data_manager: 26 | class_path: probts.data.data_manager.DataManager 27 | init_args: 28 | dataset: ettm1 29 | split_val: true 30 | scaler: standard # identity, standard, temporal 31 | context_length: 96 32 | prediction_length: 96 33 | batch_size: 32 34 | test_batch_size: 32 35 | num_workers: 8 36 | -------------------------------------------------------------------------------- /config/ltsf/ettm1/gru_nvp.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 400 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | accumulate_grad_batches: 4 13 | model: 14 | forecaster: 15 | class_path: probts.model.forecaster.prob_forecaster.GRU_NVP 16 | init_args: 17 | enc_hidden_size: 64 18 | enc_num_layers: 2 19 | enc_dropout: 0.1 20 | n_blocks: 4 21 | hidden_size: 64 22 | n_hidden: 3 23 | batch_norm: false 24 | conditional_length: 200 25 | dequantize: false 26 | use_lags: true 27 | use_feat_idx_emb: true 28 | use_time_feat: true 29 | feat_idx_emb_dim: 1 30 | use_scaling: true 31 | num_samples: 100 32 | learning_rate: 0.001 33 | quantiles_num: 20 34 | data: 35 | data_manager: 36 | class_path: probts.data.data_manager.DataManager 37 | init_args: 38 | dataset: ettm1 39 | split_val: true 40 | scaler: identity # identity, standard, temporal 41 | context_length: 96 42 | prediction_length: 96 43 | batch_size: 16 44 | test_batch_size: 16 45 | num_workers: 8 46 | -------------------------------------------------------------------------------- /config/ltsf/ettm1/patchtst.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | accumulate_grad_batches: 1 13 | model: 14 | forecaster: 15 | class_path: probts.model.forecaster.point_forecaster.PatchTST 16 | init_args: 17 | stride: 8 18 | patch_len: 16 19 | dropout: 0.2 20 | f_hidden_size: 128 21 | n_layers: 3 22 | n_heads: 16 23 | fc_dropout: 0.2 24 | head_dropout: 0 25 | individual: true 26 | learning_rate: 0.0001 27 | quantiles_num: 20 28 | data: 29 | data_manager: 30 | class_path: probts.data.data_manager.DataManager 31 | init_args: 32 | dataset: ettm1 33 | split_val: true 34 | scaler: standard # identity, standard, temporal 35 | context_length: 96 36 | prediction_length: 96 37 | batch_size: 32 38 | test_batch_size: 32 39 | num_workers: 8 -------------------------------------------------------------------------------- /config/ltsf/ettm1/timegrad.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | accumulate_grad_batches: 4 13 | model: 14 | forecaster: 15 | class_path: probts.model.forecaster.prob_forecaster.TimeGrad 16 | init_args: 17 | loss_type: l2 18 | diff_steps: 100 19 | beta_end: 0.1 20 | beta_schedule: linear 21 | conditional_length: 200 22 | enc_hidden_size: 128 23 | enc_num_layers: 3 24 | enc_dropout: 0.1 25 | use_lags: true 26 | use_feat_idx_emb: true 27 | use_time_feat: true 28 | feat_idx_emb_dim: 1 29 | use_scaling: true 30 | num_samples: 100 31 | learning_rate: 0.001 32 | quantiles_num: 20 33 | data: 34 | data_manager: 35 | class_path: probts.data.data_manager.DataManager 36 | init_args: 37 | dataset: ettm1 38 | split_val: true 39 | scaler: identity # identity, standard, temporal 40 | context_length: 96 41 | prediction_length: 96 42 | batch_size: 16 43 | test_batch_size: 16 44 | num_workers: 8 45 | -------------------------------------------------------------------------------- /config/ltsf/ettm2/csdi.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 800 10 | log_every_n_steps: 1 11 | check_val_every_n_epoch: 2 12 | default_root_dir: ./results 13 | accumulate_grad_batches: 8 14 | model: 15 | forecaster: 16 | class_path: probts.model.forecaster.prob_forecaster.CSDI 17 | init_args: 18 | emb_time_dim: 128 19 | emb_feature_dim: 16 20 | channels: 64 21 | n_layers: 4 22 | num_heads: 8 23 | num_steps: 50 24 | diffusion_embedding_dim: 128 25 | beta_start: 0.001 26 | beta_end: 0.5 27 | sample_size: 64 28 | linear_trans: false 29 | use_lags: false 30 | use_feat_idx_emb: false 31 | use_time_feat: false 32 | feat_idx_emb_dim: 1 33 | num_samples: 100 34 | learning_rate: 0.001 35 | quantiles_num: 20 36 | data: 37 | data_manager: 38 | class_path: probts.data.data_manager.DataManager 39 | init_args: 40 | dataset: ettm2 41 | split_val: true 42 | scaler: standard # identity, standard, temporal 43 | context_length: 96 44 | prediction_length: 96 45 | batch_size: 8 46 | test_batch_size: 8 47 | num_workers: 8 48 | -------------------------------------------------------------------------------- /config/ltsf/ettm2/dlinear.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | accumulate_grad_batches: 1 12 | default_root_dir: ./results 13 | model: 14 | forecaster: 15 | class_path: probts.model.forecaster.point_forecaster.DLinear 16 | init_args: 17 | individual: false 18 | kernel_size: 25 19 | use_lags: false 20 | use_feat_idx_emb: false 21 | use_time_feat: false 22 | learning_rate: 0.001 23 | quantiles_num: 20 24 | data: 25 | data_manager: 26 | class_path: probts.data.data_manager.DataManager 27 | init_args: 28 | dataset: ettm2 29 | split_val: true 30 | scaler: standard # identity, standard, temporal 31 | context_length: 96 32 | prediction_length: 96 33 | batch_size: 32 34 | test_batch_size: 32 35 | num_workers: 8 36 | -------------------------------------------------------------------------------- /config/ltsf/ettm2/gru_nvp.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 400 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | accumulate_grad_batches: 4 13 | model: 14 | forecaster: 15 | class_path: probts.model.forecaster.prob_forecaster.GRU_NVP 16 | init_args: 17 | enc_hidden_size: 64 18 | enc_num_layers: 4 19 | enc_dropout: 0.1 20 | n_blocks: 2 21 | hidden_size: 128 22 | n_hidden: 3 23 | batch_norm: false 24 | conditional_length: 200 25 | dequantize: false 26 | use_lags: true 27 | use_feat_idx_emb: true 28 | use_time_feat: true 29 | feat_idx_emb_dim: 1 30 | use_scaling: true 31 | num_samples: 100 32 | learning_rate: 0.001 33 | quantiles_num: 20 34 | data: 35 | data_manager: 36 | class_path: probts.data.data_manager.DataManager 37 | init_args: 38 | dataset: ettm2 39 | split_val: true 40 | scaler: identity # identity, standard, temporal 41 | context_length: 96 42 | prediction_length: 96 43 | batch_size: 16 44 | test_batch_size: 16 45 | num_workers: 8 46 | -------------------------------------------------------------------------------- /config/ltsf/ettm2/patchtst.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | accumulate_grad_batches: 1 13 | model: 14 | forecaster: 15 | class_path: probts.model.forecaster.point_forecaster.PatchTST 16 | init_args: 17 | stride: 8 18 | patch_len: 16 19 | dropout: 0.2 20 | f_hidden_size: 128 21 | n_layers: 3 22 | n_heads: 16 23 | fc_dropout: 0.2 24 | head_dropout: 0 25 | individual: true 26 | learning_rate: 0.0001 27 | quantiles_num: 20 28 | data: 29 | data_manager: 30 | class_path: probts.data.data_manager.DataManager 31 | init_args: 32 | dataset: ettm2 33 | split_val: true 34 | scaler: standard # identity, standard, temporal 35 | context_length: 96 36 | prediction_length: 96 37 | batch_size: 32 38 | test_batch_size: 32 39 | num_workers: 8 -------------------------------------------------------------------------------- /config/ltsf/ettm2/timegrad.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | accumulate_grad_batches: 4 13 | model: 14 | forecaster: 15 | class_path: probts.model.forecaster.prob_forecaster.TimeGrad 16 | init_args: 17 | loss_type: l2 18 | diff_steps: 100 19 | beta_end: 0.1 20 | beta_schedule: linear 21 | conditional_length: 200 22 | enc_hidden_size: 64 23 | enc_num_layers: 2 24 | enc_dropout: 0.1 25 | use_lags: true 26 | use_feat_idx_emb: true 27 | use_time_feat: true 28 | feat_idx_emb_dim: 1 29 | use_scaling: true 30 | num_samples: 100 31 | learning_rate: 0.001 32 | quantiles_num: 20 33 | data: 34 | data_manager: 35 | class_path: probts.data.data_manager.DataManager 36 | init_args: 37 | dataset: ettm2 38 | split_val: true 39 | scaler: identity # identity, standard, temporal 40 | context_length: 96 41 | prediction_length: 96 42 | batch_size: 16 43 | test_batch_size: 16 44 | num_workers: 8 45 | -------------------------------------------------------------------------------- /config/ltsf/exchange_ltsf/csdi.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 800 10 | log_every_n_steps: 1 11 | check_val_every_n_epoch: 2 12 | default_root_dir: ./results 13 | accumulate_grad_batches: 8 14 | model: 15 | forecaster: 16 | class_path: probts.model.forecaster.prob_forecaster.CSDI 17 | init_args: 18 | emb_time_dim: 128 19 | emb_feature_dim: 16 20 | channels: 64 21 | n_layers: 4 22 | num_heads: 8 23 | num_steps: 50 24 | diffusion_embedding_dim: 128 25 | beta_start: 0.001 26 | beta_end: 0.5 27 | sample_size: 64 28 | linear_trans: false 29 | use_lags: false 30 | use_feat_idx_emb: false 31 | use_time_feat: false 32 | feat_idx_emb_dim: 1 33 | num_samples: 100 34 | learning_rate: 0.001 35 | quantiles_num: 20 36 | data: 37 | data_manager: 38 | class_path: probts.data.data_manager.DataManager 39 | init_args: 40 | dataset: exchange_ltsf 41 | split_val: true 42 | scaler: standard # identity, standard, temporal 43 | context_length: 96 44 | prediction_length: 96 45 | batch_size: 8 46 | test_batch_size: 8 47 | num_workers: 8 48 | -------------------------------------------------------------------------------- /config/ltsf/exchange_ltsf/dlinear.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | accumulate_grad_batches: 1 12 | default_root_dir: ./results 13 | model: 14 | forecaster: 15 | class_path: probts.model.forecaster.point_forecaster.DLinear 16 | init_args: 17 | individual: true 18 | kernel_size: 25 19 | use_lags: false 20 | use_feat_idx_emb: false 21 | use_time_feat: false 22 | learning_rate: 0.0005 23 | quantiles_num: 20 24 | data: 25 | data_manager: 26 | class_path: probts.data.data_manager.DataManager 27 | init_args: 28 | dataset: exchange_ltsf 29 | split_val: true 30 | scaler: standard # identity, standard, temporal 31 | context_length: 96 32 | prediction_length: 96 33 | batch_size: 32 34 | test_batch_size: 32 35 | num_workers: 8 36 | -------------------------------------------------------------------------------- /config/ltsf/exchange_ltsf/gru_nvp.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 400 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | accumulate_grad_batches: 4 13 | model: 14 | forecaster: 15 | class_path: probts.model.forecaster.prob_forecaster.GRU_NVP 16 | init_args: 17 | enc_hidden_size: 128 18 | enc_num_layers: 2 19 | enc_dropout: 0.1 20 | n_blocks: 2 21 | hidden_size: 128 22 | n_hidden: 3 23 | batch_norm: false 24 | conditional_length: 200 25 | dequantize: false 26 | use_lags: true 27 | use_feat_idx_emb: true 28 | use_time_feat: true 29 | feat_idx_emb_dim: 1 30 | use_scaling: true 31 | num_samples: 100 32 | learning_rate: 0.001 33 | quantiles_num: 20 34 | data: 35 | data_manager: 36 | class_path: probts.data.data_manager.DataManager 37 | init_args: 38 | dataset: exchange_ltsf 39 | split_val: true 40 | scaler: identity # identity, standard, temporal 41 | context_length: 96 42 | prediction_length: 96 43 | batch_size: 16 44 | test_batch_size: 16 45 | num_workers: 8 46 | -------------------------------------------------------------------------------- /config/ltsf/exchange_ltsf/patchtst.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | accumulate_grad_batches: 1 13 | model: 14 | forecaster: 15 | class_path: probts.model.forecaster.point_forecaster.PatchTST 16 | init_args: 17 | stride: 8 18 | patch_len: 16 19 | dropout: 0.2 20 | f_hidden_size: 16 21 | n_layers: 3 22 | n_heads: 4 23 | fc_dropout: 0.2 24 | head_dropout: 0 25 | individual: true 26 | learning_rate: 0.0001 27 | quantiles_num: 20 28 | data: 29 | data_manager: 30 | class_path: probts.data.data_manager.DataManager 31 | init_args: 32 | dataset: exchange_ltsf 33 | split_val: true 34 | scaler: standard # identity, standard, temporal 35 | context_length: 96 36 | prediction_length: 96 37 | batch_size: 32 38 | test_batch_size: 32 39 | num_workers: 8 -------------------------------------------------------------------------------- /config/ltsf/exchange_ltsf/timegrad.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 400 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | accumulate_grad_batches: 4 13 | model: 14 | forecaster: 15 | class_path: probts.model.forecaster.prob_forecaster.TimeGrad 16 | init_args: 17 | loss_type: l2 18 | diff_steps: 100 19 | beta_end: 0.1 20 | beta_schedule: linear 21 | conditional_length: 200 22 | enc_hidden_size: 64 23 | enc_num_layers: 4 24 | enc_dropout: 0.1 25 | use_lags: true 26 | use_feat_idx_emb: true 27 | use_time_feat: true 28 | feat_idx_emb_dim: 1 29 | use_scaling: true 30 | num_samples: 100 31 | learning_rate: 0.001 32 | quantiles_num: 20 33 | data: 34 | data_manager: 35 | class_path: probts.data.data_manager.DataManager 36 | init_args: 37 | dataset: exchange_ltsf 38 | split_val: true 39 | scaler: identity # identity, standard, temporal 40 | context_length: 96 41 | prediction_length: 96 42 | batch_size: 16 43 | test_batch_size: 16 44 | num_workers: 8 45 | -------------------------------------------------------------------------------- /config/ltsf/illness_ltsf/csdi.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 800 10 | log_every_n_steps: 1 11 | check_val_every_n_epoch: 2 12 | default_root_dir: ./results 13 | accumulate_grad_batches: 8 14 | model: 15 | forecaster: 16 | class_path: probts.model.forecaster.prob_forecaster.CSDI 17 | init_args: 18 | emb_time_dim: 128 19 | emb_feature_dim: 16 20 | channels: 64 21 | n_layers: 4 22 | num_heads: 8 23 | num_steps: 50 24 | diffusion_embedding_dim: 128 25 | beta_start: 0.001 26 | beta_end: 0.5 27 | sample_size: 64 28 | linear_trans: false 29 | use_lags: false 30 | use_feat_idx_emb: false 31 | use_time_feat: false 32 | feat_idx_emb_dim: 1 33 | num_samples: 100 34 | learning_rate: 0.001 35 | quantiles_num: 20 36 | data: 37 | data_manager: 38 | class_path: probts.data.data_manager.DataManager 39 | init_args: 40 | dataset: illness_ltsf 41 | split_val: true 42 | scaler: standard # identity, standard, temporal 43 | context_length: 96 44 | prediction_length: 96 45 | batch_size: 8 46 | test_batch_size: 8 47 | num_workers: 8 48 | -------------------------------------------------------------------------------- /config/ltsf/illness_ltsf/dlinear.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | accumulate_grad_batches: 1 12 | default_root_dir: ./results 13 | model: 14 | forecaster: 15 | class_path: probts.model.forecaster.point_forecaster.DLinear 16 | init_args: 17 | individual: false 18 | kernel_size: 25 19 | use_lags: false 20 | use_feat_idx_emb: false 21 | use_time_feat: false 22 | learning_rate: 0.01 23 | quantiles_num: 20 24 | data: 25 | data_manager: 26 | class_path: probts.data.data_manager.DataManager 27 | init_args: 28 | dataset: illness_ltsf 29 | split_val: true 30 | scaler: standard # identity, standard, temporal 31 | context_length: 36 32 | prediction_length: 36 33 | batch_size: 32 34 | test_batch_size: 32 35 | num_workers: 8 36 | -------------------------------------------------------------------------------- /config/ltsf/illness_ltsf/gru_nvp.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 400 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | accumulate_grad_batches: 4 13 | model: 14 | forecaster: 15 | class_path: probts.model.forecaster.prob_forecaster.GRU_NVP 16 | init_args: 17 | enc_hidden_size: 64 18 | enc_num_layers: 4 19 | enc_dropout: 0.1 20 | n_blocks: 4 21 | hidden_size: 128 22 | n_hidden: 2 23 | batch_norm: false 24 | conditional_length: 200 25 | dequantize: false 26 | use_lags: true 27 | use_feat_idx_emb: true 28 | use_time_feat: true 29 | feat_idx_emb_dim: 1 30 | use_scaling: true 31 | num_samples: 100 32 | learning_rate: 0.001 33 | quantiles_num: 20 34 | data: 35 | data_manager: 36 | class_path: probts.data.data_manager.DataManager 37 | init_args: 38 | dataset: illness_ltsf 39 | split_val: true 40 | scaler: identity # identity, standard, temporal 41 | context_length: 36 42 | prediction_length: 36 43 | batch_size: 16 44 | test_batch_size: 16 45 | num_workers: 8 46 | -------------------------------------------------------------------------------- /config/ltsf/illness_ltsf/patchtst.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | accumulate_grad_batches: 1 13 | model: 14 | forecaster: 15 | class_path: probts.model.forecaster.point_forecaster.PatchTST 16 | init_args: 17 | stride: 2 18 | patch_len: 24 19 | dropout: 0.3 20 | f_hidden_size: 16 21 | n_layers: 3 22 | n_heads: 4 23 | fc_dropout: 0.3 24 | head_dropout: 0 25 | individual: true 26 | learning_rate: 0.0025 27 | quantiles_num: 20 28 | data: 29 | data_manager: 30 | class_path: probts.data.data_manager.DataManager 31 | init_args: 32 | dataset: illness_ltsf 33 | path: /home/covpreduser/Blob/v-jiawezhang/data/all_datasets/ 34 | split_val: true 35 | scaler: standard # identity, standard, temporal 36 | context_length: 36 37 | prediction_length: 36 38 | batch_size: 32 39 | test_batch_size: 32 40 | num_workers: 8 -------------------------------------------------------------------------------- /config/ltsf/illness_ltsf/timegrad.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | accumulate_grad_batches: 4 13 | model: 14 | forecaster: 15 | class_path: probts.model.forecaster.prob_forecaster.TimeGrad 16 | init_args: 17 | loss_type: l2 18 | diff_steps: 100 19 | beta_end: 0.1 20 | beta_schedule: linear 21 | conditional_length: 200 22 | enc_hidden_size: 64 23 | enc_num_layers: 2 24 | enc_dropout: 0.1 25 | use_lags: true 26 | use_feat_idx_emb: true 27 | use_time_feat: true 28 | feat_idx_emb_dim: 1 29 | use_scaling: true 30 | num_samples: 100 31 | learning_rate: 0.001 32 | quantiles_num: 20 33 | data: 34 | data_manager: 35 | class_path: probts.data.data_manager.DataManager 36 | init_args: 37 | dataset: illness_ltsf 38 | split_val: true 39 | scaler: identity # identity, standard, temporal 40 | context_length: 36 41 | prediction_length: 36 42 | batch_size: 16 43 | test_batch_size: 16 44 | num_workers: 8 45 | -------------------------------------------------------------------------------- /config/ltsf/traffic_ltsf/csdi.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 800 10 | log_every_n_steps: 1 11 | check_val_every_n_epoch: 3 12 | default_root_dir: ./results 13 | accumulate_grad_batches: 8 14 | model: 15 | forecaster: 16 | class_path: probts.model.forecaster.prob_forecaster.CSDI 17 | init_args: 18 | emb_time_dim: 64 19 | emb_feature_dim: 8 20 | channels: 64 21 | n_layers: 4 22 | num_heads: 8 23 | num_steps: 50 24 | diffusion_embedding_dim: 64 25 | beta_start: 0.001 26 | beta_end: 0.5 27 | sample_size: 16 28 | linear_trans: false 29 | use_lags: false 30 | use_feat_idx_emb: false 31 | use_time_feat: false 32 | feat_idx_emb_dim: 1 33 | num_samples: 100 34 | learning_rate: 0.001 35 | quantiles_num: 20 36 | data: 37 | data_manager: 38 | class_path: probts.data.data_manager.DataManager 39 | init_args: 40 | dataset: traffic_ltsf 41 | split_val: true 42 | scaler: standard # identity, standard, temporal 43 | context_length: 96 44 | prediction_length: 96 45 | batch_size: 4 46 | test_batch_size: 4 47 | num_workers: 8 48 | -------------------------------------------------------------------------------- /config/ltsf/traffic_ltsf/dlinear.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 400 10 | log_every_n_steps: 1 11 | accumulate_grad_batches: 4 12 | default_root_dir: ./results 13 | model: 14 | forecaster: 15 | class_path: probts.model.forecaster.point_forecaster.DLinear 16 | init_args: 17 | individual: false 18 | kernel_size: 25 19 | use_lags: false 20 | use_feat_idx_emb: false 21 | use_time_feat: false 22 | learning_rate: 0.05 23 | quantiles_num: 20 24 | data: 25 | data_manager: 26 | class_path: probts.data.data_manager.DataManager 27 | init_args: 28 | dataset: traffic_ltsf 29 | split_val: true 30 | scaler: standard # identity, standard, temporal 31 | context_length: 96 32 | prediction_length: 96 33 | batch_size: 8 34 | test_batch_size: 8 35 | num_workers: 8 36 | -------------------------------------------------------------------------------- /config/ltsf/traffic_ltsf/gru_nvp.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 400 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | accumulate_grad_batches: 4 13 | model: 14 | forecaster: 15 | class_path: probts.model.forecaster.prob_forecaster.GRU_NVP 16 | init_args: 17 | enc_hidden_size: 128 18 | enc_num_layers: 3 19 | enc_dropout: 0.1 20 | n_blocks: 4 21 | hidden_size: 128 22 | n_hidden: 3 23 | batch_norm: true 24 | conditional_length: 200 25 | dequantize: false 26 | use_lags: true 27 | use_feat_idx_emb: true 28 | use_time_feat: true 29 | feat_idx_emb_dim: 1 30 | use_scaling: true 31 | num_samples: 100 32 | learning_rate: 0.001 33 | quantiles_num: 20 34 | data: 35 | data_manager: 36 | class_path: probts.data.data_manager.DataManager 37 | init_args: 38 | dataset: traffic_ltsf 39 | split_val: true 40 | scaler: identity # identity, standard, temporal 41 | context_length: 96 42 | prediction_length: 96 43 | batch_size: 16 44 | test_batch_size: 16 45 | num_workers: 8 46 | -------------------------------------------------------------------------------- /config/ltsf/traffic_ltsf/patchtst.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 300 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | accumulate_grad_batches: 3 13 | model: 14 | forecaster: 15 | class_path: probts.model.forecaster.point_forecaster.PatchTST 16 | init_args: 17 | stride: 8 18 | patch_len: 16 19 | dropout: 0.2 20 | f_hidden_size: 128 21 | n_layers: 3 22 | n_heads: 16 23 | fc_dropout: 0.2 24 | head_dropout: 0 25 | individual: false 26 | learning_rate: 0.0001 27 | quantiles_num: 20 28 | data: 29 | data_manager: 30 | class_path: probts.data.data_manager.DataManager 31 | init_args: 32 | dataset: traffic_ltsf 33 | split_val: true 34 | scaler: standard # identity, standard, temporal 35 | context_length: 96 36 | prediction_length: 96 37 | batch_size: 8 38 | test_batch_size: 8 39 | num_workers: 8 -------------------------------------------------------------------------------- /config/ltsf/traffic_ltsf/timegrad.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | accumulate_grad_batches: 4 13 | model: 14 | forecaster: 15 | class_path: probts.model.forecaster.prob_forecaster.TimeGrad 16 | init_args: 17 | loss_type: l2 18 | diff_steps: 100 19 | beta_end: 0.1 20 | beta_schedule: linear 21 | conditional_length: 200 22 | enc_hidden_size: 128 23 | enc_num_layers: 3 24 | enc_dropout: 0.1 25 | use_lags: true 26 | use_feat_idx_emb: true 27 | use_time_feat: true 28 | feat_idx_emb_dim: 1 29 | use_scaling: true 30 | num_samples: 100 31 | learning_rate: 0.001 32 | quantiles_num: 20 33 | data: 34 | data_manager: 35 | class_path: probts.data.data_manager.DataManager 36 | init_args: 37 | dataset: traffic_ltsf 38 | split_val: true 39 | scaler: identity # identity, standard, temporal 40 | context_length: 96 41 | prediction_length: 96 42 | batch_size: 16 43 | test_batch_size: 16 44 | num_workers: 8 45 | -------------------------------------------------------------------------------- /config/ltsf/weather_ltsf/csdi.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 800 10 | log_every_n_steps: 1 11 | check_val_every_n_epoch: 2 12 | default_root_dir: ./results 13 | accumulate_grad_batches: 8 14 | model: 15 | forecaster: 16 | class_path: probts.model.forecaster.prob_forecaster.CSDI 17 | init_args: 18 | emb_time_dim: 128 19 | emb_feature_dim: 16 20 | channels: 64 21 | n_layers: 4 22 | num_heads: 8 23 | num_steps: 50 24 | diffusion_embedding_dim: 128 25 | beta_start: 0.001 26 | beta_end: 0.5 27 | sample_size: 64 28 | linear_trans: false 29 | use_lags: false 30 | use_feat_idx_emb: false 31 | use_time_feat: false 32 | feat_idx_emb_dim: 1 33 | num_samples: 100 34 | learning_rate: 0.001 35 | quantiles_num: 20 36 | data: 37 | data_manager: 38 | class_path: probts.data.data_manager.DataManager 39 | init_args: 40 | dataset: weather_ltsf 41 | split_val: true 42 | scaler: standard # identity, standard, temporal 43 | context_length: 96 44 | prediction_length: 96 45 | batch_size: 8 46 | test_batch_size: 8 47 | num_workers: 8 48 | -------------------------------------------------------------------------------- /config/ltsf/weather_ltsf/dlinear.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | accumulate_grad_batches: 1 12 | default_root_dir: ./results 13 | model: 14 | forecaster: 15 | class_path: probts.model.forecaster.point_forecaster.DLinear 16 | init_args: 17 | individual: false 18 | kernel_size: 25 19 | use_lags: false 20 | use_feat_idx_emb: false 21 | use_time_feat: false 22 | learning_rate: 0.0001 23 | quantiles_num: 20 24 | data: 25 | data_manager: 26 | class_path: probts.data.data_manager.DataManager 27 | init_args: 28 | dataset: weather_ltsf 29 | split_val: true 30 | scaler: standard # identity, standard, temporal 31 | context_length: 96 32 | prediction_length: 96 33 | batch_size: 32 34 | test_batch_size: 32 35 | num_workers: 8 36 | -------------------------------------------------------------------------------- /config/ltsf/weather_ltsf/gru_nvp.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 400 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | accumulate_grad_batches: 4 13 | model: 14 | forecaster: 15 | class_path: probts.model.forecaster.prob_forecaster.GRU_NVP 16 | init_args: 17 | enc_hidden_size: 64 18 | enc_num_layers: 4 19 | enc_dropout: 0.1 20 | n_blocks: 4 21 | hidden_size: 128 22 | n_hidden: 3 23 | batch_norm: false 24 | conditional_length: 200 25 | dequantize: false 26 | use_lags: true 27 | use_feat_idx_emb: true 28 | use_time_feat: true 29 | feat_idx_emb_dim: 1 30 | use_scaling: true 31 | num_samples: 100 32 | learning_rate: 0.001 33 | quantiles_num: 20 34 | data: 35 | data_manager: 36 | class_path: probts.data.data_manager.DataManager 37 | init_args: 38 | dataset: weather_ltsf 39 | split_val: true 40 | scaler: identity # identity, standard, temporal 41 | context_length: 96 42 | prediction_length: 96 43 | batch_size: 16 44 | test_batch_size: 16 45 | num_workers: 8 46 | -------------------------------------------------------------------------------- /config/ltsf/weather_ltsf/patchtst.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | accumulate_grad_batches: 1 13 | model: 14 | forecaster: 15 | class_path: probts.model.forecaster.point_forecaster.PatchTST 16 | init_args: 17 | stride: 8 18 | patch_len: 16 19 | dropout: 0.2 20 | f_hidden_size: 128 21 | n_layers: 3 22 | n_heads: 16 23 | fc_dropout: 0.2 24 | head_dropout: 0 25 | individual: false 26 | learning_rate: 0.0001 27 | quantiles_num: 20 28 | data: 29 | data_manager: 30 | class_path: probts.data.data_manager.DataManager 31 | init_args: 32 | dataset: weather_ltsf 33 | split_val: true 34 | scaler: standard # identity, standard, temporal 35 | context_length: 96 36 | prediction_length: 96 37 | batch_size: 32 38 | test_batch_size: 32 39 | num_workers: 8 -------------------------------------------------------------------------------- /config/ltsf/weather_ltsf/timegrad.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | accumulate_grad_batches: 4 13 | model: 14 | forecaster: 15 | class_path: probts.model.forecaster.prob_forecaster.TimeGrad 16 | init_args: 17 | loss_type: l2 18 | diff_steps: 100 19 | beta_end: 0.1 20 | beta_schedule: linear 21 | conditional_length: 200 22 | enc_hidden_size: 64 23 | enc_num_layers: 4 24 | enc_dropout: 0.1 25 | use_lags: true 26 | use_feat_idx_emb: true 27 | use_time_feat: true 28 | feat_idx_emb_dim: 1 29 | use_scaling: true 30 | num_samples: 100 31 | learning_rate: 0.001 32 | quantiles_num: 20 33 | data: 34 | data_manager: 35 | class_path: probts.data.data_manager.DataManager 36 | init_args: 37 | dataset: weather_ltsf 38 | split_val: true 39 | scaler: identity # identity, standard, temporal 40 | context_length: 96 41 | prediction_length: 96 42 | batch_size: 16 43 | test_batch_size: 16 44 | num_workers: 8 45 | -------------------------------------------------------------------------------- /config/m4/m4_daily/csdi.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 800 10 | log_every_n_steps: 2 11 | check_val_every_n_epoch: 2 12 | default_root_dir: ./results 13 | accumulate_grad_batches: 8 14 | model: 15 | forecaster: 16 | class_path: probts.model.forecaster.prob_forecaster.CSDI 17 | init_args: 18 | emb_time_dim: 32 19 | emb_feature_dim: 4 20 | channels: 16 21 | n_layers: 4 22 | num_heads: 4 23 | num_steps: 50 24 | diffusion_embedding_dim: 32 25 | beta_start: 0.001 26 | beta_end: 0.5 27 | sample_size: 64 28 | linear_trans: false 29 | use_lags: false 30 | use_feat_idx_emb: false 31 | use_time_feat: false 32 | feat_idx_emb_dim: 1 33 | num_samples: 100 34 | learning_rate: 0.001 35 | quantiles_num: 20 36 | data: 37 | data_manager: 38 | class_path: probts.data.data_manager.DataManager 39 | init_args: 40 | dataset: m4_daily 41 | context_length_factor: 3 42 | split_val: true 43 | scaler: standard # identity, standard, temporal 44 | batch_size: 1 45 | test_batch_size: 1 46 | num_workers: 8 47 | -------------------------------------------------------------------------------- /config/m4/m4_daily/dlinear.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 800 10 | log_every_n_steps: 2 11 | default_root_dir: ./results 12 | accumulate_grad_batches: 8 13 | model: 14 | forecaster: 15 | class_path: probts.model.forecaster.point_forecaster.DLinear 16 | init_args: 17 | individual: false 18 | kernel_size: 3 19 | use_lags: false 20 | use_feat_idx_emb: false 21 | use_time_feat: false 22 | learning_rate: 0.001 23 | quantiles_num: 20 24 | data: 25 | data_manager: 26 | class_path: probts.data.data_manager.DataManager 27 | init_args: 28 | dataset: m4_daily 29 | context_length_factor: 3 30 | split_val: true 31 | scaler: standard # identity, standard, temporal 32 | batch_size: 1 33 | test_batch_size: 1 34 | num_workers: 8 35 | -------------------------------------------------------------------------------- /config/m4/m4_daily/gru_nvp.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 800 10 | log_every_n_steps: 2 11 | default_root_dir: ./results 12 | accumulate_grad_batches: 8 13 | model: 14 | forecaster: 15 | class_path: probts.model.forecaster.prob_forecaster.GRU_NVP 16 | init_args: 17 | enc_hidden_size: 40 18 | enc_num_layers: 2 19 | enc_dropout: 0.1 20 | n_blocks: 2 21 | hidden_size: 100 22 | n_hidden: 2 23 | batch_norm: true 24 | conditional_length: 100 25 | dequantize: false 26 | use_lags: false 27 | use_feat_idx_emb: false 28 | use_time_feat: false 29 | feat_idx_emb_dim: 1 30 | use_scaling: true 31 | num_samples: 100 32 | learning_rate: 0.001 33 | quantiles_num: 20 34 | data: 35 | data_manager: 36 | class_path: probts.data.data_manager.DataManager 37 | init_args: 38 | dataset: m4_daily 39 | context_length_factor: 3 40 | split_val: true 41 | scaler: identity # identity, standard, temporal 42 | batch_size: 1 43 | test_batch_size: 1 44 | num_workers: 8 45 | -------------------------------------------------------------------------------- /config/m4/m4_daily/patchtst.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 800 10 | log_every_n_steps: 2 11 | default_root_dir: ./results 12 | accumulate_grad_batches: 8 13 | model: 14 | forecaster: 15 | class_path: probts.model.forecaster.point_forecaster.PatchTST 16 | init_args: 17 | stride: 2 18 | patch_len: 6 19 | dropout: 0.3 20 | f_hidden_size: 32 21 | d_ff: 128 22 | n_layers: 3 23 | n_heads: 8 24 | fc_dropout: 0.2 25 | head_dropout: 0 26 | individual: true 27 | learning_rate: 0.001 28 | quantiles_num: 20 29 | data: 30 | data_manager: 31 | class_path: probts.data.data_manager.DataManager 32 | init_args: 33 | dataset: m4_daily 34 | context_length_factor: 3 35 | split_val: true 36 | scaler: standard # identity, standard, temporal 37 | batch_size: 1 38 | test_batch_size: 128 39 | num_workers: 8 -------------------------------------------------------------------------------- /config/m4/m4_daily/timegrad.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 800 10 | log_every_n_steps: 2 11 | check_val_every_n_epoch: 2 12 | default_root_dir: ./results 13 | accumulate_grad_batches: 8 14 | model: 15 | forecaster: 16 | class_path: probts.model.forecaster.prob_forecaster.TimeGrad 17 | init_args: 18 | loss_type: l2 19 | diff_steps: 50 20 | beta_end: 0.1 21 | beta_schedule: linear 22 | conditional_length: 100 23 | enc_hidden_size: 64 24 | enc_num_layers: 4 25 | enc_dropout: 0.1 26 | use_lags: false 27 | use_feat_idx_emb: false 28 | use_time_feat: false 29 | feat_idx_emb_dim: 1 30 | use_scaling: true 31 | num_samples: 100 32 | learning_rate: 0.001 33 | quantiles_num: 20 34 | data: 35 | data_manager: 36 | class_path: probts.data.data_manager.DataManager 37 | init_args: 38 | dataset: m4_daily 39 | context_length_factor: 3 40 | split_val: true 41 | scaler: identity # identity, standard, temporal 42 | batch_size: 1 43 | test_batch_size: 1 44 | num_workers: 8 45 | -------------------------------------------------------------------------------- /config/m4/m4_weekly/csdi.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 800 10 | log_every_n_steps: 2 11 | check_val_every_n_epoch: 2 12 | default_root_dir: ./results 13 | accumulate_grad_batches: 8 14 | model: 15 | forecaster: 16 | class_path: probts.model.forecaster.prob_forecaster.CSDI 17 | init_args: 18 | emb_time_dim: 32 19 | emb_feature_dim: 4 20 | channels: 16 21 | n_layers: 4 22 | num_heads: 4 23 | num_steps: 50 24 | diffusion_embedding_dim: 32 25 | beta_start: 0.001 26 | beta_end: 0.5 27 | sample_size: 64 28 | linear_trans: false 29 | use_lags: false 30 | use_feat_idx_emb: false 31 | use_time_feat: false 32 | feat_idx_emb_dim: 1 33 | num_samples: 100 34 | learning_rate: 0.001 35 | quantiles_num: 20 36 | data: 37 | data_manager: 38 | class_path: probts.data.data_manager.DataManager 39 | init_args: 40 | dataset: m4_weekly 41 | context_length_factor: 3 42 | split_val: true 43 | scaler: standard # identity, standard, temporal 44 | batch_size: 1 45 | test_batch_size: 1 46 | num_workers: 8 47 | -------------------------------------------------------------------------------- /config/m4/m4_weekly/dlinear.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 800 10 | log_every_n_steps: 2 11 | default_root_dir: ./results 12 | accumulate_grad_batches: 8 13 | model: 14 | forecaster: 15 | class_path: probts.model.forecaster.point_forecaster.DLinear 16 | init_args: 17 | individual: false 18 | kernel_size: 3 19 | use_lags: false 20 | use_feat_idx_emb: false 21 | use_time_feat: false 22 | learning_rate: 0.001 23 | quantiles_num: 20 24 | data: 25 | data_manager: 26 | class_path: probts.data.data_manager.DataManager 27 | init_args: 28 | dataset: m4_weekly 29 | context_length_factor: 3 30 | split_val: true 31 | scaler: standard # identity, standard, temporal 32 | batch_size: 1 33 | test_batch_size: 1 34 | num_workers: 8 35 | -------------------------------------------------------------------------------- /config/m4/m4_weekly/gru_nvp.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 800 10 | log_every_n_steps: 2 11 | default_root_dir: ./results 12 | accumulate_grad_batches: 8 13 | model: 14 | forecaster: 15 | class_path: probts.model.forecaster.prob_forecaster.GRU_NVP 16 | init_args: 17 | enc_hidden_size: 40 18 | enc_num_layers: 2 19 | enc_dropout: 0.1 20 | n_blocks: 2 21 | hidden_size: 100 22 | n_hidden: 2 23 | batch_norm: true 24 | conditional_length: 100 25 | dequantize: false 26 | use_lags: false 27 | use_feat_idx_emb: false 28 | use_time_feat: false 29 | feat_idx_emb_dim: 1 30 | use_scaling: true 31 | num_samples: 100 32 | learning_rate: 0.001 33 | quantiles_num: 20 34 | data: 35 | data_manager: 36 | class_path: probts.data.data_manager.DataManager 37 | init_args: 38 | dataset: m4_weekly 39 | context_length_factor: 3 40 | split_val: true 41 | scaler: identity # identity, standard, temporal 42 | batch_size: 1 43 | test_batch_size: 1 44 | num_workers: 8 45 | -------------------------------------------------------------------------------- /config/m4/m4_weekly/patchtst.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 800 10 | log_every_n_steps: 2 11 | default_root_dir: ./results 12 | accumulate_grad_batches: 8 13 | model: 14 | forecaster: 15 | class_path: probts.model.forecaster.point_forecaster.PatchTST 16 | init_args: 17 | stride: 3 18 | patch_len: 6 19 | dropout: 0.3 20 | f_hidden_size: 32 21 | d_ff: 128 22 | n_layers: 3 23 | n_heads: 8 24 | fc_dropout: 0.2 25 | head_dropout: 0 26 | individual: true 27 | learning_rate: 0.0001 28 | quantiles_num: 20 29 | data: 30 | data_manager: 31 | class_path: probts.data.data_manager.DataManager 32 | init_args: 33 | dataset: m4_weekly 34 | context_length_factor: 3 35 | split_val: true 36 | scaler: standard # identity, standard, temporal 37 | batch_size: 1 38 | test_batch_size: 128 39 | num_workers: 8 -------------------------------------------------------------------------------- /config/m4/m4_weekly/timegrad.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 800 10 | log_every_n_steps: 2 11 | check_val_every_n_epoch: 2 12 | default_root_dir: ./results 13 | accumulate_grad_batches: 8 14 | model: 15 | forecaster: 16 | class_path: probts.model.forecaster.prob_forecaster.TimeGrad 17 | init_args: 18 | loss_type: l2 19 | diff_steps: 50 20 | beta_end: 0.1 21 | beta_schedule: linear 22 | conditional_length: 100 23 | enc_hidden_size: 64 24 | enc_num_layers: 4 25 | enc_dropout: 0.1 26 | use_lags: false 27 | use_feat_idx_emb: false 28 | use_time_feat: false 29 | feat_idx_emb_dim: 1 30 | use_scaling: true 31 | num_samples: 100 32 | learning_rate: 0.001 33 | quantiles_num: 20 34 | data: 35 | data_manager: 36 | class_path: probts.data.data_manager.DataManager 37 | init_args: 38 | dataset: m4_weekly 39 | context_length_factor: 3 40 | split_val: true 41 | scaler: identity # identity, standard, temporal 42 | batch_size: 1 43 | test_batch_size: 1 44 | num_workers: 8 45 | -------------------------------------------------------------------------------- /config/m4/m5/csdi.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 800 10 | log_every_n_steps: 2 11 | check_val_every_n_epoch: 2 12 | default_root_dir: ./results 13 | accumulate_grad_batches: 8 14 | model: 15 | forecaster: 16 | class_path: probts.model.forecaster.prob_forecaster.CSDI 17 | init_args: 18 | emb_time_dim: 32 19 | emb_feature_dim: 4 20 | channels: 16 21 | n_layers: 4 22 | num_heads: 4 23 | num_steps: 50 24 | diffusion_embedding_dim: 32 25 | beta_start: 0.001 26 | beta_end: 0.5 27 | sample_size: 64 28 | linear_trans: false 29 | use_lags: false 30 | use_feat_idx_emb: false 31 | use_time_feat: false 32 | feat_idx_emb_dim: 1 33 | num_samples: 100 34 | learning_rate: 0.001 35 | quantiles_num: 20 36 | data: 37 | data_manager: 38 | class_path: probts.data.data_manager.DataManager 39 | init_args: 40 | dataset: m5 41 | context_length_factor: 3 42 | split_val: true 43 | scaler: standard # identity, standard, temporal 44 | batch_size: 1 45 | test_batch_size: 1 46 | num_workers: 8 47 | -------------------------------------------------------------------------------- /config/m4/m5/dlinear.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 800 10 | log_every_n_steps: 2 11 | check_val_every_n_epoch: 2 12 | default_root_dir: ./results 13 | accumulate_grad_batches: 8 14 | model: 15 | forecaster: 16 | class_path: probts.model.forecaster.point_forecaster.DLinear 17 | init_args: 18 | individual: false 19 | kernel_size: 3 20 | use_lags: false 21 | use_feat_idx_emb: false 22 | use_time_feat: false 23 | learning_rate: 0.001 24 | quantiles_num: 20 25 | data: 26 | data_manager: 27 | class_path: probts.data.data_manager.DataManager 28 | init_args: 29 | dataset: m5 30 | context_length_factor: 3 31 | split_val: true 32 | scaler: standard # identity, standard, temporal 33 | batch_size: 1 34 | test_batch_size: 256 35 | num_workers: 8 36 | -------------------------------------------------------------------------------- /config/m4/m5/gru_nvp.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 800 10 | log_every_n_steps: 2 11 | check_val_every_n_epoch: 2 12 | default_root_dir: ./results 13 | accumulate_grad_batches: 8 14 | model: 15 | forecaster: 16 | class_path: probts.model.forecaster.prob_forecaster.GRU_NVP 17 | init_args: 18 | enc_hidden_size: 40 19 | enc_num_layers: 2 20 | enc_dropout: 0.1 21 | n_blocks: 2 22 | hidden_size: 100 23 | n_hidden: 2 24 | batch_norm: true 25 | conditional_length: 100 26 | dequantize: false 27 | use_lags: false 28 | use_feat_idx_emb: false 29 | use_time_feat: false 30 | feat_idx_emb_dim: 1 31 | use_scaling: true 32 | num_samples: 100 33 | learning_rate: 0.001 34 | quantiles_num: 20 35 | data: 36 | data_manager: 37 | class_path: probts.data.data_manager.DataManager 38 | init_args: 39 | dataset: m5 40 | context_length_factor: 3 41 | split_val: true 42 | scaler: identity # identity, standard, temporal 43 | batch_size: 1 44 | test_batch_size: 1 45 | num_workers: 8 46 | -------------------------------------------------------------------------------- /config/m4/m5/patchtst.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 800 10 | log_every_n_steps: 2 11 | check_val_every_n_epoch: 2 12 | default_root_dir: ./results 13 | accumulate_grad_batches: 8 14 | model: 15 | forecaster: 16 | class_path: probts.model.forecaster.point_forecaster.PatchTST 17 | init_args: 18 | stride: 2 19 | patch_len: 4 20 | dropout: 0.3 21 | f_hidden_size: 64 22 | d_ff: 128 23 | n_layers: 3 24 | n_heads: 8 25 | fc_dropout: 0.2 26 | head_dropout: 0 27 | individual: true 28 | learning_rate: 0.0001 29 | quantiles_num: 20 30 | data: 31 | data_manager: 32 | class_path: probts.data.data_manager.DataManager 33 | init_args: 34 | dataset: m5 35 | context_length_factor: 3 36 | split_val: true 37 | scaler: standard # identity, standard, temporal 38 | batch_size: 1 39 | test_batch_size: 128 40 | num_workers: 8 -------------------------------------------------------------------------------- /config/m4/m5/timegrad.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 30 8 | use_distributed_sampler: false 9 | limit_train_batches: 800 10 | log_every_n_steps: 2 11 | check_val_every_n_epoch: 2 12 | default_root_dir: ./results 13 | accumulate_grad_batches: 8 14 | model: 15 | forecaster: 16 | class_path: probts.model.forecaster.prob_forecaster.TimeGrad 17 | init_args: 18 | loss_type: l2 19 | diff_steps: 50 20 | beta_end: 0.1 21 | beta_schedule: linear 22 | conditional_length: 100 23 | enc_hidden_size: 64 24 | enc_num_layers: 4 25 | enc_dropout: 0.1 26 | use_lags: false 27 | use_feat_idx_emb: false 28 | use_time_feat: false 29 | feat_idx_emb_dim: 1 30 | use_scaling: true 31 | num_samples: 100 32 | learning_rate: 0.001 33 | quantiles_num: 20 34 | data: 35 | data_manager: 36 | class_path: probts.data.data_manager.DataManager 37 | init_args: 38 | dataset: m5 39 | context_length_factor: 3 40 | split_val: true 41 | scaler: identity # identity, standard, temporal 42 | batch_size: 1 43 | test_batch_size: 512 44 | num_workers: 8 45 | -------------------------------------------------------------------------------- /config/m4/tourism_monthly/csdi.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 800 10 | log_every_n_steps: 2 11 | check_val_every_n_epoch: 2 12 | default_root_dir: ./results 13 | accumulate_grad_batches: 8 14 | model: 15 | forecaster: 16 | class_path: probts.model.forecaster.prob_forecaster.CSDI 17 | init_args: 18 | emb_time_dim: 32 19 | emb_feature_dim: 4 20 | channels: 16 21 | n_layers: 4 22 | num_heads: 4 23 | num_steps: 50 24 | diffusion_embedding_dim: 32 25 | beta_start: 0.001 26 | beta_end: 0.5 27 | sample_size: 64 28 | linear_trans: false 29 | use_lags: false 30 | use_feat_idx_emb: false 31 | use_time_feat: false 32 | feat_idx_emb_dim: 1 33 | num_samples: 100 34 | learning_rate: 0.001 35 | quantiles_num: 20 36 | data: 37 | data_manager: 38 | class_path: probts.data.data_manager.DataManager 39 | init_args: 40 | dataset: tourism_monthly 41 | context_length_factor: 3 42 | split_val: true 43 | scaler: standard # identity, standard, temporal 44 | batch_size: 1 45 | test_batch_size: 1 46 | num_workers: 8 47 | -------------------------------------------------------------------------------- /config/m4/tourism_monthly/dlinear.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 800 10 | log_every_n_steps: 2 11 | default_root_dir: ./results 12 | accumulate_grad_batches: 8 13 | model: 14 | forecaster: 15 | class_path: probts.model.forecaster.point_forecaster.DLinear 16 | init_args: 17 | individual: false 18 | kernel_size: 3 19 | use_lags: false 20 | use_feat_idx_emb: false 21 | use_time_feat: false 22 | learning_rate: 0.001 23 | quantiles_num: 20 24 | data: 25 | data_manager: 26 | class_path: probts.data.data_manager.DataManager 27 | init_args: 28 | dataset: tourism_monthly 29 | context_length_factor: 3 30 | split_val: true 31 | scaler: standard # identity, standard, temporal 32 | batch_size: 1 33 | test_batch_size: 1 34 | num_workers: 8 35 | -------------------------------------------------------------------------------- /config/m4/tourism_monthly/gru_nvp.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 800 10 | log_every_n_steps: 2 11 | default_root_dir: ./results 12 | accumulate_grad_batches: 8 13 | model: 14 | forecaster: 15 | class_path: probts.model.forecaster.prob_forecaster.GRU_NVP 16 | init_args: 17 | enc_hidden_size: 40 18 | enc_num_layers: 2 19 | enc_dropout: 0.1 20 | n_blocks: 2 21 | hidden_size: 100 22 | n_hidden: 2 23 | batch_norm: true 24 | conditional_length: 100 25 | dequantize: false 26 | use_lags: false 27 | use_feat_idx_emb: false 28 | use_time_feat: false 29 | feat_idx_emb_dim: 1 30 | use_scaling: true 31 | num_samples: 100 32 | learning_rate: 0.001 33 | quantiles_num: 20 34 | data: 35 | data_manager: 36 | class_path: probts.data.data_manager.DataManager 37 | init_args: 38 | dataset: tourism_monthly 39 | context_length_factor: 3 40 | split_val: true 41 | scaler: identity # identity, standard, temporal 42 | batch_size: 1 43 | test_batch_size: 1 44 | num_workers: 8 45 | -------------------------------------------------------------------------------- /config/m4/tourism_monthly/patchtst.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 800 10 | log_every_n_steps: 2 11 | default_root_dir: ./results 12 | accumulate_grad_batches: 8 13 | model: 14 | forecaster: 15 | class_path: probts.model.forecaster.point_forecaster.PatchTST 16 | init_args: 17 | stride: 2 18 | patch_len: 6 19 | dropout: 0.3 20 | f_hidden_size: 64 21 | d_ff: 128 22 | n_layers: 3 23 | n_heads: 8 24 | fc_dropout: 0.2 25 | head_dropout: 0 26 | individual: true 27 | learning_rate: 0.0001 28 | quantiles_num: 20 29 | data: 30 | data_manager: 31 | class_path: probts.data.data_manager.DataManager 32 | init_args: 33 | dataset: tourism_monthly 34 | context_length_factor: 3 35 | split_val: true 36 | scaler: standard # identity, standard, temporal 37 | batch_size: 1 38 | test_batch_size: 128 39 | num_workers: 8 -------------------------------------------------------------------------------- /config/m4/tourism_monthly/timegrad.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 800 10 | log_every_n_steps: 2 11 | check_val_every_n_epoch: 2 12 | default_root_dir: ./results 13 | accumulate_grad_batches: 8 14 | model: 15 | forecaster: 16 | class_path: probts.model.forecaster.prob_forecaster.TimeGrad 17 | init_args: 18 | loss_type: l2 19 | diff_steps: 50 20 | beta_end: 0.1 21 | beta_schedule: linear 22 | conditional_length: 100 23 | enc_hidden_size: 64 24 | enc_num_layers: 4 25 | enc_dropout: 0.1 26 | use_lags: false 27 | use_feat_idx_emb: false 28 | use_time_feat: false 29 | feat_idx_emb_dim: 1 30 | use_scaling: true 31 | num_samples: 100 32 | learning_rate: 0.001 33 | quantiles_num: 20 34 | data: 35 | data_manager: 36 | class_path: probts.data.data_manager.DataManager 37 | init_args: 38 | dataset: tourism_monthly 39 | context_length_factor: 3 40 | split_val: true 41 | scaler: identity # identity, standard, temporal 42 | batch_size: 1 43 | test_batch_size: 1 44 | num_workers: 8 45 | -------------------------------------------------------------------------------- /config/pipeline_config.yaml: -------------------------------------------------------------------------------- 1 | # lightning.pytorch==2.3.0dev 2 | seed_everything: true 3 | trainer: 4 | accelerator: auto 5 | strategy: auto 6 | devices: auto 7 | num_nodes: 1 8 | precision: null 9 | logger: null 10 | callbacks: null 11 | fast_dev_run: false 12 | max_epochs: null 13 | min_epochs: null 14 | max_steps: -1 15 | min_steps: null 16 | max_time: null 17 | limit_train_batches: null 18 | limit_val_batches: null 19 | limit_test_batches: null 20 | limit_predict_batches: null 21 | overfit_batches: 0.0 22 | val_check_interval: null 23 | check_val_every_n_epoch: 1 24 | num_sanity_val_steps: null 25 | log_every_n_steps: null 26 | enable_checkpointing: null 27 | enable_progress_bar: null 28 | enable_model_summary: null 29 | accumulate_grad_batches: 1 30 | gradient_clip_val: null 31 | gradient_clip_algorithm: null 32 | deterministic: null 33 | benchmark: null 34 | inference_mode: true 35 | use_distributed_sampler: true 36 | profiler: null 37 | detect_anomaly: false 38 | barebones: false 39 | plugins: null 40 | sync_batchnorm: false 41 | reload_dataloaders_every_n_epochs: 0 42 | default_root_dir: null 43 | model: 44 | forecaster: null 45 | num_samples: 100 46 | learning_rate: 0.001 47 | quantiles_num: 10 48 | load_from_ckpt: null 49 | data: 50 | data_manager: null 51 | batch_size: 64 52 | test_batch_size: 8 53 | num_workers: 8 54 | -------------------------------------------------------------------------------- /config/stsf/electricity/csdi.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 800 10 | log_every_n_steps: 1 11 | check_val_every_n_epoch: 2 12 | default_root_dir: ./results 13 | accumulate_grad_batches: 8 14 | model: 15 | forecaster: 16 | class_path: probts.model.forecaster.prob_forecaster.CSDI 17 | init_args: 18 | emb_time_dim: 128 19 | emb_feature_dim: 16 20 | channels: 64 21 | n_layers: 4 22 | num_heads: 8 23 | num_steps: 50 24 | diffusion_embedding_dim: 128 25 | beta_start: 0.001 26 | beta_end: 0.5 27 | sample_size: 64 28 | linear_trans: false 29 | use_lags: false 30 | use_feat_idx_emb: false 31 | use_time_feat: false 32 | feat_idx_emb_dim: 1 33 | num_samples: 100 34 | learning_rate: 0.001 35 | quantiles_num: 20 36 | data: 37 | data_manager: 38 | class_path: probts.data.data_manager.DataManager 39 | init_args: 40 | dataset: electricity_nips 41 | split_val: true 42 | scaler: standard # identity, standard, temporal 43 | batch_size: 4 44 | test_batch_size: 4 45 | num_workers: 8 46 | -------------------------------------------------------------------------------- /config/stsf/electricity/dlinear.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | model: 13 | forecaster: 14 | class_path: probts.model.forecaster.point_forecaster.DLinear 15 | init_args: 16 | individual: true 17 | kernel_size: 3 18 | use_lags: false 19 | use_feat_idx_emb: false 20 | use_time_feat: false 21 | learning_rate: 0.01 22 | quantiles_num: 20 23 | data: 24 | data_manager: 25 | class_path: probts.data.data_manager.DataManager 26 | init_args: 27 | dataset: electricity_nips 28 | split_val: true 29 | scaler: standard # identity, standard, temporal 30 | batch_size: 32 31 | test_batch_size: 32 32 | num_workers: 8 -------------------------------------------------------------------------------- /config/stsf/electricity/gru.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | model: 13 | forecaster: 14 | class_path: probts.model.forecaster.point_forecaster.GRUForecaster 15 | init_args: 16 | f_hidden_size: 40 17 | num_layers: 2 18 | dropout: 0.1 19 | use_lags: true 20 | use_feat_idx_emb: true 21 | use_time_feat: true 22 | feat_idx_emb_dim: 1 23 | learning_rate: 0.001 24 | quantiles_num: 20 25 | data: 26 | data_manager: 27 | class_path: probts.data.data_manager.DataManager 28 | init_args: 29 | dataset: electricity_nips 30 | split_val: true 31 | scaler: standard # identity, standard, temporal 32 | batch_size: 64 33 | test_batch_size: 64 34 | num_workers: 8 35 | -------------------------------------------------------------------------------- /config/stsf/electricity/gru_maf.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | model: 13 | forecaster: 14 | class_path: probts.model.forecaster.prob_forecaster.GRU_MAF 15 | init_args: 16 | enc_num_layers: 2 17 | enc_hidden_size: 40 18 | enc_dropout: 0.1 19 | n_blocks: 4 20 | hidden_size: 100 21 | n_hidden: 2 22 | batch_norm: true 23 | conditional_length: 200 24 | dequantize: false 25 | use_lags: true 26 | use_feat_idx_emb: true 27 | use_time_feat: true 28 | feat_idx_emb_dim: 1 29 | use_scaling: true 30 | num_samples: 100 31 | learning_rate: 0.001 32 | quantiles_num: 20 33 | data: 34 | data_manager: 35 | class_path: probts.data.data_manager.DataManager 36 | init_args: 37 | dataset: electricity_nips 38 | split_val: true 39 | scaler: identity # identity, standard, temporal 40 | batch_size: 64 41 | test_batch_size: 64 42 | num_workers: 8 43 | -------------------------------------------------------------------------------- /config/stsf/electricity/gru_nvp.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | model: 13 | forecaster: 14 | class_path: probts.model.forecaster.prob_forecaster.GRU_NVP 15 | init_args: 16 | enc_hidden_size: 40 17 | enc_num_layers: 2 18 | enc_dropout: 0.1 19 | n_blocks: 3 20 | hidden_size: 100 21 | n_hidden: 2 22 | batch_norm: true 23 | conditional_length: 200 24 | dequantize: false 25 | use_lags: true 26 | use_feat_idx_emb: true 27 | use_time_feat: true 28 | feat_idx_emb_dim: 1 29 | use_scaling: true 30 | num_samples: 100 31 | learning_rate: 0.001 32 | quantiles_num: 20 33 | data: 34 | data_manager: 35 | class_path: probts.data.data_manager.DataManager 36 | init_args: 37 | dataset: electricity_nips 38 | split_val: true 39 | scaler: identity # identity, standard, temporal 40 | batch_size: 64 41 | test_batch_size: 64 42 | num_workers: 8 43 | -------------------------------------------------------------------------------- /config/stsf/electricity/patchtst.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | accumulate_grad_batches: 1 13 | model: 14 | forecaster: 15 | class_path: probts.model.forecaster.point_forecaster.PatchTST 16 | init_args: 17 | stride: 2 18 | patch_len: 4 19 | dropout: 0.1 20 | f_hidden_size: 64 21 | n_layers: 4 22 | n_heads: 8 23 | fc_dropout: 0.1 24 | head_dropout: 0 25 | individual: true 26 | learning_rate: 0.0001 27 | quantiles_num: 20 28 | data: 29 | data_manager: 30 | class_path: probts.data.data_manager.DataManager 31 | init_args: 32 | dataset: electricity_nips 33 | split_val: true 34 | scaler: standard # identity, standard, temporal 35 | batch_size: 64 36 | test_batch_size: 64 37 | num_workers: 8 -------------------------------------------------------------------------------- /config/stsf/electricity/timegrad.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | model: 13 | forecaster: 14 | class_path: probts.model.forecaster.prob_forecaster.TimeGrad 15 | init_args: 16 | loss_type: l2 17 | diff_steps: 100 18 | beta_end: 0.1 19 | beta_schedule: linear 20 | conditional_length: 100 21 | enc_hidden_size: 128 22 | enc_num_layers: 4 23 | enc_dropout: 0.1 24 | use_lags: true 25 | use_feat_idx_emb: true 26 | use_time_feat: true 27 | feat_idx_emb_dim: 1 28 | use_scaling: true 29 | num_samples: 100 30 | learning_rate: 0.001 31 | quantiles_num: 20 32 | data: 33 | data_manager: 34 | class_path: probts.data.data_manager.DataManager 35 | init_args: 36 | dataset: electricity_nips 37 | split_val: true 38 | scaler: identity # identity, standard, temporal 39 | batch_size: 64 40 | test_batch_size: 64 41 | num_workers: 8 42 | -------------------------------------------------------------------------------- /config/stsf/electricity/timesnet.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | model: 13 | forecaster: 14 | class_path: probts.model.forecaster.point_forecaster.TimesNet 15 | init_args: 16 | n_layers: 2 17 | num_kernels: 6 18 | top_k: 5 19 | d_ff: 64 20 | dropout: 0.1 21 | f_hidden_size: 64 22 | use_lags: false 23 | use_feat_idx_emb: false 24 | use_time_feat: false 25 | learning_rate: 0.001 26 | quantiles_num: 20 27 | data: 28 | data_manager: 29 | class_path: probts.data.data_manager.DataManager 30 | init_args: 31 | dataset: electricity_nips 32 | split_val: true 33 | scaler: standard # identity, standard, temporal 34 | batch_size: 64 35 | test_batch_size: 64 36 | num_workers: 8 37 | -------------------------------------------------------------------------------- /config/stsf/electricity/trans_maf.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | model: 13 | forecaster: 14 | class_path: probts.model.forecaster.prob_forecaster.Trans_MAF 15 | init_args: 16 | enc_hidden_size: 32 17 | enc_num_heads: 8 18 | enc_num_encoder_layers: 2 19 | enc_num_decoder_layers: 2 20 | enc_dim_feedforward_scale: 4 21 | enc_dropout: 0.1 22 | enc_activation: gelu 23 | n_blocks: 4 24 | hidden_size: 100 25 | n_hidden: 2 26 | batch_norm: true 27 | conditional_length: 200 28 | dequantize: false 29 | use_lags: true 30 | use_feat_idx_emb: true 31 | use_time_feat: true 32 | feat_idx_emb_dim: 1 33 | use_scaling: true 34 | num_samples: 100 35 | learning_rate: 0.001 36 | quantiles_num: 20 37 | data: 38 | data_manager: 39 | class_path: probts.data.data_manager.DataManager 40 | init_args: 41 | dataset: electricity_nips 42 | split_val: true 43 | scaler: identity # identity, standard, temporal 44 | batch_size: 64 45 | test_batch_size: 64 46 | num_workers: 8 47 | -------------------------------------------------------------------------------- /config/stsf/electricity/transformer.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | model: 13 | forecaster: 14 | class_path: probts.model.forecaster.point_forecaster.TransformerForecaster 15 | init_args: 16 | f_hidden_size: 32 17 | num_heads: 8 18 | num_encoder_layers: 3 19 | num_decoder_layers: 3 20 | dim_feedforward_scale: 4 21 | dropout: 0.1 22 | activation: gelu 23 | use_lags: true 24 | use_feat_idx_emb: true 25 | use_time_feat: true 26 | feat_idx_emb_dim: 1 27 | learning_rate: 0.001 28 | quantiles_num: 20 29 | data: 30 | data_manager: 31 | class_path: probts.data.data_manager.DataManager 32 | init_args: 33 | dataset: electricity_nips 34 | split_val: true 35 | scaler: standard # identity, standard, temporal 36 | batch_size: 64 37 | test_batch_size: 64 38 | num_workers: 8 39 | -------------------------------------------------------------------------------- /config/stsf/exchange/csdi.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 800 10 | log_every_n_steps: 1 11 | check_val_every_n_epoch: 2 12 | default_root_dir: ./results 13 | accumulate_grad_batches: 8 14 | model: 15 | forecaster: 16 | class_path: probts.model.forecaster.prob_forecaster.CSDI 17 | init_args: 18 | emb_time_dim: 128 19 | emb_feature_dim: 16 20 | channels: 64 21 | n_layers: 4 22 | num_heads: 8 23 | num_steps: 50 24 | diffusion_embedding_dim: 128 25 | beta_start: 0.001 26 | beta_end: 0.5 27 | sample_size: 64 28 | linear_trans: false 29 | use_lags: false 30 | use_feat_idx_emb: false 31 | use_time_feat: false 32 | feat_idx_emb_dim: 1 33 | num_samples: 100 34 | learning_rate: 0.001 35 | quantiles_num: 20 36 | data: 37 | data_manager: 38 | class_path: probts.data.data_manager.DataManager 39 | init_args: 40 | dataset: exchange_rate_nips 41 | split_val: true 42 | scaler: standard # identity, standard, temporal 43 | batch_size: 4 44 | test_batch_size: 4 45 | num_workers: 8 46 | -------------------------------------------------------------------------------- /config/stsf/exchange/dlinear.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | model: 13 | forecaster: 14 | class_path: probts.model.forecaster.point_forecaster.DLinear 15 | init_args: 16 | individual: false 17 | kernel_size: 3 18 | use_lags: false 19 | use_feat_idx_emb: false 20 | use_time_feat: false 21 | learning_rate: 0.01 22 | quantiles_num: 20 23 | data: 24 | data_manager: 25 | class_path: probts.data.data_manager.DataManager 26 | init_args: 27 | dataset: exchange_rate_nips 28 | split_val: true 29 | scaler: standard # identity, standard, temporal 30 | batch_size: 32 31 | test_batch_size: 32 32 | num_workers: 8 -------------------------------------------------------------------------------- /config/stsf/exchange/gru.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | model: 13 | forecaster: 14 | class_path: probts.model.forecaster.point_forecaster.GRUForecaster 15 | init_args: 16 | f_hidden_size: 40 17 | num_layers: 2 18 | dropout: 0.1 19 | use_lags: true 20 | use_feat_idx_emb: true 21 | use_time_feat: true 22 | feat_idx_emb_dim: 1 23 | learning_rate: 0.001 24 | quantiles_num: 20 25 | data: 26 | data_manager: 27 | class_path: probts.data.data_manager.DataManager 28 | init_args: 29 | dataset: exchange_rate_nips 30 | split_val: true 31 | scaler: standard # identity, standard, temporal 32 | batch_size: 64 33 | test_batch_size: 64 34 | num_workers: 8 35 | -------------------------------------------------------------------------------- /config/stsf/exchange/gru_maf.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | model: 13 | forecaster: 14 | class_path: probts.model.forecaster.prob_forecaster.GRU_MAF 15 | init_args: 16 | enc_num_layers: 2 17 | enc_hidden_size: 40 18 | enc_dropout: 0.1 19 | n_blocks: 4 20 | hidden_size: 100 21 | n_hidden: 2 22 | batch_norm: false 23 | conditional_length: 200 24 | dequantize: false 25 | use_lags: true 26 | use_feat_idx_emb: true 27 | use_time_feat: true 28 | feat_idx_emb_dim: 1 29 | use_scaling: true 30 | num_samples: 100 31 | learning_rate: 0.001 32 | quantiles_num: 20 33 | data: 34 | data_manager: 35 | class_path: probts.data.data_manager.DataManager 36 | init_args: 37 | dataset: exchange_rate_nips 38 | split_val: true 39 | scaler: identity # identity, standard, temporal 40 | batch_size: 64 41 | test_batch_size: 64 42 | num_workers: 8 43 | -------------------------------------------------------------------------------- /config/stsf/exchange/gru_nvp.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | model: 13 | forecaster: 14 | class_path: probts.model.forecaster.prob_forecaster.GRU_NVP 15 | init_args: 16 | enc_hidden_size: 40 17 | enc_num_layers: 2 18 | enc_dropout: 0.1 19 | n_blocks: 4 20 | hidden_size: 100 21 | n_hidden: 2 22 | batch_norm: true 23 | conditional_length: 200 24 | dequantize: false 25 | use_lags: true 26 | use_feat_idx_emb: true 27 | use_time_feat: true 28 | feat_idx_emb_dim: 1 29 | use_scaling: true 30 | num_samples: 100 31 | learning_rate: 0.001 32 | quantiles_num: 20 33 | data: 34 | data_manager: 35 | class_path: probts.data.data_manager.DataManager 36 | init_args: 37 | dataset: exchange_rate_nips 38 | split_val: true 39 | scaler: identity # identity, standard, temporal 40 | batch_size: 64 41 | test_batch_size: 64 42 | num_workers: 8 43 | -------------------------------------------------------------------------------- /config/stsf/exchange/patchtst.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | accumulate_grad_batches: 1 13 | model: 14 | forecaster: 15 | class_path: probts.model.forecaster.point_forecaster.PatchTST 16 | init_args: 17 | stride: 3 18 | patch_len: 6 19 | dropout: 0.1 20 | f_hidden_size: 32 21 | n_layers: 3 22 | n_heads: 8 23 | fc_dropout: 0.2 24 | head_dropout: 0 25 | individual: true 26 | learning_rate: 0.0001 27 | quantiles_num: 20 28 | data: 29 | data_manager: 30 | class_path: probts.data.data_manager.DataManager 31 | init_args: 32 | dataset: exchange_rate_nips 33 | split_val: true 34 | scaler: standard # identity, standard, temporal 35 | batch_size: 64 36 | test_batch_size: 64 37 | num_workers: 8 -------------------------------------------------------------------------------- /config/stsf/exchange/timegrad.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | model: 13 | forecaster: 14 | class_path: probts.model.forecaster.prob_forecaster.TimeGrad 15 | init_args: 16 | loss_type: l2 17 | diff_steps: 100 18 | beta_end: 0.1 19 | beta_schedule: linear 20 | conditional_length: 100 21 | enc_hidden_size: 128 22 | enc_num_layers: 4 23 | enc_dropout: 0.1 24 | use_lags: true 25 | use_feat_idx_emb: true 26 | use_time_feat: true 27 | feat_idx_emb_dim: 1 28 | use_scaling: true 29 | num_samples: 100 30 | learning_rate: 0.001 31 | quantiles_num: 20 32 | data: 33 | data_manager: 34 | class_path: probts.data.data_manager.DataManager 35 | init_args: 36 | dataset: exchange_rate_nips 37 | split_val: true 38 | scaler: identity # identity, standard, temporal 39 | batch_size: 64 40 | test_batch_size: 64 41 | num_workers: 8 42 | -------------------------------------------------------------------------------- /config/stsf/exchange/timesnet.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | model: 13 | forecaster: 14 | class_path: probts.model.forecaster.point_forecaster.TimesNet 15 | init_args: 16 | n_layers: 2 17 | num_kernels: 6 18 | top_k: 5 19 | d_ff: 64 20 | dropout: 0.1 21 | f_hidden_size: 64 22 | use_lags: false 23 | use_feat_idx_emb: false 24 | use_time_feat: false 25 | learning_rate: 0.0001 26 | quantiles_num: 20 27 | data: 28 | data_manager: 29 | class_path: probts.data.data_manager.DataManager 30 | init_args: 31 | dataset: exchange_rate_nips 32 | split_val: true 33 | scaler: standard # identity, standard, temporal 34 | batch_size: 64 35 | test_batch_size: 64 36 | num_workers: 8 37 | -------------------------------------------------------------------------------- /config/stsf/exchange/trans_maf.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | model: 13 | forecaster: 14 | class_path: probts.model.forecaster.prob_forecaster.Trans_MAF 15 | init_args: 16 | enc_hidden_size: 16 17 | enc_num_heads: 8 18 | enc_num_encoder_layers: 2 19 | enc_num_decoder_layers: 2 20 | enc_dim_feedforward_scale: 4 21 | enc_dropout: 0.1 22 | enc_activation: gelu 23 | n_blocks: 4 24 | hidden_size: 100 25 | n_hidden: 2 26 | batch_norm: false 27 | conditional_length: 200 28 | dequantize: false 29 | use_lags: true 30 | use_feat_idx_emb: true 31 | use_time_feat: true 32 | feat_idx_emb_dim: 1 33 | use_scaling: true 34 | num_samples: 100 35 | learning_rate: 0.001 36 | quantiles_num: 20 37 | data: 38 | data_manager: 39 | class_path: probts.data.data_manager.DataManager 40 | init_args: 41 | dataset: exchange_rate_nips 42 | split_val: true 43 | scaler: identity # identity, standard, temporal 44 | batch_size: 64 45 | test_batch_size: 64 46 | num_workers: 8 -------------------------------------------------------------------------------- /config/stsf/exchange/transformer.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | model: 13 | forecaster: 14 | class_path: probts.model.forecaster.point_forecaster.TransformerForecaster 15 | init_args: 16 | f_hidden_size: 32 17 | num_heads: 8 18 | num_encoder_layers: 3 19 | num_decoder_layers: 3 20 | dim_feedforward_scale: 4 21 | dropout: 0.1 22 | activation: gelu 23 | use_lags: true 24 | use_feat_idx_emb: true 25 | use_time_feat: true 26 | feat_idx_emb_dim: 1 27 | learning_rate: 0.001 28 | quantiles_num: 20 29 | data: 30 | data_manager: 31 | class_path: probts.data.data_manager.DataManager 32 | init_args: 33 | dataset: exchange_rate_nips 34 | split_val: true 35 | scaler: standard # identity, standard, temporal 36 | batch_size: 64 37 | test_batch_size: 64 38 | num_workers: 8 -------------------------------------------------------------------------------- /config/stsf/solar/csdi.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 800 10 | log_every_n_steps: 1 11 | check_val_every_n_epoch: 2 12 | default_root_dir: ./results 13 | accumulate_grad_batches: 8 14 | model: 15 | forecaster: 16 | class_path: probts.model.forecaster.prob_forecaster.CSDI 17 | init_args: 18 | emb_time_dim: 128 19 | emb_feature_dim: 16 20 | channels: 64 21 | n_layers: 4 22 | num_heads: 8 23 | num_steps: 50 24 | diffusion_embedding_dim: 128 25 | beta_start: 0.001 26 | beta_end: 0.5 27 | sample_size: 64 28 | linear_trans: false 29 | use_lags: false 30 | use_feat_idx_emb: false 31 | use_time_feat: false 32 | feat_idx_emb_dim: 1 33 | num_samples: 100 34 | learning_rate: 0.001 35 | quantiles_num: 20 36 | data: 37 | data_manager: 38 | class_path: probts.data.data_manager.DataManager 39 | init_args: 40 | dataset: solar_nips 41 | split_val: true 42 | scaler: standard # identity, standard, temporal 43 | batch_size: 4 44 | test_batch_size: 4 45 | num_workers: 8 46 | -------------------------------------------------------------------------------- /config/stsf/solar/dlinear.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | model: 13 | forecaster: 14 | class_path: probts.model.forecaster.point_forecaster.DLinear 15 | init_args: 16 | individual: false 17 | kernel_size: 3 18 | use_lags: true 19 | use_feat_idx_emb: true 20 | use_time_feat: true 21 | learning_rate: 0.01 22 | quantiles_num: 20 23 | data: 24 | data_manager: 25 | class_path: probts.data.data_manager.DataManager 26 | init_args: 27 | dataset: solar_nips 28 | split_val: true 29 | scaler: standard # identity, standard, temporal 30 | batch_size: 32 31 | test_batch_size: 32 32 | num_workers: 8 -------------------------------------------------------------------------------- /config/stsf/solar/gru.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | model: 13 | forecaster: 14 | class_path: probts.model.forecaster.point_forecaster.GRUForecaster 15 | init_args: 16 | f_hidden_size: 40 17 | num_layers: 2 18 | dropout: 0.1 19 | use_lags: true 20 | use_feat_idx_emb: true 21 | use_time_feat: true 22 | feat_idx_emb_dim: 1 23 | learning_rate: 0.001 24 | quantiles_num: 20 25 | data: 26 | data_manager: 27 | class_path: probts.data.data_manager.DataManager 28 | init_args: 29 | dataset: solar_nips 30 | split_val: true 31 | scaler: standard # identity, standard, temporal 32 | batch_size: 64 33 | test_batch_size: 64 34 | num_workers: 8 35 | -------------------------------------------------------------------------------- /config/stsf/solar/gru_maf.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | model: 13 | forecaster: 14 | class_path: probts.model.forecaster.prob_forecaster.GRU_MAF 15 | init_args: 16 | enc_num_layers: 2 17 | enc_hidden_size: 40 18 | enc_dropout: 0.1 19 | n_blocks: 4 20 | hidden_size: 100 21 | n_hidden: 2 22 | batch_norm: false 23 | conditional_length: 200 24 | dequantize: true 25 | use_lags: true 26 | use_feat_idx_emb: true 27 | use_time_feat: true 28 | feat_idx_emb_dim: 1 29 | use_scaling: true 30 | num_samples: 100 31 | learning_rate: 0.001 32 | quantiles_num: 20 33 | data: 34 | data_manager: 35 | class_path: probts.data.data_manager.DataManager 36 | init_args: 37 | dataset: solar_nips 38 | split_val: true 39 | scaler: identity # identity, standard, temporal 40 | batch_size: 64 41 | test_batch_size: 64 42 | num_workers: 8 43 | -------------------------------------------------------------------------------- /config/stsf/solar/gru_nvp.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | model: 13 | forecaster: 14 | class_path: probts.model.forecaster.prob_forecaster.GRU_NVP 15 | init_args: 16 | enc_hidden_size: 40 17 | enc_num_layers: 2 18 | enc_dropout: 0.1 19 | n_blocks: 4 20 | hidden_size: 100 21 | n_hidden: 2 22 | batch_norm: true 23 | conditional_length: 200 24 | dequantize: true 25 | use_lags: true 26 | use_feat_idx_emb: true 27 | use_time_feat: true 28 | feat_idx_emb_dim: 1 29 | use_scaling: true 30 | num_samples: 100 31 | learning_rate: 0.001 32 | quantiles_num: 20 33 | data: 34 | data_manager: 35 | class_path: probts.data.data_manager.DataManager 36 | init_args: 37 | dataset: solar_nips 38 | split_val: true 39 | scaler: identity # identity, standard, temporal 40 | batch_size: 64 41 | test_batch_size: 64 42 | num_workers: 8 43 | -------------------------------------------------------------------------------- /config/stsf/solar/patchtst.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | accumulate_grad_batches: 1 13 | model: 14 | forecaster: 15 | class_path: probts.model.forecaster.point_forecaster.PatchTST 16 | init_args: 17 | stride: 3 18 | patch_len: 6 19 | dropout: 0.1 20 | f_hidden_size: 32 21 | n_layers: 3 22 | n_heads: 8 23 | fc_dropout: 0.2 24 | head_dropout: 0 25 | individual: true 26 | learning_rate: 0.0001 27 | quantiles_num: 20 28 | data: 29 | data_manager: 30 | class_path: probts.data.data_manager.DataManager 31 | init_args: 32 | dataset: solar_nips 33 | split_val: true 34 | scaler: standard # identity, standard, temporal 35 | batch_size: 64 36 | test_batch_size: 64 37 | num_workers: 8 -------------------------------------------------------------------------------- /config/stsf/solar/timegrad.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | model: 13 | forecaster: 14 | class_path: probts.model.forecaster.prob_forecaster.TimeGrad 15 | init_args: 16 | loss_type: l2 17 | diff_steps: 100 18 | beta_end: 0.1 19 | beta_schedule: linear 20 | conditional_length: 100 21 | enc_hidden_size: 128 22 | enc_num_layers: 4 23 | enc_dropout: 0.1 24 | use_lags: true 25 | use_feat_idx_emb: true 26 | use_time_feat: true 27 | feat_idx_emb_dim: 1 28 | use_scaling: true 29 | num_samples: 100 30 | learning_rate: 0.001 31 | quantiles_num: 20 32 | data: 33 | data_manager: 34 | class_path: probts.data.data_manager.DataManager 35 | init_args: 36 | dataset: solar_nips 37 | split_val: true 38 | scaler: identity # identity, standard, temporal 39 | batch_size: 64 40 | test_batch_size: 64 41 | num_workers: 8 42 | -------------------------------------------------------------------------------- /config/stsf/solar/timesnet.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | model: 13 | forecaster: 14 | class_path: probts.model.forecaster.point_forecaster.TimesNet 15 | init_args: 16 | n_layers: 2 17 | num_kernels: 6 18 | top_k: 5 19 | d_ff: 16 20 | dropout: 0.1 21 | f_hidden_size: 16 22 | use_lags: false 23 | use_feat_idx_emb: false 24 | use_time_feat: false 25 | learning_rate: 0.001 26 | quantiles_num: 20 27 | data: 28 | data_manager: 29 | class_path: probts.data.data_manager.DataManager 30 | init_args: 31 | dataset: solar_nips 32 | split_val: true 33 | scaler: standard # identity, standard, temporal 34 | batch_size: 64 35 | test_batch_size: 64 36 | num_workers: 8 37 | -------------------------------------------------------------------------------- /config/stsf/solar/trans_maf.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | model: 13 | forecaster: 14 | class_path: probts.model.forecaster.prob_forecaster.Trans_MAF 15 | init_args: 16 | enc_hidden_size: 32 17 | enc_num_heads: 8 18 | enc_num_encoder_layers: 2 19 | enc_num_decoder_layers: 2 20 | enc_dim_feedforward_scale: 4 21 | enc_dropout: 0.1 22 | enc_activation: gelu 23 | n_blocks: 4 24 | hidden_size: 100 25 | n_hidden: 2 26 | batch_norm: false 27 | conditional_length: 200 28 | dequantize: true 29 | use_lags: true 30 | use_feat_idx_emb: true 31 | use_time_feat: true 32 | feat_idx_emb_dim: 1 33 | use_scaling: true 34 | num_samples: 100 35 | learning_rate: 0.001 36 | quantiles_num: 20 37 | data: 38 | data_manager: 39 | class_path: probts.data.data_manager.DataManager 40 | init_args: 41 | dataset: solar_nips 42 | split_val: true 43 | scaler: identity # identity, standard, temporal 44 | batch_size: 64 45 | test_batch_size: 64 46 | num_workers: 8 -------------------------------------------------------------------------------- /config/stsf/solar/transformer.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | model: 13 | forecaster: 14 | class_path: probts.model.forecaster.point_forecaster.TransformerForecaster 15 | init_args: 16 | f_hidden_size: 16 17 | num_heads: 4 18 | num_encoder_layers: 3 19 | num_decoder_layers: 3 20 | dim_feedforward_scale: 4 21 | dropout: 0.1 22 | activation: gelu 23 | use_lags: true 24 | use_feat_idx_emb: true 25 | use_time_feat: true 26 | feat_idx_emb_dim: 1 27 | learning_rate: 0.001 28 | quantiles_num: 20 29 | data: 30 | data_manager: 31 | class_path: probts.data.data_manager.DataManager 32 | init_args: 33 | dataset: solar_nips 34 | split_val: true 35 | scaler: standard # identity, standard, temporal 36 | batch_size: 64 37 | test_batch_size: 64 38 | num_workers: 8 -------------------------------------------------------------------------------- /config/stsf/traffic/csdi.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 400 10 | log_every_n_steps: 1 11 | check_val_every_n_epoch: 3 12 | default_root_dir: ./results 13 | accumulate_grad_batches: 4 14 | model: 15 | forecaster: 16 | class_path: probts.model.forecaster.prob_forecaster.CSDI 17 | init_args: 18 | emb_time_dim: 64 19 | emb_feature_dim: 8 20 | channels: 64 21 | n_layers: 4 22 | num_heads: 8 23 | num_steps: 50 24 | diffusion_embedding_dim: 64 25 | beta_start: 0.001 26 | beta_end: 0.5 27 | sample_size: 16 28 | linear_trans: false 29 | use_lags: false 30 | use_feat_idx_emb: false 31 | use_time_feat: false 32 | feat_idx_emb_dim: 1 33 | num_samples: 100 34 | learning_rate: 0.001 35 | quantiles_num: 20 36 | data: 37 | data_manager: 38 | class_path: probts.data.data_manager.DataManager 39 | init_args: 40 | dataset: traffic_nips 41 | split_val: true 42 | scaler: standard # identity, standard, temporal 43 | batch_size: 8 44 | test_batch_size: 8 45 | num_workers: 8 46 | -------------------------------------------------------------------------------- /config/stsf/traffic/dlinear.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | model: 13 | forecaster: 14 | class_path: probts.model.forecaster.point_forecaster.DLinear 15 | init_args: 16 | individual: false 17 | kernel_size: 3 18 | use_lags: false 19 | use_feat_idx_emb: false 20 | use_time_feat: false 21 | learning_rate: 0.001 22 | quantiles_num: 20 23 | data: 24 | data_manager: 25 | class_path: probts.data.data_manager.DataManager 26 | init_args: 27 | dataset: traffic_nips 28 | split_val: true 29 | scaler: standard # identity, standard, temporal 30 | batch_size: 32 31 | test_batch_size: 32 32 | num_workers: 8 -------------------------------------------------------------------------------- /config/stsf/traffic/gru.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | model: 13 | forecaster: 14 | class_path: probts.model.forecaster.point_forecaster.GRUForecaster 15 | init_args: 16 | f_hidden_size: 128 17 | num_layers: 2 18 | dropout: 0.1 19 | use_lags: true 20 | use_feat_idx_emb: true 21 | use_time_feat: true 22 | feat_idx_emb_dim: 1 23 | learning_rate: 0.001 24 | quantiles_num: 20 25 | data: 26 | data_manager: 27 | class_path: probts.data.data_manager.DataManager 28 | init_args: 29 | dataset: traffic_nips 30 | split_val: true 31 | scaler: standard # identity, standard, temporal 32 | batch_size: 32 33 | test_batch_size: 32 34 | num_workers: 8 35 | -------------------------------------------------------------------------------- /config/stsf/traffic/gru_maf.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | model: 13 | forecaster: 14 | class_path: probts.model.forecaster.prob_forecaster.GRU_MAF 15 | init_args: 16 | enc_num_layers: 2 17 | enc_hidden_size: 128 18 | enc_dropout: 0.3 19 | n_blocks: 3 20 | hidden_size: 100 21 | n_hidden: 2 22 | batch_norm: true 23 | conditional_length: 200 24 | dequantize: false 25 | use_lags: true 26 | use_feat_idx_emb: true 27 | use_time_feat: true 28 | feat_idx_emb_dim: 1 29 | use_scaling: true 30 | num_samples: 100 31 | learning_rate: 0.001 32 | quantiles_num: 20 33 | data: 34 | data_manager: 35 | class_path: probts.data.data_manager.DataManager 36 | init_args: 37 | dataset: traffic_nips 38 | split_val: true 39 | scaler: identity # identity, standard, temporal 40 | batch_size: 32 41 | test_batch_size: 32 42 | num_workers: 8 -------------------------------------------------------------------------------- /config/stsf/traffic/gru_nvp.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | model: 13 | forecaster: 14 | class_path: probts.model.forecaster.prob_forecaster.GRU_NVP 15 | init_args: 16 | enc_hidden_size: 128 17 | enc_num_layers: 2 18 | enc_dropout: 0.3 19 | n_blocks: 4 20 | hidden_size: 100 21 | n_hidden: 2 22 | batch_norm: true 23 | conditional_length: 200 24 | dequantize: false 25 | use_lags: true 26 | use_feat_idx_emb: true 27 | use_time_feat: true 28 | feat_idx_emb_dim: 1 29 | use_scaling: true 30 | num_samples: 100 31 | learning_rate: 0.001 32 | quantiles_num: 20 33 | data: 34 | data_manager: 35 | class_path: probts.data.data_manager.DataManager 36 | init_args: 37 | dataset: traffic_nips 38 | split_val: true 39 | scaler: identity # identity, standard, temporal 40 | batch_size: 32 41 | test_batch_size: 32 42 | num_workers: 8 -------------------------------------------------------------------------------- /config/stsf/traffic/patchtst.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | accumulate_grad_batches: 1 13 | model: 14 | forecaster: 15 | class_path: probts.model.forecaster.point_forecaster.PatchTST 16 | init_args: 17 | stride: 3 18 | patch_len: 6 19 | dropout: 0.1 20 | f_hidden_size: 32 21 | n_layers: 3 22 | n_heads: 8 23 | fc_dropout: 0.2 24 | head_dropout: 0 25 | individual: false 26 | num_samples: 100 27 | learning_rate: 0.0001 28 | quantiles_num: 20 29 | data: 30 | data_manager: 31 | class_path: probts.data.data_manager.DataManager 32 | init_args: 33 | dataset: traffic_nips 34 | split_val: true 35 | scaler: standard # identity, standard, temporal 36 | batch_size: 64 37 | test_batch_size: 64 38 | num_workers: 8 -------------------------------------------------------------------------------- /config/stsf/traffic/timegrad.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | model: 13 | forecaster: 14 | class_path: probts.model.forecaster.prob_forecaster.TimeGrad 15 | init_args: 16 | loss_type: l2 17 | diff_steps: 100 18 | beta_end: 0.1 19 | beta_schedule: linear 20 | conditional_length: 100 21 | enc_hidden_size: 128 22 | enc_num_layers: 4 23 | enc_dropout: 0.1 24 | use_lags: true 25 | use_feat_idx_emb: true 26 | use_time_feat: true 27 | feat_idx_emb_dim: 1 28 | use_scaling: true 29 | num_samples: 100 30 | learning_rate: 0.001 31 | quantiles_num: 20 32 | data: 33 | data_manager: 34 | class_path: probts.data.data_manager.DataManager 35 | init_args: 36 | dataset: traffic_nips 37 | split_val: true 38 | scaler: identity # identity, standard, temporal 39 | batch_size: 32 40 | test_batch_size: 32 41 | num_workers: 8 42 | -------------------------------------------------------------------------------- /config/stsf/traffic/timesnet.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | model: 13 | forecaster: 14 | class_path: probts.model.forecaster.point_forecaster.TimesNet 15 | init_args: 16 | n_layers: 2 17 | num_kernels: 6 18 | top_k: 5 19 | d_ff: 16 20 | dropout: 0.1 21 | f_hidden_size: 16 22 | use_lags: false 23 | use_feat_idx_emb: false 24 | use_time_feat: false 25 | learning_rate: 0.001 26 | quantiles_num: 20 27 | data: 28 | data_manager: 29 | class_path: probts.data.data_manager.DataManager 30 | init_args: 31 | dataset: traffic_nips 32 | split_val: true 33 | scaler: standard # identity, standard, temporal 34 | batch_size: 64 35 | test_batch_size: 64 36 | num_workers: 8 37 | -------------------------------------------------------------------------------- /config/stsf/traffic/trans_maf.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | model: 13 | forecaster: 14 | class_path: probts.model.forecaster.prob_forecaster.Trans_MAF 15 | init_args: 16 | enc_hidden_size: 128 17 | enc_num_heads: 4 18 | enc_num_encoder_layers: 2 19 | enc_num_decoder_layers: 2 20 | enc_dim_feedforward_scale: 4 21 | enc_dropout: 0.1 22 | enc_activation: gelu 23 | n_blocks: 3 24 | hidden_size: 100 25 | n_hidden: 2 26 | batch_norm: true 27 | conditional_length: 200 28 | dequantize: false 29 | use_lags: true 30 | use_feat_idx_emb: true 31 | use_time_feat: true 32 | feat_idx_emb_dim: 1 33 | use_scaling: true 34 | num_samples: 100 35 | learning_rate: 0.001 36 | quantiles_num: 20 37 | data: 38 | data_manager: 39 | class_path: probts.data.data_manager.DataManager 40 | init_args: 41 | dataset: traffic_nips 42 | split_val: true 43 | scaler: identity # identity, standard, temporal 44 | batch_size: 32 45 | test_batch_size: 32 46 | num_workers: 8 -------------------------------------------------------------------------------- /config/stsf/traffic/transformer.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | model: 13 | forecaster: 14 | class_path: probts.model.forecaster.point_forecaster.TransformerForecaster 15 | init_args: 16 | f_hidden_size: 32 17 | num_heads: 8 18 | num_encoder_layers: 3 19 | num_decoder_layers: 3 20 | dim_feedforward_scale: 4 21 | dropout: 0.1 22 | activation: gelu 23 | use_lags: true 24 | use_feat_idx_emb: true 25 | use_time_feat: true 26 | feat_idx_emb_dim: 1 27 | learning_rate: 0.001 28 | quantiles_num: 20 29 | data: 30 | data_manager: 31 | class_path: probts.data.data_manager.DataManager 32 | init_args: 33 | dataset: traffic_nips 34 | split_val: true 35 | scaler: standard # identity, standard, temporal 36 | batch_size: 32 37 | test_batch_size: 32 38 | num_workers: 8 -------------------------------------------------------------------------------- /config/stsf/wiki/csdi.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 400 10 | log_every_n_steps: 1 11 | check_val_every_n_epoch: 3 12 | default_root_dir: ./results 13 | accumulate_grad_batches: 4 14 | model: 15 | forecaster: 16 | class_path: probts.model.forecaster.prob_forecaster.CSDI 17 | init_args: 18 | emb_time_dim: 64 19 | emb_feature_dim: 8 20 | channels: 64 21 | n_layers: 4 22 | num_heads: 8 23 | num_steps: 50 24 | diffusion_embedding_dim: 64 25 | beta_start: 0.001 26 | beta_end: 0.5 27 | sample_size: 16 28 | linear_trans: false 29 | use_lags: false 30 | use_feat_idx_emb: false 31 | use_time_feat: false 32 | feat_idx_emb_dim: 1 33 | num_samples: 100 34 | learning_rate: 0.001 35 | quantiles_num: 20 36 | data: 37 | data_manager: 38 | class_path: probts.data.data_manager.DataManager 39 | init_args: 40 | dataset: wiki2000_nips 41 | split_val: true 42 | scaler: standard # identity, standard, temporal 43 | batch_size: 8 44 | test_batch_size: 8 45 | num_workers: 8 46 | -------------------------------------------------------------------------------- /config/stsf/wiki/dlinear.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | model: 13 | forecaster: 14 | class_path: probts.model.forecaster.point_forecaster.DLinear 15 | init_args: 16 | individual: false 17 | kernel_size: 3 18 | use_lags: false 19 | use_feat_idx_emb: false 20 | use_time_feat: false 21 | learning_rate: 0.0001 22 | quantiles_num: 20 23 | data: 24 | data_manager: 25 | class_path: probts.data.data_manager.DataManager 26 | init_args: 27 | dataset: wiki2000_nips 28 | split_val: true 29 | scaler: standard # identity, standard, temporal 30 | batch_size: 32 31 | test_batch_size: 32 32 | num_workers: 8 -------------------------------------------------------------------------------- /config/stsf/wiki/gru.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | model: 13 | forecaster: 14 | class_path: probts.model.forecaster.point_forecaster.GRUForecaster 15 | init_args: 16 | f_hidden_size: 40 17 | num_layers: 2 18 | dropout: 0.1 19 | use_lags: true 20 | use_feat_idx_emb: true 21 | use_time_feat: true 22 | feat_idx_emb_dim: 1 23 | learning_rate: 0.001 24 | quantiles_num: 20 25 | data: 26 | data_manager: 27 | class_path: probts.data.data_manager.DataManager 28 | init_args: 29 | dataset: wiki2000_nips 30 | split_val: true 31 | scaler: standard # identity, standard, temporal 32 | batch_size: 64 33 | test_batch_size: 64 34 | num_workers: 8 35 | -------------------------------------------------------------------------------- /config/stsf/wiki/gru_maf.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | model: 13 | forecaster: 14 | class_path: probts.model.forecaster.prob_forecaster.GRU_MAF 15 | init_args: 16 | enc_num_layers: 2 17 | enc_hidden_size: 40 18 | enc_dropout: 0.1 19 | n_blocks: 3 20 | hidden_size: 100 21 | n_hidden: 2 22 | batch_norm: true 23 | conditional_length: 200 24 | dequantize: true 25 | use_lags: true 26 | use_feat_idx_emb: true 27 | use_time_feat: true 28 | feat_idx_emb_dim: 1 29 | use_scaling: true 30 | num_samples: 100 31 | learning_rate: 0.001 32 | quantiles_num: 20 33 | data: 34 | data_manager: 35 | class_path: probts.data.data_manager.DataManager 36 | init_args: 37 | dataset: wiki2000_nips 38 | split_val: true 39 | scaler: identity # identity, standard, temporal 40 | batch_size: 64 41 | test_batch_size: 64 42 | num_workers: 8 43 | -------------------------------------------------------------------------------- /config/stsf/wiki/gru_nvp.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | model: 13 | forecaster: 14 | class_path: probts.model.forecaster.prob_forecaster.GRU_NVP 15 | init_args: 16 | enc_hidden_size: 40 17 | enc_num_layers: 2 18 | enc_dropout: 0.1 19 | n_blocks: 3 20 | hidden_size: 100 21 | n_hidden: 2 22 | batch_norm: true 23 | conditional_length: 200 24 | dequantize: true 25 | use_lags: true 26 | use_feat_idx_emb: true 27 | use_time_feat: true 28 | feat_idx_emb_dim: 1 29 | use_scaling: true 30 | num_samples: 100 31 | learning_rate: 0.001 32 | quantiles_num: 20 33 | data: 34 | data_manager: 35 | class_path: probts.data.data_manager.DataManager 36 | init_args: 37 | dataset: wiki2000_nips 38 | split_val: true 39 | scaler: identity # identity, standard, temporal 40 | batch_size: 64 41 | test_batch_size: 64 42 | num_workers: 8 43 | -------------------------------------------------------------------------------- /config/stsf/wiki/patchtst.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 400 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | accumulate_grad_batches: 4 13 | model: 14 | forecaster: 15 | class_path: probts.model.forecaster.point_forecaster.PatchTST 16 | init_args: 17 | stride: 4 18 | patch_len: 8 19 | dropout: 0.1 20 | f_hidden_size: 32 21 | n_layers: 2 22 | n_heads: 8 23 | fc_dropout: 0.2 24 | head_dropout: 0 25 | individual: false 26 | num_samples: 100 27 | learning_rate: 0.0001 28 | quantiles_num: 20 29 | data: 30 | data_manager: 31 | class_path: probts.data.data_manager.DataManager 32 | init_args: 33 | dataset: wiki2000_nips 34 | split_val: true 35 | scaler: standard # identity, standard, temporal 36 | batch_size: 16 37 | test_batch_size: 16 38 | num_workers: 8 -------------------------------------------------------------------------------- /config/stsf/wiki/timegrad.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | model: 13 | forecaster: 14 | class_path: probts.model.forecaster.prob_forecaster.TimeGrad 15 | init_args: 16 | loss_type: l2 17 | diff_steps: 100 18 | beta_end: 0.1 19 | beta_schedule: linear 20 | conditional_length: 100 21 | enc_hidden_size: 128 22 | enc_num_layers: 4 23 | enc_dropout: 0.1 24 | use_lags: true 25 | use_feat_idx_emb: true 26 | use_time_feat: true 27 | feat_idx_emb_dim: 1 28 | use_scaling: true 29 | num_samples: 100 30 | learning_rate: 0.001 31 | quantiles_num: 20 32 | data: 33 | data_manager: 34 | class_path: probts.data.data_manager.DataManager 35 | init_args: 36 | dataset: wiki2000_nips 37 | split_val: true 38 | scaler: identity # identity, standard, temporal 39 | batch_size: 64 40 | test_batch_size: 64 41 | num_workers: 8 42 | -------------------------------------------------------------------------------- /config/stsf/wiki/timesnet.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | model: 13 | forecaster: 14 | class_path: probts.model.forecaster.point_forecaster.TimesNet 15 | init_args: 16 | n_layers: 2 17 | num_kernels: 6 18 | top_k: 5 19 | d_ff: 32 20 | dropout: 0.1 21 | f_hidden_size: 32 22 | use_lags: false 23 | use_feat_idx_emb: false 24 | use_time_feat: false 25 | learning_rate: 0.001 26 | quantiles_num: 20 27 | data: 28 | data_manager: 29 | class_path: probts.data.data_manager.DataManager 30 | init_args: 31 | dataset: wiki2000_nips 32 | split_val: true 33 | scaler: standard # identity, standard, temporal 34 | batch_size: 64 35 | test_batch_size: 64 36 | num_workers: 8 37 | -------------------------------------------------------------------------------- /config/stsf/wiki/trans_maf.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | model: 13 | forecaster: 14 | class_path: probts.model.forecaster.prob_forecaster.Trans_MAF 15 | init_args: 16 | enc_hidden_size: 128 17 | enc_num_heads: 4 18 | enc_num_encoder_layers: 2 19 | enc_num_decoder_layers: 2 20 | enc_dim_feedforward_scale: 4 21 | enc_dropout: 0.1 22 | enc_activation: gelu 23 | n_blocks: 3 24 | hidden_size: 100 25 | n_hidden: 2 26 | batch_norm: true 27 | conditional_length: 200 28 | dequantize: true 29 | use_lags: true 30 | use_feat_idx_emb: true 31 | use_time_feat: true 32 | feat_idx_emb_dim: 1 33 | use_scaling: true 34 | num_samples: 100 35 | learning_rate: 0.001 36 | quantiles_num: 20 37 | data: 38 | data_manager: 39 | class_path: probts.data.data_manager.DataManager 40 | init_args: 41 | dataset: wiki2000_nips 42 | split_val: true 43 | scaler: identity # identity, standard, temporal 44 | batch_size: 64 45 | test_batch_size: 64 46 | num_workers: 8 -------------------------------------------------------------------------------- /config/stsf/wiki/transformer.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 1 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 50 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | model: 13 | forecaster: 14 | class_path: probts.model.forecaster.point_forecaster.TransformerForecaster 15 | init_args: 16 | f_hidden_size: 32 17 | num_heads: 8 18 | num_encoder_layers: 3 19 | num_decoder_layers: 3 20 | dim_feedforward_scale: 4 21 | dropout: 0.1 22 | activation: gelu 23 | use_lags: true 24 | use_feat_idx_emb: true 25 | use_time_feat: true 26 | feat_idx_emb_dim: 1 27 | learning_rate: 0.001 28 | quantiles_num: 20 29 | data: 30 | data_manager: 31 | class_path: probts.data.data_manager.DataManager 32 | init_args: 33 | dataset: wiki2000_nips 34 | split_val: true 35 | scaler: standard # identity, standard, temporal 36 | batch_size: 64 37 | test_batch_size: 64 38 | num_workers: 8 -------------------------------------------------------------------------------- /config/tsfm/chronos.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 0 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 40 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | model: 13 | forecaster: 14 | class_path: probts.model.forecaster.prob_forecaster.Chronos 15 | init_args: 16 | model_size: base # tiny, mini, small, base, large 17 | num_samples: 100 18 | quantiles_num: 20 19 | data: 20 | data_manager: 21 | class_path: probts.data.data_manager.DataManager 22 | init_args: 23 | dataset: solar_nips 24 | split_val: true 25 | scaler: standard # identity, standard, temporal 26 | batch_size: 16 27 | test_batch_size: 16 28 | num_workers: 8 -------------------------------------------------------------------------------- /config/tsfm/forecastpfn.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 0 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 40 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | model: 13 | forecaster: 14 | class_path: probts.model.forecaster.point_forecaster.ForecastPFN 15 | init_args: 16 | label_len: 48 17 | ckpt_path: ./checkpoints/ForecastPFN/saved_weights 18 | quantiles_num: 20 19 | data: 20 | data_manager: 21 | class_path: probts.data.data_manager.DataManager 22 | init_args: 23 | dataset: solar_nips 24 | split_val: true 25 | scaler: standard # identity, standard, temporal 26 | timeenc: 2 27 | batch_size: 64 28 | test_batch_size: 64 29 | num_workers: 8 -------------------------------------------------------------------------------- /config/tsfm/lag_llama.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 0 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 40 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | model: 13 | forecaster: 14 | class_path: probts.model.forecaster.prob_forecaster.LagLlama 15 | init_args: 16 | use_rope_scaling: true 17 | ckpt_path: ./checkpoints/lag-llama/lag-llama.ckpt 18 | num_samples: 100 19 | quantiles_num: 20 20 | data: 21 | data_manager: 22 | class_path: probts.data.data_manager.DataManager 23 | init_args: 24 | dataset: solar_nips 25 | split_val: true 26 | scaler: identity # identity, standard, temporal 27 | timeenc: 2 28 | batch_size: 1 29 | test_batch_size: 1 30 | num_workers: 8 -------------------------------------------------------------------------------- /config/tsfm/moirai.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 0 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 40 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | model: 13 | forecaster: 14 | class_path: probts.model.forecaster.prob_forecaster.Moirai 15 | init_args: 16 | variate_mode: S 17 | patch_size: auto 18 | model_size: base 19 | scaling: true 20 | num_samples: 100 21 | quantiles_num: 20 22 | data: 23 | data_manager: 24 | class_path: probts.data.data_manager.DataManager 25 | init_args: 26 | dataset: solar_nips 27 | split_val: true 28 | scaler: identity # identity, standard, temporal 29 | auto_search: true 30 | batch_size: 64 31 | test_batch_size: 64 32 | num_workers: 8 -------------------------------------------------------------------------------- /config/tsfm/moirai/context_5000/electricity_ltsf.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 0 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 1 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | model: 13 | forecaster: 14 | class_path: probts.model.forecaster.prob_forecaster.Moirai 15 | init_args: 16 | variate_mode: S 17 | patch_size: 128 18 | model_size: base 19 | scaling: true 20 | num_samples: 100 21 | quantiles_num: 20 22 | data: 23 | data_manager: 24 | class_path: probts.data.data_manager.DataManager 25 | init_args: 26 | dataset: electricity_ltsf 27 | split_val: true 28 | scaler: standard # identity, standard, temporal 29 | var_specific_norm: false 30 | context_length: 5000 31 | auto_search: true 32 | batch_size: 1 33 | test_batch_size: 1 34 | num_workers: 8 -------------------------------------------------------------------------------- /config/tsfm/moirai/context_5000/electricity_nips.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 0 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 1 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | model: 13 | forecaster: 14 | class_path: probts.model.forecaster.prob_forecaster.Moirai 15 | init_args: 16 | variate_mode: S 17 | patch_size: 64 18 | model_size: base 19 | scaling: true 20 | num_samples: 100 21 | quantiles_num: 20 22 | data: 23 | data_manager: 24 | class_path: probts.data.data_manager.DataManager 25 | init_args: 26 | dataset: electricity_nips 27 | split_val: true 28 | scaler: standard # identity, standard, temporal 29 | var_specific_norm: true 30 | context_length: 3800 # maximum history length 31 | auto_search: true 32 | batch_size: 1 33 | test_batch_size: 1 34 | num_workers: 8 -------------------------------------------------------------------------------- /config/tsfm/moirai/context_5000/etth1.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 0 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 1 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | model: 13 | forecaster: 14 | class_path: probts.model.forecaster.prob_forecaster.Moirai 15 | init_args: 16 | variate_mode: M 17 | patch_size: 64 18 | model_size: base 19 | scaling: true 20 | num_samples: 100 21 | quantiles_num: 20 22 | data: 23 | data_manager: 24 | class_path: probts.data.data_manager.DataManager 25 | init_args: 26 | dataset: etth1 27 | split_val: true 28 | scaler: standard # identity, standard, temporal 29 | context_length: 5000 30 | auto_search: true 31 | batch_size: 64 32 | test_batch_size: 64 33 | num_workers: 8 -------------------------------------------------------------------------------- /config/tsfm/moirai/context_5000/etth2.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 0 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 1 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | model: 13 | forecaster: 14 | class_path: probts.model.forecaster.prob_forecaster.Moirai 15 | init_args: 16 | variate_mode: M 17 | patch_size: 64 18 | model_size: base 19 | scaling: true 20 | num_samples: 100 21 | quantiles_num: 20 22 | data: 23 | data_manager: 24 | class_path: probts.data.data_manager.DataManager 25 | init_args: 26 | dataset: etth2 27 | split_val: true 28 | scaler: standard # identity, standard, temporal 29 | context_length: 5000 30 | auto_search: true 31 | batch_size: 64 32 | test_batch_size: 64 33 | num_workers: 8 -------------------------------------------------------------------------------- /config/tsfm/moirai/context_5000/ettm1.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 0 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 1 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | model: 13 | forecaster: 14 | class_path: probts.model.forecaster.prob_forecaster.Moirai 15 | init_args: 16 | variate_mode: S 17 | patch_size: 64 18 | model_size: base 19 | scaling: true 20 | num_samples: 100 21 | quantiles_num: 20 22 | data: 23 | data_manager: 24 | class_path: probts.data.data_manager.DataManager 25 | init_args: 26 | dataset: ettm1 27 | split_val: true 28 | scaler: standard # identity, standard, temporal 29 | context_length: 5000 30 | auto_search: true 31 | batch_size: 64 32 | test_batch_size: 64 33 | num_workers: 8 -------------------------------------------------------------------------------- /config/tsfm/moirai/context_5000/ettm2.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 0 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 1 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | model: 13 | forecaster: 14 | class_path: probts.model.forecaster.prob_forecaster.Moirai 15 | init_args: 16 | variate_mode: M 17 | patch_size: 128 18 | model_size: base 19 | scaling: true 20 | num_samples: 100 21 | quantiles_num: 20 22 | data: 23 | data_manager: 24 | class_path: probts.data.data_manager.DataManager 25 | init_args: 26 | dataset: ettm2 27 | split_val: true 28 | scaler: standard # identity, standard, temporal 29 | context_length: 5000 30 | auto_search: true 31 | batch_size: 64 32 | test_batch_size: 64 33 | num_workers: 8 -------------------------------------------------------------------------------- /config/tsfm/moirai/context_5000/exchange_rate_nips.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 0 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 1 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | model: 13 | forecaster: 14 | class_path: probts.model.forecaster.prob_forecaster.Moirai 15 | init_args: 16 | variate_mode: M 17 | patch_size: 128 18 | model_size: base 19 | scaling: true 20 | num_samples: 100 21 | quantiles_num: 20 22 | data: 23 | data_manager: 24 | class_path: probts.data.data_manager.DataManager 25 | init_args: 26 | dataset: exchange_rate_nips 27 | split_val: true 28 | scaler: standard # identity, standard, temporal 29 | var_specific_norm: false 30 | context_length: 5000 31 | auto_search: true 32 | batch_size: 1 33 | test_batch_size: 1 34 | num_workers: 8 -------------------------------------------------------------------------------- /config/tsfm/moirai/context_5000/solar_nips.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 0 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 1 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | model: 13 | forecaster: 14 | class_path: probts.model.forecaster.prob_forecaster.Moirai 15 | init_args: 16 | variate_mode: S 17 | patch_size: auto 18 | model_size: base 19 | scaling: true 20 | num_samples: 100 21 | quantiles_num: 20 22 | data: 23 | data_manager: 24 | class_path: probts.data.data_manager.DataManager 25 | init_args: 26 | dataset: solar_nips 27 | split_val: true 28 | scaler: identity # identity, standard, temporal 29 | var_specific_norm: false 30 | context_length: 5000 31 | auto_search: true 32 | batch_size: 1 33 | test_batch_size: 1 34 | num_workers: 8 -------------------------------------------------------------------------------- /config/tsfm/moirai/context_5000/weather_ltsf.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 0 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 1 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | model: 13 | forecaster: 14 | class_path: probts.model.forecaster.prob_forecaster.Moirai 15 | init_args: 16 | variate_mode: M 17 | patch_size: 128 18 | model_size: base 19 | scaling: true 20 | num_samples: 100 21 | quantiles_num: 20 22 | data: 23 | data_manager: 24 | class_path: probts.data.data_manager.DataManager 25 | init_args: 26 | dataset: weather_ltsf 27 | split_val: true 28 | scaler: standard # identity, standard, temporal 29 | var_specific_norm: true 30 | context_length: 5000 31 | auto_search: true 32 | batch_size: 64 33 | test_batch_size: 64 34 | num_workers: 8 -------------------------------------------------------------------------------- /config/tsfm/moirai/context_96/electricity_ltsf.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 0 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 1 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | model: 13 | forecaster: 14 | class_path: probts.model.forecaster.prob_forecaster.Moirai 15 | init_args: 16 | variate_mode: S 17 | patch_size: auto 18 | model_size: base 19 | scaling: true 20 | num_samples: 100 21 | quantiles_num: 20 22 | data: 23 | data_manager: 24 | class_path: probts.data.data_manager.DataManager 25 | init_args: 26 | dataset: electricity_ltsf 27 | split_val: true 28 | scaler: standard # identity, standard, temporal 29 | var_specific_norm: false 30 | context_length: 96 31 | auto_search: true 32 | batch_size: 4 33 | test_batch_size: 4 34 | num_workers: 8 -------------------------------------------------------------------------------- /config/tsfm/moirai/context_96/electricity_nips.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 0 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 1 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | model: 13 | forecaster: 14 | class_path: probts.model.forecaster.prob_forecaster.Moirai 15 | init_args: 16 | variate_mode: S 17 | patch_size: 64 18 | model_size: base 19 | scaling: true 20 | num_samples: 100 21 | quantiles_num: 20 22 | data: 23 | data_manager: 24 | class_path: probts.data.data_manager.DataManager 25 | init_args: 26 | dataset: electricity_nips 27 | split_val: true 28 | scaler: standard # identity, standard, temporal 29 | var_specific_norm: true 30 | context_length: 96 31 | auto_search: true 32 | batch_size: 1 33 | test_batch_size: 1 34 | num_workers: 8 -------------------------------------------------------------------------------- /config/tsfm/moirai/context_96/etth1.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 0 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 1 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | model: 13 | forecaster: 14 | class_path: probts.model.forecaster.prob_forecaster.Moirai 15 | init_args: 16 | variate_mode: M 17 | patch_size: auto 18 | model_size: base 19 | scaling: true 20 | num_samples: 100 21 | quantiles_num: 20 22 | data: 23 | data_manager: 24 | class_path: probts.data.data_manager.DataManager 25 | init_args: 26 | dataset: etth1 27 | split_val: true 28 | scaler: standard # identity, standard, temporal 29 | var_specific_norm: false 30 | context_length: 96 31 | auto_search: true 32 | batch_size: 64 33 | test_batch_size: 64 34 | num_workers: 8 -------------------------------------------------------------------------------- /config/tsfm/moirai/context_96/etth2.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 0 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 1 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | model: 13 | forecaster: 14 | class_path: probts.model.forecaster.prob_forecaster.Moirai 15 | init_args: 16 | variate_mode: M 17 | patch_size: auto 18 | model_size: base 19 | scaling: true 20 | num_samples: 100 21 | quantiles_num: 20 22 | data: 23 | data_manager: 24 | class_path: probts.data.data_manager.DataManager 25 | init_args: 26 | dataset: etth2 27 | split_val: true 28 | scaler: standard # identity, standard, temporal 29 | var_specific_norm: false 30 | context_length: 96 31 | auto_search: true 32 | batch_size: 64 33 | test_batch_size: 64 34 | num_workers: 8 -------------------------------------------------------------------------------- /config/tsfm/moirai/context_96/ettm1.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 0 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 1 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | model: 13 | forecaster: 14 | class_path: probts.model.forecaster.prob_forecaster.Moirai 15 | init_args: 16 | variate_mode: M 17 | patch_size: auto 18 | model_size: base 19 | scaling: true 20 | num_samples: 100 21 | quantiles_num: 20 22 | data: 23 | data_manager: 24 | class_path: probts.data.data_manager.DataManager 25 | init_args: 26 | dataset: ettm1 27 | split_val: true 28 | scaler: standard # identity, standard, temporal 29 | var_specific_norm: false 30 | context_length: 96 31 | auto_search: true 32 | batch_size: 64 33 | test_batch_size: 64 34 | num_workers: 8 -------------------------------------------------------------------------------- /config/tsfm/moirai/context_96/ettm2.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 0 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 1 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | model: 13 | forecaster: 14 | class_path: probts.model.forecaster.prob_forecaster.Moirai 15 | init_args: 16 | variate_mode: M 17 | patch_size: auto 18 | model_size: base 19 | scaling: true 20 | num_samples: 100 21 | quantiles_num: 20 22 | data: 23 | data_manager: 24 | class_path: probts.data.data_manager.DataManager 25 | init_args: 26 | dataset: ettm2 27 | split_val: true 28 | scaler: standard # identity, standard, temporal 29 | var_specific_norm: false 30 | context_length: 96 31 | auto_search: true 32 | batch_size: 64 33 | test_batch_size: 64 34 | num_workers: 8 -------------------------------------------------------------------------------- /config/tsfm/moirai/context_96/exchange_rate_nips.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 0 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 1 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | model: 13 | forecaster: 14 | class_path: probts.model.forecaster.prob_forecaster.Moirai 15 | init_args: 16 | variate_mode: M 17 | patch_size: auto 18 | model_size: base 19 | scaling: true 20 | num_samples: 100 21 | quantiles_num: 20 22 | data: 23 | data_manager: 24 | class_path: probts.data.data_manager.DataManager 25 | init_args: 26 | dataset: exchange_rate_nips 27 | split_val: true 28 | scaler: standard # identity, standard, temporal 29 | var_specific_norm: true 30 | context_length: 96 31 | auto_search: true 32 | batch_size: 1 33 | test_batch_size: 1 34 | num_workers: 8 -------------------------------------------------------------------------------- /config/tsfm/moirai/context_96/solar_nips.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 0 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 1 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | model: 13 | forecaster: 14 | class_path: probts.model.forecaster.prob_forecaster.Moirai 15 | init_args: 16 | variate_mode: S 17 | patch_size: auto 18 | model_size: base 19 | scaling: true 20 | num_samples: 100 21 | quantiles_num: 20 22 | data: 23 | data_manager: 24 | class_path: probts.data.data_manager.DataManager 25 | init_args: 26 | dataset: solar_nips 27 | split_val: true 28 | scaler: identity # identity, standard, temporal 29 | var_specific_norm: false 30 | context_length: 96 31 | auto_search: true 32 | batch_size: 1 33 | test_batch_size: 1 34 | num_workers: 8 -------------------------------------------------------------------------------- /config/tsfm/moirai/context_96/weather_ltsf.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 0 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 1 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | model: 13 | forecaster: 14 | class_path: probts.model.forecaster.prob_forecaster.Moirai 15 | init_args: 16 | variate_mode: M 17 | patch_size: auto 18 | model_size: base 19 | scaling: true 20 | num_samples: 100 21 | quantiles_num: 20 22 | data: 23 | data_manager: 24 | class_path: probts.data.data_manager.DataManager 25 | init_args: 26 | dataset: weather_ltsf 27 | split_val: true 28 | scaler: standard # identity, standard, temporal 29 | var_specific_norm: true 30 | context_length: 96 31 | auto_search: true 32 | batch_size: 64 33 | test_batch_size: 64 34 | num_workers: 8 -------------------------------------------------------------------------------- /config/tsfm/time_moe.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 0 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 40 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | model: 13 | forecaster: 14 | class_path: probts.model.forecaster.point_forecaster.TimeMoE 15 | init_args: 16 | model_size: 200M # select from ['50M', '200M'] 17 | instance_norm: true 18 | quantiles_num: 20 19 | data: 20 | data_manager: 21 | class_path: probts.data.data_manager.DataManager 22 | init_args: 23 | dataset: solar_nips 24 | split_val: true 25 | scaler: identity # identity, standard, temporal 26 | var_specific_norm: true 27 | batch_size: 64 28 | test_batch_size: 64 29 | num_workers: 8 -------------------------------------------------------------------------------- /config/tsfm/timer.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 0 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 40 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | model: 13 | forecaster: 14 | class_path: probts.model.forecaster.point_forecaster.Timer 15 | init_args: 16 | label_len: 96 17 | ckpt_path: ./checkpoints/timer/Timer_67M_UTSD_4G.pt 18 | quantiles_num: 20 19 | data: 20 | data_manager: 21 | class_path: probts.data.data_manager.DataManager 22 | init_args: 23 | dataset: solar_nips 24 | split_val: true 25 | scaler: standard # identity, standard, temporal 26 | batch_size: 64 27 | test_batch_size: 64 28 | num_workers: 8 -------------------------------------------------------------------------------- /config/tsfm/timesfm.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 0 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 40 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | model: 13 | forecaster: 14 | class_path: probts.model.forecaster.point_forecaster.TimesFM 15 | init_args: 16 | model_size: 200m # select from ['200m', '500m'] 17 | quantiles_num: 20 18 | data: 19 | data_manager: 20 | class_path: probts.data.data_manager.DataManager 21 | init_args: 22 | dataset: solar_nips 23 | split_val: true 24 | scaler: identity # identity, standard, temporal 25 | var_specific_norm: true 26 | batch_size: 64 27 | test_batch_size: 64 28 | num_workers: 8 -------------------------------------------------------------------------------- /config/tsfm/tinytimemixer.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 0 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 40 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | model: 13 | forecaster: 14 | class_path: probts.model.forecaster.point_forecaster.TinyTimeMixer 15 | quantiles_num: 20 16 | data: 17 | data_manager: 18 | class_path: probts.data.data_manager.DataManager 19 | init_args: 20 | dataset: solar_nips 21 | split_val: true 22 | scaler: standard # identity, standard, temporal 23 | batch_size: 64 24 | test_batch_size: 64 25 | num_workers: 8 -------------------------------------------------------------------------------- /config/tsfm/units.yaml: -------------------------------------------------------------------------------- 1 | # lightning==2.3.0.dev0 2 | seed_everything: 0 3 | trainer: 4 | accelerator: gpu 5 | devices: 1 6 | strategy: auto 7 | max_epochs: 40 8 | use_distributed_sampler: false 9 | limit_train_batches: 100 10 | log_every_n_steps: 1 11 | default_root_dir: ./results 12 | model: 13 | forecaster: 14 | class_path: probts.model.forecaster.point_forecaster.UniTS 15 | init_args: 16 | ckpt_path: ./checkpoints/units/units_x128_pretrain_checkpoint.pth 17 | quantiles_num: 20 18 | data: 19 | data_manager: 20 | class_path: probts.data.data_manager.DataManager 21 | init_args: 22 | dataset: solar_nips 23 | split_val: true 24 | scaler: standard # identity, standard, temporal 25 | # var_norm: true 26 | batch_size: 64 27 | test_batch_size: 64 28 | num_workers: 8 -------------------------------------------------------------------------------- /datasets/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /docs/benchmark/README.md: -------------------------------------------------------------------------------- 1 | # Benchmarking :balance_scale: 2 | 3 | Accurate point and distributional forecasts across diverse horizons are crucial for time-series forecasting. However, existing research often focuses on isolated aspects, such as long-term point forecasting or short-term probabilistic estimation. This raises a fundamental question: **How do different methodological designs address these diverse forecasting needs?** 4 | 5 | In this repository, we: 6 | 1. **Provide Detailed Reproduction Guides:** Offer comprehensive instructions for replicating supervised models and pre-trained foundation models. 7 | 2. **Evaluate Methods Under a Unified Framework:** Align and assess existing methods across various data scenarios using a consistent benchmarking framework. 8 | 3. **Deliver In-Depth Insights:** Present detailed analyses and insights into the experimental results. 9 | 10 | 11 | ## Benchmarking Scripts 12 | 13 | - [Supervised Forecasting Models](./supervised_model/README.md) 14 | - [Pre-trained Time-Series Foundation Models](./foundation_model/README.md) 15 | 16 | ## Methodology Overview 17 | 18 | ![Methodology](./figs/methodology.jpg) -------------------------------------------------------------------------------- /docs/benchmark/figs/methodology.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/ProbTS/66f826cb39cfecb7920f0f953fc8743b6c28229b/docs/benchmark/figs/methodology.jpg -------------------------------------------------------------------------------- /docs/benchmark/foundation_model/figs/FM_dataset.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/ProbTS/66f826cb39cfecb7920f0f953fc8743b6c28229b/docs/benchmark/foundation_model/figs/FM_dataset.jpg -------------------------------------------------------------------------------- /docs/benchmark/foundation_model/figs/FM_summary.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/ProbTS/66f826cb39cfecb7920f0f953fc8743b6c28229b/docs/benchmark/foundation_model/figs/FM_summary.jpg -------------------------------------------------------------------------------- /docs/benchmark/foundation_model/figs/fm_short_term.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/ProbTS/66f826cb39cfecb7920f0f953fc8743b6c28229b/docs/benchmark/foundation_model/figs/fm_short_term.jpg -------------------------------------------------------------------------------- /docs/benchmark/foundation_model/figs/fm_var_hor.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/ProbTS/66f826cb39cfecb7920f0f953fc8743b6c28229b/docs/benchmark/foundation_model/figs/fm_var_hor.jpg -------------------------------------------------------------------------------- /docs/benchmark/foundation_model/figs/foundation_model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/ProbTS/66f826cb39cfecb7920f0f953fc8743b6c28229b/docs/benchmark/foundation_model/figs/foundation_model.png -------------------------------------------------------------------------------- /docs/benchmark/foundation_model/figs/tsfm_analysis.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/ProbTS/66f826cb39cfecb7920f0f953fc8743b6c28229b/docs/benchmark/foundation_model/figs/tsfm_analysis.jpg -------------------------------------------------------------------------------- /docs/benchmark/foundation_model/figs/tsfm_results.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/ProbTS/66f826cb39cfecb7920f0f953fc8743b6c28229b/docs/benchmark/foundation_model/figs/tsfm_results.jpg -------------------------------------------------------------------------------- /docs/benchmark/foundation_model/forecastpfn.md: -------------------------------------------------------------------------------- 1 | # Running Inference with ForecastPFN 2 | 3 | [Original Repository](https://github.com/abacusai/ForecastPFN) | [Paper](https://arxiv.org/abs/2311.01933) 4 | 5 | Follow these steps to set up and run inference using ForecastPFN: 6 | 7 | 1. Set up the [environment](../README.md#results-reproduction). 8 | 2. Run the inference script with the following commands: 9 | 10 | ```bash 11 | # ForecastPFN 12 | MODEL='forecastpfn' 13 | for DATASET in 'etth1' 'etth2' 'ettm1' 'ettm2' 'weather_ltsf'; do 14 | for CTX_LEN in 96; do 15 | for PRED_LEN in 24 48 96 192 336 720; do 16 | python run.py --config config/tsfm/${MODEL}.yaml --seed_everything 0 \ 17 | --data.data_manager.init_args.path ${DATA_DIR} \ 18 | --trainer.default_root_dir ${LOG_DIR} \ 19 | --data.data_manager.init_args.split_val true \ 20 | --data.data_manager.init_args.dataset ${DATASET} \ 21 | --data.data_manager.init_args.context_length ${CTX_LEN} \ 22 | --data.data_manager.init_args.prediction_length ${PRED_LEN} \ 23 | --model.forecaster.init_args.ckpt_path './checkpoints/ForecastPFN/saved_weights' \ 24 | --data.test_batch_size 64 25 | done 26 | done 27 | done 28 | ``` 29 | -------------------------------------------------------------------------------- /docs/benchmark/foundation_model/lag-llama.md: -------------------------------------------------------------------------------- 1 | # Running Inference with Lag-Llama 2 | 3 | [Original Repository](https://github.com/time-series-foundation-models/lag-llama) | [Paper](https://arxiv.org/abs/2310.08278) 4 | 5 | Follow these steps to set up and run inference using Lag-Llama: 6 | 7 | 1. Set up the [environment and initialize submodules](../README.md#results-reproduction). 8 | 2. Run the inference script with the following commands: 9 | 10 | ```bash 11 | # Lag-Llama 12 | MODEL='lag_llama' 13 | for DATASET in 'etth1' 'etth2' 'ettm1' 'ettm2' 'weather_ltsf'; do 14 | for CTX_LEN in 512; do 15 | for PRED_LEN in 24 48 96 192 336 720; do 16 | python run.py --config config/tsfm/${MODEL}.yaml --seed_everything 0 \ 17 | --data.data_manager.init_args.path ${DATA_DIR} \ 18 | --trainer.default_root_dir ${LOG_DIR} \ 19 | --data.data_manager.init_args.split_val true \ 20 | --data.data_manager.init_args.dataset ${DATASET} \ 21 | --data.data_manager.init_args.context_length ${CTX_LEN} \ 22 | --data.data_manager.init_args.prediction_length ${PRED_LEN} \ 23 | --model.forecaster.init_args.ckpt_path './checkpoints/lag-llama/lag-llama.ckpt' \ 24 | --data.test_batch_size 1 25 | done 26 | done 27 | done 28 | ``` 29 | -------------------------------------------------------------------------------- /docs/benchmark/foundation_model/ttm.md: -------------------------------------------------------------------------------- 1 | # Running Inference with Tiny Time Mixers 2 | 3 | [Original Repository](https://github.com/ibm-granite/granite-tsfm/tree/main/tsfm_public/models/tinytimemixer) | [Paper](https://arxiv.org/abs/2401.03955) 4 | 5 | Follow these steps to set up and run inference using Tiny Time Mixers: 6 | 7 | 1. Set up the [environment and initialize submodules](../README.md#results-reproduction). 8 | 2. Run the inference script with the following commands: 9 | 10 | ```bash 11 | MODEL='tinytimemixer' 12 | for DATASET in 'etth1' 'etth2' 'ettm1' 'ettm2' 'weather_ltsf'; do 13 | for CTX_LEN in 5000 96; do 14 | for PRED_LEN in 24 48 96 192 336 720; do 15 | python run.py --config config/tsfm/${MODEL}.yaml --seed_everything 0 \ 16 | --data.data_manager.init_args.path ${DATA_DIR} \ 17 | --trainer.default_root_dir ${LOG_DIR} \ 18 | --data.data_manager.init_args.split_val true \ 19 | --data.data_manager.init_args.dataset ${DATASET} \ 20 | --data.data_manager.init_args.context_length ${CTX_LEN} \ 21 | --data.data_manager.init_args.prediction_length ${PRED_LEN} \ 22 | --data.test_batch_size 1 23 | done 24 | done 25 | done 26 | ``` 27 | -------------------------------------------------------------------------------- /docs/benchmark/foundation_model/units.md: -------------------------------------------------------------------------------- 1 | # Running Inference with UniTS 2 | 3 | [Original Repository](https://github.com/mims-harvard/UniTS) | [Paper](https://arxiv.org/pdf/2403.00131) 4 | 5 | Follow these steps to set up and run inference using UniTS: 6 | 7 | 1. Set up the [environment](../README.md#results-reproduction). 8 | 2. Run the inference script with the following commands: 9 | 10 | ```bash 11 | MODEL='units' 12 | for DATASET in 'etth1' 'etth2' 'ettm1' 'ettm2'; do 13 | for CTX_LEN in 96; do 14 | for PRED_LEN in 24 48 96 192 336 720; do 15 | python run.py --config config/tsfm/${MODEL}.yaml --seed_everything 0 \ 16 | --data.data_manager.init_args.path ${DATA_DIR} \ 17 | --trainer.default_root_dir ${LOG_DIR} \ 18 | --data.data_manager.init_args.split_val true \ 19 | --data.data_manager.init_args.dataset ${DATASET} \ 20 | --data.data_manager.init_args.context_length ${CTX_LEN} \ 21 | --data.data_manager.init_args.prediction_length ${PRED_LEN} \ 22 | --model.forecaster.init_args.ckpt_path './checkpoints/units/units_x128_pretrain_checkpoint.pth' \ 23 | --data.test_batch_size 64 24 | done 25 | done 26 | done 27 | ``` 28 | -------------------------------------------------------------------------------- /docs/benchmark/supervised_model/figs/ar_vs_nar.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/ProbTS/66f826cb39cfecb7920f0f953fc8743b6c28229b/docs/benchmark/supervised_model/figs/ar_vs_nar.jpg -------------------------------------------------------------------------------- /docs/benchmark/supervised_model/figs/long_bench.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/ProbTS/66f826cb39cfecb7920f0f953fc8743b6c28229b/docs/benchmark/supervised_model/figs/long_bench.jpg -------------------------------------------------------------------------------- /docs/benchmark/supervised_model/figs/norm.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/ProbTS/66f826cb39cfecb7920f0f953fc8743b6c28229b/docs/benchmark/supervised_model/figs/norm.jpg -------------------------------------------------------------------------------- /docs/benchmark/supervised_model/figs/point_vs_prob.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/ProbTS/66f826cb39cfecb7920f0f953fc8743b6c28229b/docs/benchmark/supervised_model/figs/point_vs_prob.jpg -------------------------------------------------------------------------------- /docs/benchmark/supervised_model/figs/short_bench.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/ProbTS/66f826cb39cfecb7920f0f953fc8743b6c28229b/docs/benchmark/supervised_model/figs/short_bench.jpg -------------------------------------------------------------------------------- /docs/benchmark/supervised_model/figs/supervised.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/ProbTS/66f826cb39cfecb7920f0f953fc8743b6c28229b/docs/benchmark/supervised_model/figs/supervised.png -------------------------------------------------------------------------------- /docs/documentation/Gift_eval.md: -------------------------------------------------------------------------------- 1 | 2 | ## How to evaluate the models in ProbTS using the GIFT-EVAL benchmark 3 | 4 | Link to the GIFT-EVAL benchmark: [Github Repo](https://github.com/SalesforceAIResearch/gift-eval) [Paper](https://openreview.net/forum?id=9EBSEkFSje) 5 | 6 | 1. Follow installation instructions in the GIFT-EVAL repository to **download the dataset** from its huggingface dataset repository. 7 | 2. Also, set the environment variable `GIFT_EVAL` to the path where the dataset is downloaded. 8 | ``` bash 9 | echo "GIFT_EVAL=/path/to/gift-eval" >> .env 10 | ``` 11 | 3. Quick start example: 12 | ``` bash 13 | python run.py --config config/default/mean.yaml \ 14 | --seed_everything 0 \ 15 | --model.forecaster.init_args.mode batch \ 16 | --data.data_manager.init_args.dataset gift/ett1/H/long \ 17 | --data.data_manager.init_args.path ./datasets \ 18 | --trainer.default_root_dir ./exps 19 | ``` 20 | 21 | > [!NOTE] 22 | > The dataset name for the GIFT-EVAL format should be specified as follows: `"gift/" + "dataset_name (main_name/freq)" + "short/medium/long"`. For example, `gift/ett1/H/long`. More dataset names can be found in the GIFT-EVAL repository (for example [naive.ipynb](https://github.com/SalesforceAIResearch/gift-eval/blob/main/notebooks/naive.ipynb)). 23 | -------------------------------------------------------------------------------- /docs/figs/data_pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/ProbTS/66f826cb39cfecb7920f0f953fc8743b6c28229b/docs/figs/data_pipeline.png -------------------------------------------------------------------------------- /docs/figs/probts_framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/ProbTS/66f826cb39cfecb7920f0f953fc8743b6c28229b/docs/figs/probts_framework.png -------------------------------------------------------------------------------- /docs/figs/probts_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/ProbTS/66f826cb39cfecb7920f0f953fc8743b6c28229b/docs/figs/probts_logo.png -------------------------------------------------------------------------------- /exps/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /probts/__init__.py: -------------------------------------------------------------------------------- 1 | from .data import * 2 | from .model import * 3 | from .utils import * -------------------------------------------------------------------------------- /probts/callbacks/__init__.py: -------------------------------------------------------------------------------- 1 | from .memory_callback import MemoryCallback 2 | from .time_callback import TimeCallback -------------------------------------------------------------------------------- /probts/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .data_module import * 2 | from .data_manager import * 3 | from .data_utils.time_features import * -------------------------------------------------------------------------------- /probts/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .forecast_module import * -------------------------------------------------------------------------------- /probts/model/forecaster/__init__.py: -------------------------------------------------------------------------------- 1 | from .forecaster import Forecaster 2 | from .point_forecaster import * 3 | from .prob_forecaster import * -------------------------------------------------------------------------------- /probts/model/forecaster/point_forecaster/mean.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import repeat 3 | from probts.model.forecaster import Forecaster 4 | 5 | 6 | class MeanForecaster(Forecaster): 7 | def __init__( 8 | self, 9 | global_mean: torch.Tensor, 10 | mode: str = 'batch', 11 | **kwargs 12 | ): 13 | super().__init__(**kwargs) 14 | self.global_mean = global_mean 15 | self.mode = mode 16 | self.no_training = True 17 | 18 | @property 19 | def name(self): 20 | return self.mode + self.__class__.__name__ 21 | 22 | def forecast(self, batch_data, num_samples=None): 23 | B = batch_data.past_target_cdf.shape[0] 24 | if self.mode == 'global': 25 | outputs = self.global_mean.clone() 26 | elif self.mode == 'batch': 27 | outputs = torch.mean(batch_data.past_target_cdf, dim=1) 28 | outputs = torch.mean(outputs, dim=0) 29 | else: 30 | raise ValueError(f"Unsupported mode: {self.mode}") 31 | 32 | outputs = repeat(outputs,'d -> b n l d', b=B, n=1, l=self.prediction_length) 33 | return outputs 34 | -------------------------------------------------------------------------------- /probts/model/forecaster/point_forecaster/naive.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import repeat 3 | from probts.model.forecaster import Forecaster 4 | import sys 5 | 6 | class NaiveForecaster(Forecaster): 7 | def __init__( 8 | self, 9 | **kwargs 10 | ): 11 | super().__init__(**kwargs) 12 | self.no_training = True 13 | 14 | 15 | def forecast(self, batch_data, num_samples=None): 16 | last_value = batch_data.past_target_cdf[:,-1,:] 17 | outputs = repeat(last_value,'b k -> b n l k', n=1, l=self.prediction_length) 18 | return outputs 19 | -------------------------------------------------------------------------------- /probts/model/nn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/ProbTS/66f826cb39cfecb7920f0f953fc8743b6c28229b/probts/model/nn/__init__.py -------------------------------------------------------------------------------- /probts/model/nn/arch/ChronosModule/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | from .base import BaseChronosPipeline, ForecastType 5 | from .chronos import ( 6 | ChronosConfig, 7 | ChronosModel, 8 | ChronosPipeline, 9 | ChronosTokenizer, 10 | MeanScaleUniformBins, 11 | ) 12 | 13 | from .chronos_bolt import ChronosBoltConfig, ChronosBoltPipeline 14 | 15 | __all__ = [ 16 | "BaseChronosPipeline", 17 | "ForecastType", 18 | "ChronosConfig", 19 | "ChronosModel", 20 | "ChronosPipeline", 21 | "ChronosTokenizer", 22 | "MeanScaleUniformBins", 23 | "ChronosBoltConfig", 24 | "ChronosBoltPipeline", 25 | ] 26 | -------------------------------------------------------------------------------- /probts/model/nn/arch/ChronosModule/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | 5 | from typing import List 6 | 7 | import torch 8 | 9 | 10 | def left_pad_and_stack_1D(tensors: List[torch.Tensor]) -> torch.Tensor: 11 | max_len = max(len(c) for c in tensors) 12 | padded = [] 13 | for c in tensors: 14 | assert isinstance(c, torch.Tensor) 15 | assert c.ndim == 1 16 | padding = torch.full( 17 | size=(max_len - len(c),), fill_value=torch.nan, device=c.device 18 | ) 19 | padded.append(torch.concat((padding, c), dim=-1)) 20 | return torch.stack(padded) 21 | -------------------------------------------------------------------------------- /probts/model/nn/arch/ElasTSTModule/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/ProbTS/66f826cb39cfecb7920f0f953fc8743b6c28229b/probts/model/nn/arch/ElasTSTModule/__init__.py -------------------------------------------------------------------------------- /probts/model/nn/arch/TimesFMModule/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # # limitations under the License. 14 | """TimesFM init file.""" 15 | # print( 16 | # "TimesFM v1.2.0. See https://github.com/google-research/timesfm/blob/master/README.md for updated APIs." 17 | # ) 18 | from probts.model.nn.arch.TimesFMModule.timesfm_base import freq_map, TimesFmCheckpoint, TimesFmHparams, TimesFmBase 19 | 20 | # print("Loaded PyTorch TimesFM.") 21 | from probts.model.nn.arch.TimesFMModule.timesfm_torch import TimesFmTorch as TimesFm 22 | -------------------------------------------------------------------------------- /probts/model/nn/arch/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/ProbTS/66f826cb39cfecb7920f0f953fc8743b6c28229b/probts/model/nn/arch/__init__.py -------------------------------------------------------------------------------- /probts/model/nn/arch/decomp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | class moving_avg(nn.Module): 5 | """ 6 | Moving average block to highlight the trend of time series 7 | """ 8 | def __init__(self, kernel_size, stride): 9 | super(moving_avg, self).__init__() 10 | self.kernel_size = kernel_size 11 | self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0) 12 | 13 | def forward(self, x): 14 | # padding on the both ends of time series 15 | front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1) 16 | end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1) 17 | x = torch.cat([front, x, end], dim=1) 18 | x = self.avg(x.permute(0, 2, 1)) 19 | x = x.permute(0, 2, 1) 20 | return x 21 | 22 | 23 | class series_decomp(nn.Module): 24 | """ 25 | Series decomposition block 26 | """ 27 | def __init__(self, kernel_size): 28 | super(series_decomp, self).__init__() 29 | self.moving_avg = moving_avg(kernel_size, stride=1) 30 | 31 | def forward(self, x): 32 | moving_mean = self.moving_avg(x) 33 | res = x - moving_mean 34 | return res, moving_mean -------------------------------------------------------------------------------- /probts/model/nn/prob/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/ProbTS/66f826cb39cfecb7920f0f953fc8743b6c28229b/probts/model/nn/prob/__init__.py -------------------------------------------------------------------------------- /probts/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import * 2 | from .evaluator import Evaluator -------------------------------------------------------------------------------- /probts/utils/masking.py: -------------------------------------------------------------------------------- 1 | # Code implementation from https://github.com/thuml/iTransformer 2 | import torch 3 | 4 | class TriangularCausalMask(): 5 | def __init__(self, B, L, device="cpu"): 6 | mask_shape = [B, 1, L, L] 7 | with torch.no_grad(): 8 | self._mask = torch.triu(torch.ones(mask_shape, dtype=torch.bool), diagonal=1).to(device) 9 | 10 | @property 11 | def mask(self): 12 | return self._mask 13 | 14 | 15 | class ProbMask(): 16 | def __init__(self, B, H, L, index, scores, device="cpu"): 17 | _mask = torch.ones(L, scores.shape[-1], dtype=torch.bool).to(device).triu(1) 18 | _mask_ex = _mask[None, None, :].expand(B, H, L, scores.shape[-1]) 19 | indicator = _mask_ex[torch.arange(B)[:, None, None], 20 | torch.arange(H)[None, :, None], 21 | index, :].to(device) 22 | self._mask = indicator.view(scores.shape).to(device) 23 | 24 | @property 25 | def mask(self): 26 | return self._mask 27 | -------------------------------------------------------------------------------- /scripts/prepare_datasets.sh: -------------------------------------------------------------------------------- 1 | # Check if gdown is installed 2 | if pip show gdown > /dev/null 2>&1; then 3 | echo "gdown is already installed, skipping installation." 4 | else 5 | echo "gdown is not installed, installing..." 6 | pip install gdown 7 | fi 8 | 9 | python probts/utils/download_datasets.py --data_path $1 -------------------------------------------------------------------------------- /scripts/reproduce_stsf_results.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | 3 | DATA_DIR=./datasets 4 | LOG_DIR=./exps 5 | 6 | for DATASET in 'solar' 'electricity' 'exchange' 'traffic' 'wiki' 7 | do 8 | for MODEL in 'dlinear' 'patchtst' 'gru_nvp' 'gru_maf' 'trans_maf' 'timegrad' 'csdi' 'timesnet' 9 | do 10 | python run.py --config config/stsf/${DATASET}/${MODEL}.yaml --seed_everything 0 \ 11 | --data.data_manager.init_args.path ${DATA_DIR} \ 12 | --trainer.default_root_dir ${LOG_DIR} \ 13 | --data.data_manager.init_args.split_val true 14 | done 15 | done 16 | -------------------------------------------------------------------------------- /scripts/run_elastst.sh: -------------------------------------------------------------------------------- 1 | DATA_DIR=/path/to/datasets 2 | LOG_DIR=/path/to/log_dir 3 | 4 | # for varied-horizon forecasting 5 | 6 | TRAIN_CTX_LEN=96 7 | VAL_CTX_LEN=96 8 | TEST_CTX_LEN=96 9 | 10 | TRAIN_PRED_LEN=720 11 | VAL_PRED_LEN=720 12 | TEST_PRED_LEN=24-48-96-192-336-720 13 | 14 | 15 | DATASET='exchange_ltsf' # select from ['etth1', 'etth2', 'ettm1', 'ettm2', 'traffic_ltsf', 'electricity_ltsf', 'exchange_ltsf', 'weather_ltsf'] 16 | 17 | MODEL=elastst 18 | 19 | python run.py --config config/multi_hor/${MODEL}.yaml --seed_everything 0 \ 20 | --data.data_manager.init_args.path ${DATA_DIR} \ 21 | --trainer.default_root_dir ${LOG_DIR} \ 22 | --data.data_manager.init_args.split_val true \ 23 | --data.data_manager.init_args.dataset ${DATASET} \ 24 | --data.data_manager.init_args.context_length ${TEST_CTX_LEN} \ 25 | --data.data_manager.init_args.prediction_length ${TEST_PRED_LEN} \ 26 | --data.data_manager.init_args.train_pred_len_list ${TRAIN_PRED_LEN} \ 27 | --data.data_manager.init_args.train_ctx_len ${TRAIN_CTX_LEN} \ 28 | --data.data_manager.init_args.val_ctx_len ${VAL_CTX_LEN} \ 29 | --data.data_manager.init_args.val_pred_len_list ${VAL_PRED_LEN} \ 30 | --trainer.max_epochs 50 -------------------------------------------------------------------------------- /scripts/run_varied_hor_training.sh: -------------------------------------------------------------------------------- 1 | DATA_DIR=/path/to/datasets 2 | LOG_DIR=/path/to/log_dir 3 | 4 | # for varied-horizon forecasting 5 | 6 | TRAIN_CTX_LEN=96 7 | VAL_CTX_LEN=96 8 | TEST_CTX_LEN=96 9 | 10 | TRAIN_PRED_LEN=1-720 11 | VAL_PRED_LEN=720 12 | TEST_PRED_LEN=24-48-96-192-336-720 13 | 14 | 15 | DATASET='exchange_ltsf' # select from ['etth1', 'etth2', 'ettm1', 'ettm2', 'traffic_ltsf', 'electricity_ltsf', 'exchange_ltsf', 'weather_ltsf'] 16 | 17 | MODEL=elastst 18 | 19 | python run.py --config config/multi_hor/${MODEL}.yaml --seed_everything 0 \ 20 | --data.data_manager.init_args.path ${DATA_DIR} \ 21 | --trainer.default_root_dir ${LOG_DIR} \ 22 | --data.data_manager.init_args.split_val true \ 23 | --data.data_manager.init_args.dataset ${DATASET} \ 24 | --data.data_manager.init_args.context_length ${TEST_CTX_LEN} \ 25 | --data.data_manager.init_args.prediction_length ${TEST_PRED_LEN} \ 26 | --data.data_manager.init_args.train_pred_len_list ${TRAIN_PRED_LEN} \ 27 | --data.data_manager.init_args.train_ctx_len ${TRAIN_CTX_LEN} \ 28 | --data.data_manager.init_args.val_ctx_len ${VAL_CTX_LEN} \ 29 | --data.data_manager.init_args.val_pred_len_list ${VAL_PRED_LEN} \ 30 | --data.data_manager.init_args.continuous_sample true \ 31 | --trainer.max_epochs 50 --------------------------------------------------------------------------------